import timeit

import tensorflow as tf

from .fe.tractate_fe import (CATEGORICAL_COLUMNS, FLOAT_COLUMNS, INT_COLUMNS, device_tractate_fe)
from .model import _bytes_feature, _float_feature, _int64_feature


def model_predict_tractate(device_id, tractate_ids, device_dict, tractate_dict, predict_fn):
    try:
        time_1 = timeit.default_timer()
        device_info, tractate_lst, tractate_ids_res = device_tractate_fe(device_id, tractate_ids, device_dict, tractate_dict)
        print("predict check: " + str(len(tractate_lst)) + " " + str(len(tractate_ids_res)))
        int_columns = INT_COLUMNS
        float_columns = FLOAT_COLUMNS
        str_columns = CATEGORICAL_COLUMNS
        examples = []
        for tractate_info in tractate_lst:
            tmp = {}
            tmp.update(device_info)
            tmp.update(tractate_info)
            features = {}
            for col in int_columns:
                features[col] = _int64_feature(int(tmp[col]))
            for col in float_columns:
                features[col] = _float_feature(float(tmp[col]))
            for col in str_columns:
                features[col] = _bytes_feature(str(tmp[col]).encode(encoding="utf-8"))
            example = tf.train.Example(features=tf.train.Features(feature=features))
            examples.append(example.SerializeToString())

        total_1 = (timeit.default_timer() - time_1)
        print("make example cost {:.5f}s".format(total_1))

        time_1 = timeit.default_timer()
        predictions = predict_fn({"examples": examples})
        print(predictions)
        # res_tuple = sorted(zip(tractate_ids_res, predictions["output"].tolist()), key=lambda x: x[1], reverse=True)
        # res = []
        # for (id, _) in res_tuple:
        #     res.append(int(id))
        # # print(res)
        total_1 = (timeit.default_timer() - time_1)
        print("prediction cost {:.5f}s".format(total_1))
        # return res
    except Exception as e:
        print(e)
        # device_info, _, _ = device_tractate_fe(device_id, tractate_ids, device_dict, tractate_dict)
        # print(device_info)
        return []
