import timeit

import numba
import tensorflow as tf
from tensorflow import feature_column as fc
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.ops.losses import losses


def build_deep_layer(net, params):
    for num_hidden_units in params["hidden_units"]:
        net = tf.layers.dense(net,
                              units=num_hidden_units,
                              activation=tf.nn.relu,
                              kernel_initializer=tf.glorot_uniform_initializer())
    return net


def esmm_model_fn(features, labels, mode, params):
    net = tf.compat.v1.feature_column.input_layer(features, params["feature_columns"])
    last_ctr_layer = build_deep_layer(net, params)
    last_cvr_layer = build_deep_layer(net, params)

    head = head_lib._binary_logistic_or_multi_class_head(n_classes=2,
                                                         weight_column=None,
                                                         label_vocabulary=None,
                                                         loss_reduction=losses.Reduction.SUM)

    ctr_logits = tf.layers.dense(last_ctr_layer, units=head.logits_dimension, kernel_initializer=tf.glorot_uniform_initializer())
    cvr_logits = tf.layers.dense(last_cvr_layer, units=head.logits_dimension, kernel_initializer=tf.glorot_uniform_initializer())
    ctr_preds = tf.sigmoid(ctr_logits)
    cvr_preds = tf.sigmoid(cvr_logits)
    ctcvr_preds = tf.multiply(ctr_preds, cvr_preds)

    # optimizer = tf.compat.v1.train.AdamOptimizer()
    # click_label = features["click_label"]
    # conversion_label = features["conversion_label"]
    # device_id = features["device_id"]
    # card_id = features["card_id"]
    if mode == tf.estimator.ModeKeys.PREDICT:
        predictions = {
            "ctr_preds": ctr_preds,
            "cvr_preds": cvr_preds,
            "ctcvr_preds": ctcvr_preds,
            # "device_id": device_id,
            # "card_id": card_id
        }
        export_outputs = {"prediction": tf.estimator.export.PredictOutput(predictions["cvr_preds"])}
        return tf.estimator.EstimatorSpec(mode, predictions=predictions, export_outputs=export_outputs)
    else:
        ctr_labels = tf.reshape(tf.cast(labels["click_label"], tf.float32), (-1, 1))
        cvr_labels = tf.reshape(tf.cast(labels["conversion_label"], tf.float32), (-1, 1))
        optimizer = tf.compat.v1.train.AdagradOptimizer(learning_rate=params.get("learning_rate", 0.03))
        ctr_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=ctr_labels, logits=ctr_logits))
        ctcvr_loss = tf.reduce_sum(tf.compat.v1.losses.log_loss(labels=cvr_labels, predictions=ctcvr_preds))
        loss = ctr_loss + ctcvr_loss

        if mode == tf.estimator.ModeKeys.EVAL:
            ctr_accuracy = tf.compat.v1.metrics.accuracy(labels=ctr_labels,
                                                         predictions=tf.to_float(tf.greater_equal(ctr_preds, 0.5)))
            ctcvr_accuracy = tf.compat.v1.metrics.accuracy(labels=cvr_labels,
                                                           predictions=tf.to_float(tf.greater_equal(ctcvr_preds, 0.5)))
            ctr_auc = tf.compat.v1.metrics.auc(labels=ctr_labels, predictions=ctr_preds)
            ctcvr_auc = tf.compat.v1.metrics.auc(labels=cvr_labels, predictions=ctcvr_preds)
            metrics = {"ctcvr_accuracy": ctcvr_accuracy, "ctr_accuracy": ctr_accuracy, "ctr_auc": ctr_auc, "ctcvr_auc": ctcvr_auc}
            tf.compat.v1.summary.scalar("ctr_accuracy", ctr_accuracy[1])
            tf.compat.v1.summary.scalar("ctcvr_accuracy", ctcvr_accuracy[1])
            tf.compat.v1.summary.scalar("ctr_auc", ctr_auc[1])
            tf.compat.v1.summary.scalar("ctcvr_auc", ctcvr_auc[1])
            return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
        train_op = optimizer.minimize(loss, global_step=tf.compat.v1.train.get_global_step())
        res = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)
        return res


def model_export(model, feature_columns, save_path):
    feature_spec = fc.make_parse_example_spec(feature_columns)
    serving_input_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(feature_spec)
    path = str(model.export_saved_model(save_path, serving_input_fn), encoding="utf-8")
    return path


def _int64_feature(value):
    return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))


def _float_feature(value):
    return tf.train.Feature(float_list=tf.train.FloatList(value=[value]))


def _bytes_feature(value):
    return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))


@numba.jit(nopython=True, parallel=True)
def model_predict(inputs, predict_fn):
    time_1 = timeit.default_timer()
    int_columns = [
        "active_type", "active_days", "card_id", "is_pure_author", "is_have_reply", "is_have_pure_reply", "content_level",
        "topic_num", "favor_num", "vote_num"
    ]
    float_columns = ["one_ctr", "three_ctr", "seven_ctr", "fifteen_ctr"]
    examples = []
    for index, row in inputs.iterrows():
        features = {}
        for col, value in row.iteritems():
            if col in ["click_label", "conversion_label"]:
                pass
            elif col in int_columns:
                features[col] = _int64_feature(int(value))
            elif col in float_columns:
                features[col] = _float_feature(float(value))
            else:
                features[col] = _bytes_feature(str(value).encode(encoding="utf-8"))
        example = tf.train.Example(features=tf.train.Features(feature=features))
        examples.append(example.SerializeToString())
    total_1 = (timeit.default_timer() - time_1)
    print("make example cost {:.5f}s".format(total_1))

    time_1 = timeit.default_timer()
    predictions = predict_fn({"examples": examples})
    total_1 = (timeit.default_timer() - time_1)
    print("prediction cost {:.5f}s".format(total_1))
    # print(predictions)
    # TODO get the diary ids
    return predictions
