import os
import random
import shutil
import time
import timeit
from datetime import datetime
from pathlib import Path

import pandas as pd
import tensorflow as tf
from sklearn.model_selection import train_test_split

from models.esmm.fe import (click_feature_engineering, device_feature_engineering, diary_feature_engineering,
                            get_device_dict_from_redis, get_diary_dict_from_redis, join_features, read_csv_data)
from models.esmm.input_fn import build_features, esmm_input_fn
from models.esmm.model import esmm_model_fn, model_export, model_predict_diary

# tf.compat.v1.enable_eager_execution()


def main():
    time_begin = time.time()

    tf.logging.set_verbosity(tf.logging.INFO)

    device_df, diary_df, click_df, conversion_df = read_csv_data(Path("~/data/cvr_data/"))
    # print(diary_df.sample(1))
    device_df = device_feature_engineering(device_df)
    # print(device_df.sample(1))
    diary_df = diary_feature_engineering(diary_df)
    # print(diary_df.sample(1))
    cc_df = click_feature_engineering(click_df, conversion_df)
    df = join_features(device_df, diary_df, cc_df)

    train_df, test_df = train_test_split(df, test_size=0.2)
    train_df, val_df = train_test_split(train_df, test_size=0.2)

    all_features = build_features(df)
    params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
    model_path = str(Path("~/data/model_tmp/").expanduser())
    # if os.path.exists(model_path):
    #     shutil.rmtree(model_path)

    model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path)
    train_spec = tf.estimator.TrainSpec(input_fn=lambda: esmm_input_fn(train_df, shuffle=True), max_steps=None)
    eval_spec = tf.estimator.EvalSpec(input_fn=lambda: esmm_input_fn(val_df, shuffle=False))
    tf.estimator.train_and_evaluate(model, train_spec, eval_spec)

    # model.train(input_fn=lambda: esmm_input_fn(train_df, shuffle=True))
    # metrics = model.evaluate(input_fn=lambda: esmm_input_fn(val_df, False))
    # print("metrics: " + str(metrics))

    model_export_path = str(Path("~/data/models/").expanduser())
    save_path = model_export(model, all_features, model_export_path)
    print("save to: " + save_path)

    # save_path = "/home/gmuser/data/models/1595317247"

    # save_path = str(Path("~/Desktop/models/1595297428").expanduser())

    # tf.saved_model.load

    predict_fn = tf.contrib.predictor.from_saved_model(save_path)

    # for i in range(5):
    #     test_300 = test_df.sample(300)
    #     model_predict(test_300, predict_fn)

    print("==============================")
    # device_id = "861601036552944"
    # diary_ids = [
    #     "16195283", "16838351", "17161073", "17297878", "17307484", "17396235", "16418737", "16995481", "17312201", "12237988"
    # ]

    # device_dict = get_device_dict_from_redis()
    # diary_dict = get_diary_dict_from_redis()

    # device_ids = list(device_dict.keys())[:20]
    # diary_ids = list(diary_dict.keys())

    # for i in range(2):
    #     time_1 = timeit.default_timer()
    #     model_predict_diary(random.sample(device_ids, 1)[0], random.sample(diary_ids, 200), device_dict, diary_dict, predict_fn)
    #     total_1 = (timeit.default_timer() - time_1)
    #     print("total prediction cost {:.5f}s".format(total_1), "\n")

    total_time = (time.time() - time_begin) / 60
    print("total cost {:.2f} mins at {}".format(total_time, datetime.now()))


if __name__ == "__main__":
    main()
