import os

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow.keras import (activations, callbacks, layers, losses, metrics, optimizers)

base_dir = os.getcwd()
# base_dir = "/Users/offic/work/GM/strategy_embedding/"  # TODO remove
DATA_PATH = os.path.join(base_dir, "_data")
MODEL_PATH = os.path.join(base_dir, "_models")

DEVICE_COLUMNS = [
    "device_id",
    "device_fd",
    "device_sd",
    "device_fs",
    "device_ss",
    "device_fp",
    "device_sp",
    "device_p",
    "device_fd2",
    "device_sd2",
    "device_fs2",
    "device_ss2",
    "device_fp2",
    "device_sp2",
    "device_p2",
]

TRACTATE_COLUMNS = [
    "card_id",
    "is_pure_author",
    "is_have_pure_reply",
    "is_have_reply",
    "content_level",
    "topic_seven_click_num",
    "topic_thirty_click_num",
    "topic_num",
    "seven_transform_num",
    "thirty_transform_num",
    "favor_num",
    "favor_pure_num",
    "vote_num",
    "vote_display_num",
    "reply_num",
    "reply_pure_num",
    "one_click_num",
    "three_click_num",
    "seven_click_num",
    "fifteen_click_num",
    "thirty_click_num",
    "sixty_click_num",
    "ninety_click_num",
    "history_click_num",
    "one_precise_exposure_num",
    "three_precise_exposure_num",
    "seven_precise_exposure_num",
    "fifteen_precise_exposure_num",
    "thirty_precise_exposure_num",
    "sixty_precise_exposure_num",
    "ninety_precise_exposure_num",
    "history_precise_exposure_num",
    "one_vote_user_num",
    "three_vote_user_num",
    "seven_vote_user_num",
    "fifteen_vote_user_num",
    "thirty_vote_user_num",
    "sixty_vote_user_num",
    "ninety_vote_user_num",
    "history_vote_user_num",
    "one_reply_user_num",
    "three_reply_user_num",
    "seven_reply_user_num",
    "fifteen_reply_user_num",
    "thirty_reply_user_num",
    "sixty_reply_user_num",
    "ninety_reply_user_num",
    "history_reply_user_num",
    "one_browse_user_num",
    "three_browse_user_num",
    "seven_browse_user_num",
    "fifteen_browse_user_num",
    "thirty_browse_user_num",
    "sixty_browse_user_num",
    "ninety_browse_user_num",
    "history_browse_user_num",
    "one_reply_num",
    "three_reply_num",
    "seven_reply_num",
    "fifteen_reply_num",
    "thirty_reply_num",
    "sixty_reply_num",
    "ninety_reply_num",
    "history_reply_num",
    "one_ctr",
    "three_ctr",
    "seven_ctr",
    "fifteen_ctr",
    "thirty_ctr",
    "sixty_ctr",
    "ninety_ctr",
    "history_ctr",
    "one_vote_pure_rate",
    "three_vote_pure_rate",
    "seven_vote_pure_rate",
    "fifteen_vote_pure_rate",
    "thirty_vote_pure_rate",
    "sixty_vote_pure_rate",
    "ninety_vote_pure_rate",
    "history_vote_pure_rate",
    "one_reply_pure_rate",
    "three_reply_pure_rate",
    "seven_reply_pure_rate",
    "fifteen_reply_pure_rate",
    "thirty_reply_pure_rate",
    "sixty_reply_pure_rate",
    "ninety_reply_pure_rate",
    "history_reply_pure_rate",
    "card_fd",
    "card_sd",
    "card_fs",
    "card_ss",
    "card_fp",
    "card_sp",
    "card_p",
    "card_fd2",
    "card_sd2",
    "card_fs2",
    "card_ss2",
    "card_fp2",
    "card_sp2",
    "card_p2",
]


def nth_element(lst, n):
    if n >= len(lst):
        return ""
    return lst[n]


def get_df(file):
    full_path = os.path.join(DATA_PATH, file)
    df = pd.read_csv(full_path, sep="|")
    return df


