import datetime
import random
import sys
import time

import tensorflow as tf

from models.esmm import device_fe as device_fe
from models.esmm import diary_fe as diary_fe
from models.esmm.model import model_predict_diary
from utils.cache import redis_client2
from utils.grey import recommed_service_category_device_id_by_tail
from utils.portrait import (user_portrait_tag3_get_candidate_dict, user_portrait_tag3_get_candidate_unread_list,
                            user_portrait_tag3_write_ctcvr_data)


def user_portrait_scan_info(device_dict, diary_dict, predict_fn, tail_number):
    try:
        round = 0
        all_count = 0
        keys = "doris:user_portrait:tag3:device_id:*"
        cur, results = redis_client2.scan(0, keys, 3000)
        while cur != 0:
            round += 1
            print("round: " + str(round))
            cur, results = redis_client2.scan(cur, keys, 3000)
            for key in results:
                key = str(key, "utf-8")
                device_id = key.split(":")[-1]

                if recommed_service_category_device_id_by_tail(device_id, [tail_number]):
                    if (user_portrait_tag3_get_candidate_dict(device_id, "diary")):
                        all_count += 1
                        print(str(all_count) + ": " + device_id)
                        offline_predict(device_id, device_dict, diary_dict, predict_fn)

        print("all count: " + str(all_count))
        print("scan done " + str(datetime.datetime.now()))
    except Exception as e:
        print(e)


def offline_predict(device_id, device_dict, diary_dict, predict_fn):
    time_begin = time.time()

    diary_ids = user_portrait_tag3_get_candidate_unread_list(device_id, "diary")
    print("diary_ids: " + str(len(diary_ids)))
    res = model_predict_diary(device_id, diary_ids, device_dict, diary_dict, predict_fn)
    print("res: " + str(len(res)))
    # print("res: " + str(len(res)) + " " + str(res[:5]))
    user_portrait_tag3_write_ctcvr_data(device_id, "diary", res[:500])

    total_time = (time.time() - time_begin)
    print("total cost {:.5f}s\n".format(total_time))


def main():
    tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.INFO)

    device_dict = device_fe.get_device_dict_from_redis()
    diary_dict = diary_fe.get_diary_dict_from_redis()
    print("redis data: " + str(len(device_dict)) + " " + str(len(diary_dict)))

    save_path = "/home/gmuser/data/models/1596018742"
    predict_fn = tf.contrib.predictor.from_saved_model(save_path)

    # device_id = "androidid_a25a1129c0b38f7b"
    # offline_predict(device_id, device_dict, diary_dict, predict_fn)

    # res = user_portrait_tag3_get_candidate_unread_list(device_id, "diary")
    # print(len(res))
    # print(res[:10])

    tail_number = sys.argv[1]
    # "c", "d", "e", "f"
    user_portrait_scan_info(device_dict, diary_dict, predict_fn, tail_number)


if __name__ == "__main__":
    time_begin = time.time()

    main()

    total_time = (time.time() - time_begin) / 60
    print("total cost {:.2f} mins at {}".format(total_time, datetime.datetime.now()))
