import timeit

import tensorflow as tf

from .fe.diary_fe import device_diary_fe
from .model import _bytes_feature, _float_feature, _int64_feature

_int_columns = [
    "active_type", "active_days", "card_id", "is_pure_author", "is_have_reply", "is_have_pure_reply", "content_level",
    "topic_num", "favor_num", "vote_num"
]
_float_columns = ["one_ctr", "three_ctr", "seven_ctr", "fifteen_ctr"]
_categorical_columns = [
    "device_id", "past_consume_ability_history", "potential_consume_ability_history", "price_sensitive_history", "device_fd",
    "device_sd", "device_fs", "device_ss", "device_fp", "device_sp", "device_p", "content_fd", "content_sd", "content_fs",
    "content_ss", "content_fp", "content_sp", "content_p", "fd1", "fd2", "fd3", "sd1", "sd2", "sd3", "fs1", "fs2", "fs3", "ss1",
    "ss2", "ss3", "fp1", "fp2", "fp3", "sp1", "sp2", "sp3", "p1", "p2", "p3"
]

PREDICTION_ALL_COLUMNS = _int_columns + _float_columns + _categorical_columns


def model_predict_diary(device_id, diary_ids, device_dict, diary_dict, predict_fn):
    try:
        time_1 = timeit.default_timer()
        device_info, diary_lst, diary_ids_res = device_diary_fe(device_id, diary_ids, device_dict, diary_dict)
        print("predict check: " + str(len(diary_lst)) + " " + str(len(diary_ids_res)))

        examples = []
        for diary_info in diary_lst:
            tmp = {}
            tmp.update(device_info)
            tmp.update(diary_info)
            features = {}
            for col in _int_columns:
                features[col] = _int64_feature(int(tmp[col]))
            for col in _float_columns:
                features[col] = _float_feature(float(tmp[col]))
            for col in _categorical_columns:
                features[col] = _bytes_feature(str(tmp[col]).encode(encoding="utf-8"))
            example = tf.train.Example(features=tf.train.Features(feature=features))
            examples.append(example.SerializeToString())

        total_1 = (timeit.default_timer() - time_1)
        print("make example cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        predictions = predict_fn({"examples": examples})
        res_tuple = sorted(zip(diary_ids_res, predictions["output"].tolist()), key=lambda x: x[1], reverse=True)
        res = []
        for (id, _) in res_tuple:
            res.append(int(id))
        # print(res)
        total_1 = (timeit.default_timer() - time_1)
        print("prediction cost {:.5f}s".format(total_1))
        return res
    except Exception as e:
        print(e)
        # device_info, _, _ = device_diary_fe(device_id, diary_ids, device_dict, diary_dict)
        # print(device_info)
        return []