def device_tractae_fe():
    click_df = get_df("tractate_click.csv")
    exposure_df = get_df("tractate_exposure.csv")
    device_fe_df = get_df("device_feature.csv")
    tractate_fe_df = get_df("tractate_feature.csv")
    print(click_df.shape)
    print(exposure_df.shape)
    print(device_fe_df.shape)
    print(tractate_fe_df.shape)

    #
    click_df.drop("partition_date", inplace=True, axis=1)
    exposure_df.drop("partition_date", inplace=True, axis=1)
    base_df = pd.merge(click_df, exposure_df, how="outer", indicator="Exist")
    base_df["label"] = np.where(base_df["Exist"] == "right_only", 0.75, 1.0)
    base_df.drop("Exist", inplace=True, axis=1)

    #
    device_fe_df.fillna("", inplace=True)
    device_fe_df.rename(columns={"cl_id": "device_id"}, inplace=True)

    device_fe_df["first_demands"] = device_fe_df["first_demands"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    device_fe_df["second_demands"] = device_fe_df["second_demands"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    device_fe_df["first_solutions"] = device_fe_df["first_solutions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    device_fe_df["second_solutions"] = device_fe_df["second_solutions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    device_fe_df["first_positions"] = device_fe_df["first_positions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    device_fe_df["second_positions"] = device_fe_df["second_positions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    device_fe_df["projects"] = device_fe_df["projects"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])

    device_fe_df["device_fd"] = device_fe_df["first_demands"].apply(lambda x: nth_element(x, 0))
    device_fe_df["device_sd"] = device_fe_df["second_demands"].apply(lambda x: nth_element(x, 0))
    device_fe_df["device_fs"] = device_fe_df["first_solutions"].apply(lambda x: nth_element(x, 0))
    device_fe_df["device_ss"] = device_fe_df["second_solutions"].apply(lambda x: nth_element(x, 0))
    device_fe_df["device_fp"] = device_fe_df["first_positions"].apply(lambda x: nth_element(x, 0))
    device_fe_df["device_sp"] = device_fe_df["second_positions"].apply(lambda x: nth_element(x, 0))
    device_fe_df["device_p"] = device_fe_df["projects"].apply(lambda x: nth_element(x, 0))

    device_fe_df["device_fd2"] = device_fe_df["first_demands"].apply(lambda x: nth_element(x, 1))
    device_fe_df["device_sd2"] = device_fe_df["second_demands"].apply(lambda x: nth_element(x, 1))
    device_fe_df["device_fs2"] = device_fe_df["first_solutions"].apply(lambda x: nth_element(x, 1))
    device_fe_df["device_ss2"] = device_fe_df["second_solutions"].apply(lambda x: nth_element(x, 1))
    device_fe_df["device_fp2"] = device_fe_df["first_positions"].apply(lambda x: nth_element(x, 1))
    device_fe_df["device_sp2"] = device_fe_df["second_positions"].apply(lambda x: nth_element(x, 1))
    device_fe_df["device_p2"] = device_fe_df["projects"].apply(lambda x: nth_element(x, 1))
    _drop_columns = [
        "first_demands", "second_demands", "first_solutions", "second_solutions", "first_positions", "second_positions",
        "projects"
    ]
    device_fe_df.drop(columns=_drop_columns, axis=1, inplace=True)

    #
    _card_drop_columns = [
        "card_first_demands", "card_second_demands", "card_first_solutions", "card_second_solutions", "card_first_positions",
        "card_second_positions", "card_projects"
    ]
    tractate_fe_df[_card_drop_columns].fillna("", inplace=True)

    tractate_fe_df["card_first_demands"] = tractate_fe_df["card_first_demands"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    tractate_fe_df["card_second_demands"] = tractate_fe_df["card_second_demands"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    tractate_fe_df["card_first_solutions"] = tractate_fe_df["card_first_solutions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    tractate_fe_df["card_second_solutions"] = tractate_fe_df["card_second_solutions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    tractate_fe_df["card_first_positions"] = tractate_fe_df["card_first_positions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    tractate_fe_df["card_second_positions"] = tractate_fe_df["card_second_positions"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])
    tractate_fe_df["card_projects"] = tractate_fe_df["card_projects"].str.split(",").\
        apply(lambda d: d if isinstance(d, list) else [])

    tractate_fe_df["card_fd"] = tractate_fe_df["card_first_demands"].apply(lambda x: nth_element(x, 0))
    tractate_fe_df["card_sd"] = tractate_fe_df["card_second_demands"].apply(lambda x: nth_element(x, 0))
    tractate_fe_df["card_fs"] = tractate_fe_df["card_first_solutions"].apply(lambda x: nth_element(x, 0))
    tractate_fe_df["card_ss"] = tractate_fe_df["card_second_solutions"].apply(lambda x: nth_element(x, 0))
    tractate_fe_df["card_fp"] = tractate_fe_df["card_first_positions"].apply(lambda x: nth_element(x, 0))
    tractate_fe_df["card_sp"] = tractate_fe_df["card_second_positions"].apply(lambda x: nth_element(x, 0))
    tractate_fe_df["card_p"] = tractate_fe_df["card_projects"].apply(lambda x: nth_element(x, 0))

    tractate_fe_df["card_fd2"] = tractate_fe_df["card_first_demands"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df["card_sd2"] = tractate_fe_df["card_second_demands"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df["card_fs2"] = tractate_fe_df["card_first_solutions"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df["card_ss2"] = tractate_fe_df["card_second_solutions"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df["card_fp2"] = tractate_fe_df["card_first_positions"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df["card_sp2"] = tractate_fe_df["card_second_positions"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df["card_p2"] = tractate_fe_df["card_projects"].apply(lambda x: nth_element(x, 1))
    tractate_fe_df.drop(columns=_card_drop_columns, axis=1, inplace=True)

    #
    df = pd.merge(pd.merge(base_df, device_fe_df), tractate_fe_df)

    nullseries = df.isnull().sum()
    nulls = nullseries[nullseries > 0]
    if nulls.any():
        print(nulls)
        raise Exception("dataframe nulls")
    return df


def get_input_dim(df, columns):
    res = {}
    for i in columns:
        res[i] = df[i].unique().size + 1
    return res


def tractate_dssm_model():
    # input
    device_id = layers.Input(shape=(1, ), name="device_id")
    device_fd = layers.Input(shape=(1, ), name="device_fd")
    device_sd = layers.Input(shape=(1, ), name="device_sd")
    device_fs = layers.Input(shape=(1, ), name="device_fs")
    device_ss = layers.Input(shape=(1, ), name="device_ss")
    device_fp = layers.Input(shape=(1, ), name="device_fp")
    device_sp = layers.Input(shape=(1, ), name="device_sp")
    device_p = layers.Input(shape=(1, ), name="device_p")
    device_fd2 = layers.Input(shape=(1, ), name="device_fd2")
    device_sd2 = layers.Input(shape=(1, ), name="device_sd2")
    device_fs2 = layers.Input(shape=(1, ), name="device_fs2")
    device_ss2 = layers.Input(shape=(1, ), name="device_ss2")
    device_fp2 = layers.Input(shape=(1, ), name="device_fp2")
    device_sp2 = layers.Input(shape=(1, ), name="device_sp2")
    device_p2 = layers.Input(shape=(1, ), name="device_p2")

    card_id = layers.Input(shape=(1, ), name="card_id")
    is_pure_author = layers.Input(shape=(1, ), name="is_pure_author")
    is_have_pure_reply = layers.Input(shape=(1, ), name="is_have_pure_reply")
    is_have_reply = layers.Input(shape=(1, ), name="is_have_reply")
    content_level = layers.Input(shape=(1, ), name="content_level")
    topic_seven_click_num = layers.Input(shape=(1, ), name="topic_seven_click_num")
    topic_thirty_click_num = layers.Input(shape=(1, ), name="topic_thirty_click_num")
    topic_num = layers.Input(shape=(1, ), name="topic_num")
    seven_transform_num = layers.Input(shape=(1, ), name="seven_transform_num")
    thirty_transform_num = layers.Input(shape=(1, ), name="thirty_transform_num")
    favor_num = layers.Input(shape=(1, ), name="favor_num")
    favor_pure_num = layers.Input(shape=(1, ), name="favor_pure_num")
    vote_num = layers.Input(shape=(1, ), name="vote_num")
    vote_display_num = layers.Input(shape=(1, ), name="vote_display_num")
    reply_num = layers.Input(shape=(1, ), name="reply_num")
    reply_pure_num = layers.Input(shape=(1, ), name="reply_pure_num")
    one_click_num = layers.Input(shape=(1, ), name="one_click_num")
    three_click_num = layers.Input(shape=(1, ), name="three_click_num")
    seven_click_num = layers.Input(shape=(1, ), name="seven_click_num")
    fifteen_click_num = layers.Input(shape=(1, ), name="fifteen_click_num")
    thirty_click_num = layers.Input(shape=(1, ), name="thirty_click_num")
    sixty_click_num = layers.Input(shape=(1, ), name="sixty_click_num")
    ninety_click_num = layers.Input(shape=(1, ), name="ninety_click_num")
    history_click_num = layers.Input(shape=(1, ), name="history_click_num")
    one_precise_exposure_num = layers.Input(shape=(1, ), name="one_precise_exposure_num")
    three_precise_exposure_num = layers.Input(shape=(1, ), name="three_precise_exposure_num")
    seven_precise_exposure_num = layers.Input(shape=(1, ), name="seven_precise_exposure_num")
    fifteen_precise_exposure_num = layers.Input(shape=(1, ), name="fifteen_precise_exposure_num")
    thirty_precise_exposure_num = layers.Input(shape=(1, ), name="thirty_precise_exposure_num")
    sixty_precise_exposure_num = layers.Input(shape=(1, ), name="sixty_precise_exposure_num")
    ninety_precise_exposure_num = layers.Input(shape=(1, ), name="ninety_precise_exposure_num")
    history_precise_exposure_num = layers.Input(shape=(1, ), name="history_precise_exposure_num")
    one_vote_user_num = layers.Input(shape=(1, ), name="one_vote_user_num")
    three_vote_user_num = layers.Input(shape=(1, ), name="three_vote_user_num")
    seven_vote_user_num = layers.Input(shape=(1, ), name="seven_vote_user_num")
    fifteen_vote_user_num = layers.Input(shape=(1, ), name="fifteen_vote_user_num")
    thirty_vote_user_num = layers.Input(shape=(1, ), name="thirty_vote_user_num")
    sixty_vote_user_num = layers.Input(shape=(1, ), name="sixty_vote_user_num")
    ninety_vote_user_num = layers.Input(shape=(1, ), name="ninety_vote_user_num")
    history_vote_user_num = layers.Input(shape=(1, ), name="history_vote_user_num")
    one_reply_user_num = layers.Input(shape=(1, ), name="one_reply_user_num")
    three_reply_user_num = layers.Input(shape=(1, ), name="three_reply_user_num")
    seven_reply_user_num = layers.Input(shape=(1, ), name="seven_reply_user_num")
    fifteen_reply_user_num = layers.Input(shape=(1, ), name="fifteen_reply_user_num")
    thirty_reply_user_num = layers.Input(shape=(1, ), name="thirty_reply_user_num")
    sixty_reply_user_num = layers.Input(shape=(1, ), name="sixty_reply_user_num")
    ninety_reply_user_num = layers.Input(shape=(1, ), name="ninety_reply_user_num")
    history_reply_user_num = layers.Input(shape=(1, ), name="history_reply_user_num")
    one_browse_user_num = layers.Input(shape=(1, ), name="one_browse_user_num")
    three_browse_user_num = layers.Input(shape=(1, ), name="three_browse_user_num")
    seven_browse_user_num = layers.Input(shape=(1, ), name="seven_browse_user_num")
    fifteen_browse_user_num = layers.Input(shape=(1, ), name="fifteen_browse_user_num")
    thirty_browse_user_num = layers.Input(shape=(1, ), name="thirty_browse_user_num")
    sixty_browse_user_num = layers.Input(shape=(1, ), name="sixty_browse_user_num")
    ninety_browse_user_num = layers.Input(shape=(1, ), name="ninety_browse_user_num")
    history_browse_user_num = layers.Input(shape=(1, ), name="history_browse_user_num")
    one_reply_num = layers.Input(shape=(1, ), name="one_reply_num")
    three_reply_num = layers.Input(shape=(1, ), name="three_reply_num")
    seven_reply_num = layers.Input(shape=(1, ), name="seven_reply_num")
    fifteen_reply_num = layers.Input(shape=(1, ), name="fifteen_reply_num")
    thirty_reply_num = layers.Input(shape=(1, ), name="thirty_reply_num")
    sixty_reply_num = layers.Input(shape=(1, ), name="sixty_reply_num")
    ninety_reply_num = layers.Input(shape=(1, ), name="ninety_reply_num")
    history_reply_num = layers.Input(shape=(1, ), name="history_reply_num")
    one_ctr = layers.Input(shape=(1, ), name="one_ctr")
    three_ctr = layers.Input(shape=(1, ), name="three_ctr")
    seven_ctr = layers.Input(shape=(1, ), name="seven_ctr")
    fifteen_ctr = layers.Input(shape=(1, ), name="fifteen_ctr")
    thirty_ctr = layers.Input(shape=(1, ), name="thirty_ctr")
    sixty_ctr = layers.Input(shape=(1, ), name="sixty_ctr")
    ninety_ctr = layers.Input(shape=(1, ), name="ninety_ctr")
    history_ctr = layers.Input(shape=(1, ), name="history_ctr")
    one_vote_pure_rate = layers.Input(shape=(1, ), name="one_vote_pure_rate")
    three_vote_pure_rate = layers.Input(shape=(1, ), name="three_vote_pure_rate")
    seven_vote_pure_rate = layers.Input(shape=(1, ), name="seven_vote_pure_rate")
    fifteen_vote_pure_rate = layers.Input(shape=(1, ), name="fifteen_vote_pure_rate")
    thirty_vote_pure_rate = layers.Input(shape=(1, ), name="thirty_vote_pure_rate")
    sixty_vote_pure_rate = layers.Input(shape=(1, ), name="sixty_vote_pure_rate")
    ninety_vote_pure_rate = layers.Input(shape=(1, ), name="ninety_vote_pure_rate")
    history_vote_pure_rate = layers.Input(shape=(1, ), name="history_vote_pure_rate")
    one_reply_pure_rate = layers.Input(shape=(1, ), name="one_reply_pure_rate")
    three_reply_pure_rate = layers.Input(shape=(1, ), name="three_reply_pure_rate")
    seven_reply_pure_rate = layers.Input(shape=(1, ), name="seven_reply_pure_rate")
    fifteen_reply_pure_rate = layers.Input(shape=(1, ), name="fifteen_reply_pure_rate")
    thirty_reply_pure_rate = layers.Input(shape=(1, ), name="thirty_reply_pure_rate")
    sixty_reply_pure_rate = layers.Input(shape=(1, ), name="sixty_reply_pure_rate")
    ninety_reply_pure_rate = layers.Input(shape=(1, ), name="ninety_reply_pure_rate")
    history_reply_pure_rate = layers.Input(shape=(1, ), name="history_reply_pure_rate")
    card_fd = layers.Input(shape=(1, ), name="card_fd")
    card_sd = layers.Input(shape=(1, ), name="card_sd")
    card_fs = layers.Input(shape=(1, ), name="card_fs")
    card_ss = layers.Input(shape=(1, ), name="card_ss")
    card_fp = layers.Input(shape=(1, ), name="card_fp")
    card_sp = layers.Input(shape=(1, ), name="card_sp")
    card_p = layers.Input(shape=(1, ), name="card_p")
    card_fd2 = layers.Input(shape=(1, ), name="card_fd2")
    card_sd2 = layers.Input(shape=(1, ), name="card_sd2")
    card_fs2 = layers.Input(shape=(1, ), name="card_fs2")
    card_ss2 = layers.Input(shape=(1, ), name="card_ss2")
    card_fp2 = layers.Input(shape=(1, ), name="card_fp2")
    card_sp2 = layers.Input(shape=(1, ), name="card_sp2")
    card_p2 = layers.Input(shape=(1, ), name="card_p2")

    # user tower
    device_vector = layers.concatenate([
        layers.Embedding(DEVICE_DIM_DICT.get("device_id"), 10)(device_id),
        layers.Embedding(DEVICE_DIM_DICT.get("device_fd"), 3)(device_fd),
        layers.Embedding(DEVICE_DIM_DICT.get("device_sd"), 3)(device_sd),
        layers.Embedding(DEVICE_DIM_DICT.get("device_fs"), 3)(device_fs),
        layers.Embedding(DEVICE_DIM_DICT.get("device_ss"), 3)(device_ss),
        layers.Embedding(DEVICE_DIM_DICT.get("device_fp"), 3)(device_fp),
        layers.Embedding(DEVICE_DIM_DICT.get("device_sp"), 3)(device_sp),
        layers.Embedding(DEVICE_DIM_DICT.get("device_p"), 3)(device_p),
        layers.Embedding(DEVICE_DIM_DICT.get("device_fd2"), 3)(device_fd2),
        layers.Embedding(DEVICE_DIM_DICT.get("device_sd2"), 3)(device_sd2),
        layers.Embedding(DEVICE_DIM_DICT.get("device_fs2"), 3)(device_fs2),
        layers.Embedding(DEVICE_DIM_DICT.get("device_ss2"), 3)(device_ss2),
        layers.Embedding(DEVICE_DIM_DICT.get("device_fp2"), 3)(device_fp2),
        layers.Embedding(DEVICE_DIM_DICT.get("device_sp2"), 3)(device_sp2),
        layers.Embedding(DEVICE_DIM_DICT.get("device_p2"), 3)(device_p2)
    ])
    device_vector = layers.Dense(3000, activation=activations.relu)(device_vector)
    device_vector = layers.Dense(
        1000,
        activation=activations.relu,
        name="device_embedding",
        kernel_regularizer="l2",
    )(device_vector)

    # item tower
    tractate_vector = layers.concatenate([
        layers.Embedding(TRACTATE_DIM_DICT.get("card_id"), 10)(card_id),
        layers.Embedding(TRACTATE_DIM_DICT.get("is_pure_author"), 2)(is_pure_author),
        layers.Embedding(TRACTATE_DIM_DICT.get("is_have_pure_reply"), 2)(is_have_pure_reply),
        layers.Embedding(TRACTATE_DIM_DICT.get("is_have_reply"), 2)(is_have_reply),
        layers.Embedding(TRACTATE_DIM_DICT.get("content_level"), 3)(content_level),
        layers.Embedding(TRACTATE_DIM_DICT.get("topic_seven_click_num"), 3)(topic_seven_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("topic_thirty_click_num"), 3)(topic_thirty_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("topic_num"), 3)(topic_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_transform_num"), 3)(seven_transform_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_transform_num"), 3)(thirty_transform_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("favor_num"), 3)(favor_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("favor_pure_num"), 3)(favor_pure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("vote_num"), 3)(vote_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("vote_display_num"), 3)(vote_display_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("reply_num"), 3)(reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("reply_pure_num"), 3)(reply_pure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_click_num"), 3)(one_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_click_num"), 3)(three_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_click_num"), 3)(seven_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_click_num"), 3)(fifteen_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_click_num"), 3)(thirty_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_click_num"), 3)(sixty_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_click_num"), 3)(ninety_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_click_num"), 3)(history_click_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_precise_exposure_num"), 3)(one_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_precise_exposure_num"), 3)(three_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_precise_exposure_num"), 3)(seven_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_precise_exposure_num"), 3)(fifteen_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_precise_exposure_num"), 3)(thirty_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_precise_exposure_num"), 3)(sixty_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_precise_exposure_num"), 3)(ninety_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_precise_exposure_num"), 3)(history_precise_exposure_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_vote_user_num"), 3)(one_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_vote_user_num"), 3)(three_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_vote_user_num"), 3)(seven_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_vote_user_num"), 3)(fifteen_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_vote_user_num"), 3)(thirty_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_vote_user_num"), 3)(sixty_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_vote_user_num"), 3)(ninety_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_vote_user_num"), 3)(history_vote_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_reply_user_num"), 3)(one_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_reply_user_num"), 3)(three_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_reply_user_num"), 3)(seven_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_reply_user_num"), 3)(fifteen_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_reply_user_num"), 3)(thirty_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_reply_user_num"), 3)(sixty_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_reply_user_num"), 3)(ninety_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_reply_user_num"), 3)(history_reply_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_browse_user_num"), 3)(one_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_browse_user_num"), 3)(three_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_browse_user_num"), 3)(seven_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_browse_user_num"), 3)(fifteen_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_browse_user_num"), 3)(thirty_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_browse_user_num"), 3)(sixty_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_browse_user_num"), 3)(ninety_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_browse_user_num"), 3)(history_browse_user_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_reply_num"), 3)(one_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_reply_num"), 3)(three_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_reply_num"), 3)(seven_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_reply_num"), 3)(fifteen_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_reply_num"), 3)(thirty_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_reply_num"), 3)(sixty_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_reply_num"), 3)(ninety_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_reply_num"), 3)(history_reply_num),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_ctr"), 3)(one_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_ctr"), 3)(three_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_ctr"), 3)(seven_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_ctr"), 3)(fifteen_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_ctr"), 3)(thirty_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_ctr"), 3)(sixty_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_ctr"), 3)(ninety_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_ctr"), 3)(history_ctr),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_vote_pure_rate"), 3)(one_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_vote_pure_rate"), 3)(three_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_vote_pure_rate"), 3)(seven_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_vote_pure_rate"), 3)(fifteen_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_vote_pure_rate"), 3)(thirty_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_vote_pure_rate"), 3)(sixty_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_vote_pure_rate"), 3)(ninety_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_vote_pure_rate"), 3)(history_vote_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("one_reply_pure_rate"), 3)(one_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("three_reply_pure_rate"), 3)(three_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("seven_reply_pure_rate"), 3)(seven_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("fifteen_reply_pure_rate"), 3)(fifteen_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("thirty_reply_pure_rate"), 3)(thirty_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("sixty_reply_pure_rate"), 3)(sixty_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("ninety_reply_pure_rate"), 3)(ninety_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("history_reply_pure_rate"), 3)(history_reply_pure_rate),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_fd"), 3)(card_fd),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_sd"), 3)(card_sd),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_fs"), 3)(card_fs),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_ss"), 3)(card_ss),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_fp"), 3)(card_fp),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_sp"), 3)(card_sp),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_p"), 3)(card_p),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_fd2"), 3)(card_fd2),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_sd2"), 3)(card_sd2),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_fs2"), 3)(card_fs2),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_ss2"), 3)(card_ss2),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_fp2"), 3)(card_fp2),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_sp2"), 3)(card_sp2),
        layers.Embedding(TRACTATE_DIM_DICT.get("card_p2"), 3)(card_p2)
    ])

    tractate_vector = layers.Dense(3000, activation=activations.relu)(tractate_vector)
    tractate_vector = layers.Dense(
        1000,
        activation=activations.relu,
        name="tractate_embedding",
        kernel_regularizer="l2",
    )(tractate_vector)

    device_tractate_dot = tf.reduce_sum(device_vector * tractate_vector, axis=1)
    device_tractate_dot = tf.expand_dims(device_tractate_dot, 1)

    inputs = [
        device_id, device_fd, device_sd, device_fs, device_ss, device_fp, device_sp, device_p, device_fd2, device_sd2, device_fs2,
        device_ss2, device_fp2, device_sp2, device_p2, card_id, is_pure_author, is_have_pure_reply, is_have_reply, content_level,
        topic_seven_click_num, topic_thirty_click_num, topic_num, seven_transform_num, thirty_transform_num, favor_num,
        favor_pure_num, vote_num, vote_display_num, reply_num, reply_pure_num, one_click_num, three_click_num, seven_click_num,
        fifteen_click_num, thirty_click_num, sixty_click_num, ninety_click_num, history_click_num, one_precise_exposure_num,
        three_precise_exposure_num, seven_precise_exposure_num, fifteen_precise_exposure_num, thirty_precise_exposure_num,
        sixty_precise_exposure_num, ninety_precise_exposure_num, history_precise_exposure_num, one_vote_user_num,
        three_vote_user_num, seven_vote_user_num, fifteen_vote_user_num, thirty_vote_user_num, sixty_vote_user_num,
        ninety_vote_user_num, history_vote_user_num, one_reply_user_num, three_reply_user_num, seven_reply_user_num,
        fifteen_reply_user_num, thirty_reply_user_num, sixty_reply_user_num, ninety_reply_user_num, history_reply_user_num,
        one_browse_user_num, three_browse_user_num, seven_browse_user_num, fifteen_browse_user_num, thirty_browse_user_num,
        sixty_browse_user_num, ninety_browse_user_num, history_browse_user_num, one_reply_num, three_reply_num, seven_reply_num,
        fifteen_reply_num, thirty_reply_num, sixty_reply_num, ninety_reply_num, history_reply_num, one_ctr, three_ctr, seven_ctr,
        fifteen_ctr, thirty_ctr, sixty_ctr, ninety_ctr, history_ctr, one_vote_pure_rate, three_vote_pure_rate,
        seven_vote_pure_rate, fifteen_vote_pure_rate, thirty_vote_pure_rate, sixty_vote_pure_rate, ninety_vote_pure_rate,
        history_vote_pure_rate, one_reply_pure_rate, three_reply_pure_rate, seven_reply_pure_rate, fifteen_reply_pure_rate,
        thirty_reply_pure_rate, sixty_reply_pure_rate, ninety_reply_pure_rate, history_reply_pure_rate, card_fd, card_sd, card_fs,
        card_ss, card_fp, card_sp, card_p, card_fd2, card_sd2, card_fs2, card_ss2, card_fp2, card_sp2, card_p2
    ]
    output = layers.Dense(1, activation=activations.sigmoid)(device_tractate_dot)

    model = tf.keras.Model(inputs=inputs, outputs=[output])
    print(model.summary())

    model.compile(
        loss=losses.MeanSquaredError(),
        optimizer=optimizers.RMSprop(),
        metrics=[metrics.binary_accuracy],
    )
    return model


if __name__ == "__main__":
    df = device_tractae_fe()
    print(df.head(3), df.shape)
    y = df["label"]
    # device_df = df[DEVICE_COLUMNS]
    # tractate_df = df[TRACTATE_COLUMNS]

    DEVICE_DIM_DICT = get_input_dim(df, DEVICE_COLUMNS)
    TRACTATE_DIM_DICT = get_input_dim(df, TRACTATE_COLUMNS)

    model = tractate_dssm_model()

    x_train = []
    for i in DEVICE_COLUMNS + TRACTATE_COLUMNS:
        x_train.append(df[i])

    history = model.fit(x=x_train,
                        y=y,
                        batch_size=320,
                        epochs=5,
                        verbose=1,
                        callbacks=[callbacks.EarlyStopping(monitor="loss", patience=10)])

    history_dict = history.history
    print(history_dict)