import tensorflow as tf
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
from datetime import date, timedelta
import time
from pyspark import StorageLevel
import os

def model_fn(features, labels, mode, params):
    """Bulid Model function f(x) for Estimator."""
    #------hyperparameters----
    field_size = params["field_size"]
    feature_size = params["feature_size"]
    embedding_size = params["embedding_size"]
    l2_reg = params["l2_reg"]
    learning_rate = params["learning_rate"]
    #optimizer = params["optimizer"]
    layers = list(map(int, params["deep_layers"].split(',')))
    dropout = list(map(float, params["dropout"].split(',')))
    ctr_task_wgt = params["ctr_task_wgt"]
    common_dims = field_size*embedding_size

    #------bulid weights------
    Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer())

    feat_ids = features['ids']
    app_list = features['app_list']
    level2_list = features['level2_list']
    level3_list = features['level3_list']
    tag1_list = features['tag1_list']
    tag2_list = features['tag2_list']
    tag3_list = features['tag3_list']
    tag4_list = features['tag4_list']
    tag5_list = features['tag5_list']
    tag6_list = features['tag6_list']
    tag7_list = features['tag7_list']


    #------build f(x)------
    with tf.variable_scope("Shared-Embedding-layer"):
        embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
        app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum")
        level2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level2_list, sp_weights=None, combiner="sum")
        level3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level3_list, sp_weights=None, combiner="sum")
        tag1 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag1_list, sp_weights=None, combiner="sum")
        tag2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag2_list, sp_weights=None, combiner="sum")
        tag3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag3_list, sp_weights=None, combiner="sum")
        tag4 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag4_list, sp_weights=None, combiner="sum")
        tag5 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag5_list, sp_weights=None, combiner="sum")
        tag6 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag6_list, sp_weights=None, combiner="sum")
        tag7 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag7_list, sp_weights=None, combiner="sum")

        # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims])  # None * (F * K)
        x_concat = tf.concat([tf.reshape(embedding_id, shape=[-1, common_dims]), app_id, level2, level3, tag1,
                              tag2, tag3, tag4, tag5, tag6, tag7], axis=1)

    with tf.name_scope("CVR_Task"):
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_phase = True
        else:
            train_phase = False
        x_cvr = x_concat
        for i in range(len(layers)):
            x_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=layers[i], \
                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_mlp%d' % i)
        y_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=1, activation_fn=tf.identity, \
            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_out')
        y_cvr = tf.reshape(y_cvr,shape=[-1])

    with tf.name_scope("CTR_Task"):
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_phase = True
        else:
            train_phase = False

        x_ctr = x_concat
        for i in range(len(layers)):
            x_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=layers[i], \
                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_mlp%d' % i)
        y_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=1, activation_fn=tf.identity, \
            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_out')
        y_ctr = tf.reshape(y_ctr,shape=[-1])

    with tf.variable_scope("MTL-Layer"):
        pctr = tf.sigmoid(y_ctr)
        pcvr = tf.sigmoid(y_cvr)
        pctcvr = pctr*pcvr

    predictions={"pcvr": pcvr, "pctr": pctr, "pctcvr": pctcvr}
    export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
    # Provide an estimator spec for `ModeKeys.PREDICT`
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    print('Parsing', filenames)
    def _parse_fn(record):
        features = {
            "y": tf.FixedLenFeature([], tf.float32),
            "z": tf.FixedLenFeature([], tf.float32),
            "ids": tf.FixedLenFeature([15], tf.int64),
            "app_list": tf.VarLenFeature(tf.int64),
            "level2_list": tf.VarLenFeature(tf.int64),
            "level3_list": tf.VarLenFeature(tf.int64),
            "tag1_list": tf.VarLenFeature(tf.int64),
            "tag2_list": tf.VarLenFeature(tf.int64),
            "tag3_list": tf.VarLenFeature(tf.int64),
            "tag4_list": tf.VarLenFeature(tf.int64),
            "tag5_list": tf.VarLenFeature(tf.int64),
            "tag6_list": tf.VarLenFeature(tf.int64),
            "tag7_list": tf.VarLenFeature(tf.int64)
        }
        parsed = tf.parse_single_example(record, features)
        y = parsed.pop('y')
        z = parsed.pop('z')
        return parsed, {"y": y, "z": z}

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use
    # dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None]))   #不定长补齐

    #return dataset.make_one_shot_iterator()
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
    #print("-"*100)
    #print(batch_features,batch_labels)
    return batch_features, batch_labels

def main(te_file):
    dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
    model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/" + dt_dir
    te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
    model_params = {
        "field_size": 15,
        "feature_size": 600000,
        "embedding_size": 16,
        "learning_rate": 0.0001,
        "l2_reg": 0.005,
        "deep_layers": '512,256,128,64,32',
        "dropout": '0.3,0.3,0.3,0.3,0.3',
        "ctr_task_wgt":0.5
    }
    config = tf.estimator.RunConfig().replace(session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':36}),
            log_step_count_steps=100, save_summary_steps=100)
    Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config)

    preds = Estimator.predict(input_fn=lambda: input_fn(te_file, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])

    # with open("/home/gmuser/esmm/nearby/pred.txt", "w") as fo:
    #     for prob in preds:
    #         fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr']))

    indices = []
    for prob in preds:
        indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
    return indices

def test_map(x):
    return x * x


if __name__ == "__main__":
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    path = "hdfs://172.16.32.4:8020/strategy/esmm/"
    # df = spark.read.format("tfrecords").load(path+"nearby/part-r-00000")
    # df.show()

    # name = spark.sparkContext.parallelize([1,2,3,4,5])
    #
    # test = name.repartition(5).map(lambda x: test_map(x))
    # print(test)
    # print(test.collect())

    tf.logging.set_verbosity(tf.logging.INFO)

    # te_files = [path + "nearby/part-r-00000"]
    # main(te_files)
    te_files = [[path+"nearby/part-r-00000"],[path+"native/part-r-00000"]]
    rdd_te_files = spark.sparkContext.parallelize(te_files)
    indices = rdd_te_files.repartition(2).map(lambda x: main(x))
    print(indices.collect())


    b = time.time()
    print("耗时(分钟):")
    print((time.time()-b)/60)