import tensorflow as tf
import json
import pandas as pd
import time

import utils.connUtils as connUtils

ITEM_NUMBER_COLUMNS = ["smart_rank2"]
embedding_columns = ["itemid","userid","doctor_id","hospital_id"]
multi_columns = ["tags_v3","first_demands","second_demands","first_solutions","second_solutions","first_positions","second_positions"]
one_hot_columns = ["service_type","doctor_type","doctor_famous","hospital_city_tag_id","hospital_type","hospital_is_high_quality"]
# history_columns = ["userRatedHistory"]

# 数据加载
# data_path_train = "/Users/zhigangzheng/Desktop/work/guoyu/service_sort/train/part-00000-a61205d1-ad4e-4fa7-895d-ad8db41189e6-c000.csv"
# data_path_test = "/Users/zhigangzheng/Desktop/work/guoyu/service_sort/train/part-00000-a61205d1-ad4e-4fa7-895d-ad8db41189e6-c000.csv"

data_path_train = "/data/files/service_feature_train.csv"
data_path_test = "/data/files/service_feature_test.csv"
version = "v1"
model_file = "service_mlp_"+version

#数据字典
def getDataVocabFromRedis(version):
    conn = connUtils.getRedisConn()
    key = "Strategy:rec:vocab:service:"+version
    dataVocabStr = conn.get(key)
    if dataVocabStr:
        dataVocab = json.loads(dataVocabStr,encoding='utf-8')
        print("-----data_vocab-----")
        for k, v in dataVocab.items():
            print(k, len(v))
    else:
        dataVocab = None

    return dataVocab

# 数据类型转换
def csvTypeConvert(df,data_vocab):
    df = df.fillna("-1")
    # 离散na值填充
    for k, v in data_vocab.items():
        df[k] = df[k].fillna("-1")
        df[k] = df[k].astype("string")

    for k in ITEM_NUMBER_COLUMNS:
        df[k] = df[k].fillna(0.0)
        df[k] = df[k].astype("float")

    df["label"] = df["label"].astype("int")
    print(df.dtypes)
    return df

def loadData(data_path):
    print("读取数据...")
    timestmp1 = int(round(time.time() * 1000))
    df = pd.read_csv(data_path, sep="|")
    timestmp2 = int(round(time.time() * 1000))
    print("读取数据耗时ms:{}".format(timestmp2 - timestmp1))
    return df


def getDataSet(df,shuffleSize = 10000,batchSize=128):
    # print(df.dtypes)
    labels = df.pop('label')
    dataSet = tf.data.Dataset.from_tensor_slices((dict(df), labels)).shuffle(shuffleSize).batch(batchSize)
    return dataSet

def getTrainColumns(train_columns,data_vocab):
    columns = []
    # 离散特征
    for feature in train_columns:
        if data_vocab.get(feature):
            if feature.startswith("userRatedHistory") or feature.count("__") > 0 or feature in embedding_columns:
                cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
                col = tf.feature_column.embedding_column(cat_col, 10)
                columns.append(col)

            elif feature in one_hot_columns or feature.count("Bucket") > 0:
                cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
                col = tf.feature_column.indicator_column(cat_col)
                columns.append(col)
        elif feature in ITEM_NUMBER_COLUMNS or feature.endswith("RatingAvg") or feature.endswith("RatingStddev"):
            col = tf.feature_column.numeric_column(feature)
            columns.append(col)
    return columns


def train(columns,train_dataset):
    model = tf.keras.Sequential([
        tf.keras.layers.DenseFeatures(columns),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(128, activation='relu'),
        tf.keras.layers.Dense(1, activation='sigmoid'),
    ])

    # compile the model, set loss function, optimizer and evaluation metrics
    model.compile(
        loss='mse',
        optimizer='adam',
        metrics=['accuracy', tf.keras.metrics.AUC(curve='ROC'), tf.keras.metrics.AUC(curve='PR')])

    # train the model
    print("train start...")
    model.fit(train_dataset, epochs=5)
    print("train end...")
    print("train save...")

    model.save(model_file, include_optimizer=False, save_format='tf')

def evaluate(model,test_dataset):
    # evaluate the model
    timestmp1 = int(round(time.time()))
    print("evaluate:")
    test_loss, test_accuracy, test_roc_auc, test_pr_auc = model.evaluate(test_dataset)
    print('\n\nTest Loss {}, Test Accuracy {}, Test ROC AUC {}, Test PR AUC {}'.format(test_loss, test_accuracy,
                                                                                       test_roc_auc, test_pr_auc))
    print("验证耗时s:", int(round(time.time())) - timestmp1)


def predict(model_path,df):
    print("加载模型中...")
    model_new = tf.keras.models.load_model("service_fm_v3")
    # model_new.summary()
    print("模型加载完成...")
    # model = tf.keras.models.model_from_json(model.to_json)
    n = 1000
    dd = dict(df.sample(n=n))
    for i in range(10):
        timestmp1 = int(round(time.time() * 1000))
        model_new.predict(dd, batch_size=10000)
        print("测试样本数：{},测试耗时ms:{}".format(n, int(round(time.time() * 1000)) - timestmp1))



if __name__ == '__main__':
    # redis中加载数据字典
    print("redis 中加载模型字典...")
    data_vocab = getDataVocabFromRedis(version)
    assert data_vocab

    print("读取数据...")
    timestmp1 = int(round(time.time()))
    df_train = loadData(data_path_train)
    df_test = loadData(data_path_test)
    timestmp2 = int(round(time.time()))
    print("读取数据耗时s:{}".format(timestmp2 - timestmp1))

    # df_train = df_train[list(data_vocab.keys()) + ITEM_NUMBER_COLUMNS + ["label"]]
    # df_test = df_test[list(data_vocab.keys()) + ITEM_NUMBER_COLUMNS + ["label"]]

    trainSize = df_train["label"].count()
    testSize = df_test["label"].count()
    print("trainSize:{},testSize{}".format(trainSize,testSize))

    # 数据类型转换
    df_train = csvTypeConvert(df_train,data_vocab)
    df_test = csvTypeConvert(df_test,data_vocab)
    columns = df_train.columns.tolist()

    # 获取训练数据
    train_data = getDataSet(df_train,shuffleSize=trainSize,)
    test_data = getDataSet(df_test,shuffleSize=testSize)

    # 获取训练列
    columns = getTrainColumns(columns,data_vocab)
    timestmp3 = int(round(time.time()))
    model = train(columns,train_data)
    timestmp4 = int(round(time.time()))
    print("读取数据耗时h:{}".format((timestmp4 - timestmp3)/60/60))

    # evaluate(model,test_data)
    predict(model_file,test_data)
    pass

