import tensorflow as tf
from tensorflow import feature_column as fc

from .utils import create_boundaries, create_vocabulary_list


def build_features(df):
    numeric_columns = ["active_days", "topic_num", "favor_num", "vote_num", "one_ctr", "three_ctr", "seven_ctr", "fifteen_ctr"]
    numeric_features = []
    for col in numeric_columns:
        numeric_features.append(fc.bucketized_column(fc.numeric_column(col), boundaries=create_boundaries(df, col)))

    categorical_columns = [
        "device_id", "active_type", "past_consume_ability_history", "potential_consume_ability_history",
        "price_sensitive_history", "card_id", "is_pure_author", "is_have_reply", "is_have_pure_reply", "content_level",
        "device_fd", "content_fd", "fd1", "fd2", "fd3", "device_sd", "content_sd", "sd1", "sd2", "sd3", "device_fs", "content_fs",
        "fs1", "fs2", "fs3", "device_ss", "content_ss", "ss1", "ss2", "ss3", "device_fp", "content_fp", "fp1", "fp2", "fp3",
        "device_sp", "content_sp", "sp1", "sp2", "sp3", "device_p", "content_p", "p1", "p2", "p3"
    ]

    categorical_ignore_columns = []

    categorical_features = []
    for col in categorical_columns:
        if col not in categorical_ignore_columns:
            if col == "card_id":
                categorical_features.append(
                    fc.embedding_column(fc.categorical_column_with_hash_bucket(col, 20000, dtype=tf.int64),
                                        dimension=int(df[col].size**0.25)))
            elif col == "device_id":
                categorical_features.append(
                    fc.embedding_column(fc.categorical_column_with_hash_bucket(col, 200000), dimension=int(df[col].size**0.25)))
            else:
                categorical_features.append(
                    fc.indicator_column(fc.categorical_column_with_vocabulary_list(col, create_vocabulary_list(df, col))))

    all_features = (numeric_features + categorical_features)
    return all_features


def esmm_input_fn(dataframe, shuffle=False, batch_size=256):
    df = dataframe.copy()
    target = df[["click_label", "conversion_label"]]
    ds = tf.data.Dataset.from_tensor_slices((dict(df), dict(target)))
    if shuffle:
        ds = ds.shuffle(1000).repeat()
    return ds.batch(batch_size).make_one_shot_iterator().get_next()
