import os
import random
import shutil
import time
import timeit
from datetime import datetime
from functools import wraps
from pathlib import Path

import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split

from models.esmm.fe import (click_feature_engineering, device_feature_engineering, diary_feature_engineering,
                            get_device_df_from_redis, get_diary_df_from_redis, join_device_diary, join_features, read_csv_data)
from models.esmm.input_fn import build_features, esmm_input_fn
from models.esmm.model import esmm_model_fn, model_export, model_predict

# tf.compat.v1.enable_eager_execution()


def time_cost(func):
    @wraps(func)
    def wrapper(*args, **kwargs):
        time_1 = timeit.default_timer()

        ret = func(*args, **kwargs)

        total_1 = (timeit.default_timer() - time_1)
        print("cost {:.5f}s".format(total_1))
        return ret

    return wrapper


def main():
    time_begin = time.time()

    device_df, diary_df, click_df, conversion_df = read_csv_data(Path("~/data/cvr_data/"))
    # print(diary_df.sample(1))
    device_df = device_feature_engineering(device_df)
    # print(device_df.sample(1))
    diary_df = diary_feature_engineering(diary_df)
    # print(diary_df.sample(1))
    cc_df = click_feature_engineering(click_df, conversion_df)
    df = join_features(device_df, diary_df, cc_df)

    train_df, test_df = train_test_split(df, test_size=0.2)
    train_df, val_df = train_test_split(train_df, test_size=0.2)

    # all_features = build_features(df)

    # params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
    # model_path = str(Path("~/data/model_tmp/").expanduser())
    # if os.path.exists(model_path):
    #     shutil.rmtree(model_path)
    # model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path)

    # print("train")
    # model.train(input_fn=lambda: esmm_input_fn(train_df, shuffle=True), steps=5000)
    # metrics = model.evaluate(input_fn=lambda: esmm_input_fn(val_df, False), steps=5000)
    # print("metrics: " + str(metrics))

    # model_export_path = str(Path("~/data/models/").expanduser())
    # save_path = model_export(model, all_features, model_export_path)
    # print("save to: " + save_path)

    save_path = "/home/gmuser/data/models/1595317247"
    predict_fn = tf.contrib.predictor.from_saved_model(save_path)

    # for i in range(10):
    #     test_300 = test_df.sample(300)
    #     model_predict(test_300, predict_fn)

    print("==============================")
    # device_id = "861601036552944"
    # diary_ids = [
    #     "16195283", "16838351", "17161073", "17297878", "17307484", "17396235", "16418737", "16995481", "17312201", "12237988"
    # ]

    df = get_device_df_from_redis()
    df2 = get_diary_df_from_redis()
    redis_device_df = device_feature_engineering(df)
    redis_diary_df = diary_feature_engineering(df2, from_redis=True)
    device_ids = list(redis_device_df["device_id"].values)[:20]
    diary_ids = list(redis_diary_df["card_id"].values)

    def test1():
        time_1 = timeit.default_timer()
        user1 = join_device_diary(random.sample(device_ids, 1)[0], random.sample(diary_ids, 300), redis_device_df, redis_diary_df)
        total_1 = (timeit.default_timer() - time_1)
        print("join df cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        model_predict(user1, predict_fn)
        total_1 = (timeit.default_timer() - time_1)
        print("total prediction cost {:.5f}s".format(total_1), "\n")

    def test2():
        time_1 = timeit.default_timer()
        user1 = join_device_diary(random.sample(device_ids, 1)[0], random.sample(diary_ids, 300), redis_device_df, redis_diary_df)
        total_1 = (timeit.default_timer() - time_1)
        print("join df cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        model_predict(user1, predict_fn)
        total_1 = (timeit.default_timer() - time_1)
        print("total prediction cost {:.5f}s".format(total_1), "\n")

    def test3():
        time_1 = timeit.default_timer()
        user1 = join_device_diary(random.sample(device_ids, 1)[0], random.sample(diary_ids, 300), redis_device_df, redis_diary_df)
        total_1 = (timeit.default_timer() - time_1)
        print("join df cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        model_predict(user1, predict_fn)
        total_1 = (timeit.default_timer() - time_1)
        print("total prediction cost {:.5f}s".format(total_1), "\n")

    def test4():
        time_1 = timeit.default_timer()
        user1 = join_device_diary(random.sample(device_ids, 1)[0], random.sample(diary_ids, 300), redis_device_df, redis_diary_df)
        total_1 = (timeit.default_timer() - time_1)
        print("join df cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        model_predict(user1, predict_fn)
        total_1 = (timeit.default_timer() - time_1)
        print("total prediction cost {:.5f}s".format(total_1), "\n")

    def test5():
        time_1 = timeit.default_timer()
        user1 = join_device_diary(random.sample(device_ids, 1)[0], random.sample(diary_ids, 300), redis_device_df, redis_diary_df)
        total_1 = (timeit.default_timer() - time_1)
        print("join df cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        model_predict(user1, predict_fn)
        total_1 = (timeit.default_timer() - time_1)
        print("total prediction cost {:.5f}s".format(total_1), "\n")

    test1()
    test2()
    test3()
    test4()
    test5()

    total_time = (time.time() - time_begin) / 60
    print("total cost {:.2f} mins at {}".format(total_time, datetime.now()))


if __name__ == "__main__":
    main()
