import pandas as pd
from tensorflow import feature_column as fc
from utils.cache import redis_db_client

from ..utils import (common_elements, create_boundaries, create_vocabulary_list, nth_element)

TRACTATE_COLUMNS = [
    "card_id", "is_pure_author", "is_have_pure_reply", "is_have_reply", "content_level", "show_tag_id", "reply_num",
    "reply_pure_num", "one_ctr", "three_ctr", "seven_ctr", "fifteen_ctr", "thirty_ctr", "sixty_ctr", "ninety_ctr", "history_ctr",
    "first_demands", "second_demands", "first_solutions", "second_solutions", "first_positions", "second_positions", "projects"
]


def read_csv_data(dataset_path):
    tractate_df = pd.read_csv(dataset_path.joinpath("tractate.csv"), sep="|")
    click_df = pd.read_csv(dataset_path.joinpath("tractate_click.csv"), sep="|")
    conversion_df = pd.read_csv(dataset_path.joinpath("tractate_click_cvr.csv"), sep="|")
    # TODO
    return tractate_df.sample(5000), click_df.sample(10000), conversion_df


def get_tractate_from_redis():
    """
    return: {tractate_id: {first_demands: [], is_pure_author: 1}}
    """
    pass


def tractate_feature_engineering(tractate_df):
    df = tractate_df.copy()

    df["first_demands"] = df["first_demands"].str.split(",")
    df["second_demands"] = df["second_demands"].str.split(",")
    df["first_solutions"] = df["first_solutions"].str.split(",")
    df["second_solutions"] = df["second_solutions"].str.split(",")
    df["first_positions"] = df["first_positions"].str.split(",")
    df["second_positions"] = df["second_positions"].str.split(",")
    df["projects"] = df["projects"].str.split(",")

    df["first_demands"] = df["first_demands"].apply(lambda d: d if isinstance(d, list) else [])
    df["second_demands"] = df["second_demands"].apply(lambda d: d if isinstance(d, list) else [])
    df["first_solutions"] = df["first_solutions"].apply(lambda d: d if isinstance(d, list) else [])
    df["second_solutions"] = df["second_solutions"].apply(lambda d: d if isinstance(d, list) else [])
    df["first_positions"] = df["first_positions"].apply(lambda d: d if isinstance(d, list) else [])
    df["second_positions"] = df["second_positions"].apply(lambda d: d if isinstance(d, list) else [])
    df["projects"] = df["projects"].apply(lambda d: d if isinstance(d, list) else [])

    df["is_pure_author"] = df["is_pure_author"].astype(int)
    df["is_have_pure_reply"] = df["is_have_pure_reply"].astype(int)
    df["is_have_reply"] = df["is_have_reply"].astype(int)

    df = df[TRACTATE_COLUMNS]

    print("tractate:")
    nullseries = df.isnull().sum()
    print(nullseries[nullseries > 0])
    print(df.shape)
    return df


