import datetime
import random
import sys
import time

import tensorflow as tf

from models.esmm.diary_model import model_predict_diary
from models.esmm.fe import device_fe, diary_fe, tractate_fe
from models.esmm.tractate_model import model_predict_tractate
from utils.cache import get_essm_model_save_path, redis_client2
from utils.grey import recommed_service_category_device_id_by_tail
from utils.portrait import (get_user_portrait_tag3_read_v2, user_portrait_tag3_get_candidate_dict,
                            user_portrait_tag3_get_candidate_unread_list, user_portrait_tag3_write_ctcvr_data)


def user_portrait_scan_info(device_dict, diary_dict, tractate_dict, diary_predict_fn, tractate_predict_fn, tail_number):
    try:
        round = 0
        all_count = 0
        keys = "doris:user_portrait:tag3:device_id:*"
        cur, results = redis_client2.scan(0, keys, 3000)
        while cur != 0:
            round += 1
            print("round: " + str(round))
            cur, results = redis_client2.scan(cur, keys, 3000)
            for key in results:
                key = str(key, "utf-8")
                device_id = key.split(":")[-1]

                if recommed_service_category_device_id_by_tail(device_id, [tail_number]):
                    if (user_portrait_tag3_get_candidate_dict(device_id, "diary")):
                        all_count += 1
                        print(str(all_count) + ": " + device_id)
                        offline_predict_diary(device_id, device_dict, diary_dict, diary_predict_fn)
                        offline_predict_tractate(device_id, device_dict, tractate_dict, tractate_predict_fn)
                        print("=========================================\n")

        print("all count: " + str(all_count))
        print("scan done " + str(datetime.datetime.now()))
    except Exception as e:
        print(e)


def offline_predict_diary(device_id, device_dict, diary_dict, predict_fn, size=300):
    time_begin = time.time()

    diary_ids = user_portrait_tag3_get_candidate_unread_list(device_id, "diary", size=size)
    print("diary_ids: " + str(len(diary_ids)))
    res = model_predict_diary(device_id, diary_ids, device_dict, diary_dict, predict_fn)
    print("res: " + str(len(res)))
    # print("res: " + str(len(res)) + " " + str(res[:5]))
    user_portrait_tag3_write_ctcvr_data(device_id, "diary", res[:500])

    total_time = (time.time() - time_begin)
    print("total cost {:.5f}s".format(total_time))


def offline_predict_tractate(device_id, device_dict, tractate_dict, predict_fn, size=300):
    time_begin = time.time()

    tractate_ids = user_portrait_tag3_get_candidate_unread_list(device_id, "tractate", size=size)
    print("tractate_ids: " + str(len(tractate_ids)))
    res = model_predict_tractate(device_id, tractate_ids, device_dict, tractate_dict, predict_fn)
    print("res: " + str(len(res)))
    # print(res[:10])
    user_portrait_tag3_write_ctcvr_data(device_id, "tractate", res[:500])

    total_time = (time.time() - time_begin)
    print("total cost {:.5f}s".format(total_time))


def main():
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    device_dict = device_fe.get_device_dict_from_redis()
    diary_dict = diary_fe.get_diary_dict_from_redis()
    tractate_dict = tractate_fe.get_tractate_dict_from_redis()
    print("redis data: " + str(len(device_dict)) + " " + str(len(diary_dict)) + " " + str(len(tractate_dict)))

    diary_save_path = get_essm_model_save_path("diary")
    if not diary_save_path:
        diary_save_path = "/data/files/models/diary/1597390452"
        print(diary_save_path + "!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    diary_predict_fn = tf.contrib.predictor.from_saved_model(diary_save_path)

    tractate_save_path = get_essm_model_save_path("tractate")
    if not tractate_save_path:
        tractate_save_path = "/data/files/models/tractate/1597390051"
        print(tractate_save_path + "!!!!!!!!!!!!!!!!!!!!!!!!!!!")
    tractate_predict_fn = tf.contrib.predictor.from_saved_model(tractate_save_path)

    device_id = "androidid_a25a1129c0b38f7b"
    offline_predict_diary(device_id, device_dict, diary_dict, diary_predict_fn, size=100)
    offline_predict_tractate(device_id, device_dict, tractate_dict, tractate_predict_fn, size=100)

    # res = user_portrait_tag3_get_candidate_unread_list(device_id, "diary")
    # print(len(res))
    # print(res[:10])

    tail_number = sys.argv[1]  # "c", "d", "e", "f"
    user_portrait_scan_info(device_dict, diary_dict, tractate_dict, diary_predict_fn, tractate_predict_fn, tail_number)


if __name__ == "__main__":
    time_begin = time.time()

    main()

    total_time = (time.time() - time_begin) / 60
    print("total cost {:.2f} mins at {}".format(total_time, datetime.datetime.now()))
