import pandas as pd
from utils.cache import redis_db_client

# "channel_first", "city_first", "model_first",
DIARY_DEVICE_COLUMNS = [
    "device_id", "active_type", "active_days", "past_consume_ability_history", "potential_consume_ability_history",
    "price_sensitive_history", "first_demands", "second_demands", "first_solutions", "second_solutions", "first_positions",
    "second_positions", "projects"
]

TRACTATE_DEVICE_COLUMNS = [
    "device_id", "active_type", "active_days", "channel_first", "city_first", "model_first", "past_consume_ability_history",
    "potential_consume_ability_history", "price_sensitive_history", "first_demands", "second_demands", "first_solutions",
    "second_solutions", "first_positions", "second_positions", "projects", "click_tractate_id1", "click_tractate_id2",
    "click_tractate_id3", "click_tractate_id4", "click_tractate_id5"
]


def read_csv_data(dataset_path):
    device_df = pd.read_csv(dataset_path.joinpath("device.csv"), sep="|")
    device_df.drop_duplicates(subset=["device_id"], inplace=True)
    return device_df


def get_device_dict_from_redis():
    """
    return: {device_id: {first_demands: [], city_first: ""}}
    """
    db_key = "cvr:db:device"
    column_key = db_key + ":column"
    columns = str(redis_db_client.get(column_key), "utf-8").split("|")
    d = redis_db_client.hgetall(db_key)
    res = {}
    for i in d.values():
        row_list = str(i, "utf-8").split("|")
        tmp = {}
        for (index, elem) in enumerate(row_list):
            col_name = columns[index]
            if col_name in [
                    "first_demands", "second_demands", "first_solutions", "second_solutions", "first_positions",
                    "second_positions", "projects"
            ]:
                tmp[col_name] = elem.split(",")
            else:
                tmp[col_name] = elem
            res[tmp["device_id"]] = tmp
    return res


def device_feature_engineering(device_df, content_type):
    df = device_df.copy()

    df["first_demands"] = df["first_demands"].str.split(",")
    df["second_demands"] = df["second_demands"].str.split(",")
    df["first_solutions"] = df["first_solutions"].str.split(",")
    df["second_solutions"] = df["second_solutions"].str.split(",")
    df["first_positions"] = df["first_positions"].str.split(",")
    df["second_positions"] = df["second_positions"].str.split(",")
    df["projects"] = df["projects"].str.split(",")

    df["first_demands"] = df["first_demands"].apply(lambda d: d if isinstance(d, list) else [])
    df["second_demands"] = df["second_demands"].apply(lambda d: d if isinstance(d, list) else [])
    df["first_solutions"] = df["first_solutions"].apply(lambda d: d if isinstance(d, list) else [])
    df["second_solutions"] = df["second_solutions"].apply(lambda d: d if isinstance(d, list) else [])
    df["first_positions"] = df["first_positions"].apply(lambda d: d if isinstance(d, list) else [])
    df["second_positions"] = df["second_positions"].apply(lambda d: d if isinstance(d, list) else [])
    df["projects"] = df["projects"].apply(lambda d: d if isinstance(d, list) else [])

    df["city_first"] = df["city_first"].fillna("")
    df["model_first"] = df["model_first"].fillna("")
    df["channel_first"] = df["channel_first"].fillna("")

    df["click_diary_id1"] = df["click_diary_id1"].astype(str)
    df["click_diary_id2"] = df["click_diary_id2"].astype(str)
    df["click_diary_id3"] = df["click_diary_id3"].astype(str)
    df["click_diary_id4"] = df["click_diary_id4"].astype(str)
    df["click_diary_id5"] = df["click_diary_id5"].astype(str)

    df["click_tractate_id1"] = df["click_tractate_id1"].astype(str)
    df["click_tractate_id2"] = df["click_tractate_id2"].astype(str)
    df["click_tractate_id3"] = df["click_tractate_id3"].astype(str)
    df["click_tractate_id4"] = df["click_tractate_id4"].astype(str)
    df["click_tractate_id5"] = df["click_tractate_id5"].astype(str)

    columns = DIARY_DEVICE_COLUMNS
    if content_type == "tractate":
        columns = TRACTATE_DEVICE_COLUMNS

    nullseries = df.isnull().sum()
    print("device:")
    print(nullseries[nullseries > 0])
    print(df.shape)
    return df[columns]