def join_features(device_df, tractate_df, cc_df):
    a = pd.merge(device_df, cc_df, how="inner", left_on="device_id", right_on="cl_id")
    df = pd.merge(a, tractate_df, how="inner", left_on="card_id", right_on="card_id")

    df["first_demands"] = df[["first_demands_x", "first_demands_y"]].apply(lambda x: common_elements(*x), axis=1)
    df["second_demands"] = df[["second_demands_x", "second_demands_y"]].apply(lambda x: common_elements(*x), axis=1)
    df["first_solutions"] = df[["first_solutions_x", "first_solutions_y"]].apply(lambda x: common_elements(*x), axis=1)
    df["second_solutions"] = df[["second_solutions_x", "second_solutions_y"]].apply(lambda x: common_elements(*x), axis=1)
    df["first_positions"] = df[["first_positions_x", "second_positions_y"]].apply(lambda x: common_elements(*x), axis=1)
    df["second_positions"] = df[["second_positions_x", "second_positions_y"]].apply(lambda x: common_elements(*x), axis=1)
    df["projects"] = df[["projects_x", "projects_y"]].apply(lambda x: common_elements(*x), axis=1)

    df["device_fd"] = df["first_demands_x"].apply(lambda x: nth_element(x, 0))
    df["device_sd"] = df["second_demands_x"].apply(lambda x: nth_element(x, 0))
    df["device_fs"] = df["first_solutions_x"].apply(lambda x: nth_element(x, 0))
    df["device_ss"] = df["second_solutions_x"].apply(lambda x: nth_element(x, 0))
    df["device_fp"] = df["first_positions_x"].apply(lambda x: nth_element(x, 0))
    df["device_sp"] = df["second_positions_x"].apply(lambda x: nth_element(x, 0))
    df["device_p"] = df["projects_x"].apply(lambda x: nth_element(x, 0))

    df["content_fd"] = df["first_demands_y"].apply(lambda x: nth_element(x, 0))
    df["content_sd"] = df["second_demands_y"].apply(lambda x: nth_element(x, 0))
    df["content_fs"] = df["first_solutions_y"].apply(lambda x: nth_element(x, 0))
    df["content_ss"] = df["second_solutions_y"].apply(lambda x: nth_element(x, 0))
    df["content_fp"] = df["first_positions_y"].apply(lambda x: nth_element(x, 0))
    df["content_sp"] = df["second_positions_y"].apply(lambda x: nth_element(x, 0))
    df["content_p"] = df["projects_y"].apply(lambda x: nth_element(x, 0))

    df["fd1"] = df["first_demands"].apply(lambda x: nth_element(x, 0))
    df["fd2"] = df["first_demands"].apply(lambda x: nth_element(x, 1))
    df["fd3"] = df["first_demands"].apply(lambda x: nth_element(x, 2))
    df["sd1"] = df["second_demands"].apply(lambda x: nth_element(x, 0))
    df["sd2"] = df["second_demands"].apply(lambda x: nth_element(x, 1))
    df["sd3"] = df["second_demands"].apply(lambda x: nth_element(x, 2))
    df["fs1"] = df["first_solutions"].apply(lambda x: nth_element(x, 0))
    df["fs2"] = df["first_solutions"].apply(lambda x: nth_element(x, 1))
    df["fs3"] = df["first_solutions"].apply(lambda x: nth_element(x, 2))
    df["ss1"] = df["second_solutions"].apply(lambda x: nth_element(x, 0))
    df["ss2"] = df["second_solutions"].apply(lambda x: nth_element(x, 1))
    df["ss3"] = df["second_solutions"].apply(lambda x: nth_element(x, 2))
    df["fp1"] = df["first_positions"].apply(lambda x: nth_element(x, 0))
    df["fp2"] = df["first_positions"].apply(lambda x: nth_element(x, 1))
    df["fp3"] = df["first_positions"].apply(lambda x: nth_element(x, 2))
    df["sp1"] = df["second_positions"].apply(lambda x: nth_element(x, 0))
    df["sp2"] = df["second_positions"].apply(lambda x: nth_element(x, 1))
    df["sp3"] = df["second_positions"].apply(lambda x: nth_element(x, 2))
    df["p1"] = df["projects"].apply(lambda x: nth_element(x, 0))
    df["p2"] = df["projects"].apply(lambda x: nth_element(x, 1))
    df["p3"] = df["projects"].apply(lambda x: nth_element(x, 2))

    print("df:")
    nullseries = df.isnull().sum()
    print(nullseries[nullseries > 0])
    print(df.shape)

    drop_columns = [
        "cl_id", "first_demands_x", "first_demands_y", "first_demands", "second_demands_x", "second_demands_y", "second_demands",
        "first_solutions_x", "first_solutions_y", "first_solutions", "second_solutions_x", "second_solutions_y",
        "second_solutions", "first_positions_x", "first_positions_y", "first_positions", "second_positions_x",
        "second_positions_y", "second_positions", "projects_x", "projects_y", "projects"
    ]
    # for col in drop_columns:
    #     if col in df.columns:
    #         df.drop(col, inplace=True, axis=1)
    df.drop(drop_columns, inplace=True, axis=1)
    return df


def build_features(df):
    # TODO
    int_columns = ["active_days", "topic_num", "favor_num", "vote_num"]
    float_columns = ["one_ctr", "three_ctr", "seven_ctr", "fifteen_ctr"]
    numeric_features = []
    for col in (int_columns + float_columns):
        if col in int_columns:
            numeric_features.append(
                fc.bucketized_column(fc.numeric_column(col, dtype=tf.int64), boundaries=create_boundaries(df, col)))
        else:
            numeric_features.append(fc.bucketized_column(fc.numeric_column(col), boundaries=create_boundaries(df, col)))

    # TODO
    categorical_columns = [
        "device_id", "active_type", "past_consume_ability_history", "potential_consume_ability_history",
        "price_sensitive_history", "card_id", "is_pure_author", "is_have_reply", "is_have_pure_reply", "content_level",
        "device_fd", "content_fd", "fd1", "fd2", "fd3", "device_sd", "content_sd", "sd1", "sd2", "sd3", "device_fs", "content_fs",
        "fs1", "fs2", "fs3", "device_ss", "content_ss", "ss1", "ss2", "ss3", "device_fp", "content_fp", "fp1", "fp2", "fp3",
        "device_sp", "content_sp", "sp1", "sp2", "sp3", "device_p", "content_p", "p1", "p2", "p3"
    ]

    categorical_ignore_columns = []

    categorical_features = []
    for col in categorical_columns:
        if col not in categorical_ignore_columns:
            if col == "card_id":
                categorical_features.append(
                    fc.embedding_column(fc.categorical_column_with_hash_bucket(col, 20000, dtype=tf.int64),
                                        dimension=int(df[col].size**0.25)))
            elif col == "device_id":
                categorical_features.append(
                    fc.embedding_column(fc.categorical_column_with_hash_bucket(col, 200000), dimension=int(df[col].size**0.25)))
            else:
                categorical_features.append(
                    fc.indicator_column(fc.categorical_column_with_vocabulary_list(col, create_vocabulary_list(df, col))))

    all_features = (numeric_features + categorical_features)
    return all_features


def device_tractate_fe(device_id, tractate_ids, device_dict, tractate_dict):
    pass
