import tensorflow as tf
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import time
from pyspark import StorageLevel


def model_fn(features, labels, mode, params):
    """Bulid Model function f(x) for Estimator."""
    #------hyperparameters----
    field_size = params["field_size"]
    feature_size = params["feature_size"]
    embedding_size = params["embedding_size"]
    l2_reg = params["l2_reg"]
    learning_rate = params["learning_rate"]
    #optimizer = params["optimizer"]
    layers = list(map(int, params["deep_layers"].split(',')))
    dropout = list(map(float, params["dropout"].split(',')))
    ctr_task_wgt = params["ctr_task_wgt"]
    common_dims = field_size*embedding_size

    #------bulid weights------
    Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer())

    feat_ids = features['ids']
    app_list = features['app_list']
    level2_list = features['level2_list']
    level3_list = features['level3_list']
    tag1_list = features['tag1_list']
    tag2_list = features['tag2_list']
    tag3_list = features['tag3_list']
    tag4_list = features['tag4_list']
    tag5_list = features['tag5_list']
    tag6_list = features['tag6_list']
    tag7_list = features['tag7_list']


    #------build f(x)------
    with tf.variable_scope("Shared-Embedding-layer"):
        embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
        app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum")
        level2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level2_list, sp_weights=None, combiner="sum")
        level3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level3_list, sp_weights=None, combiner="sum")
        tag1 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag1_list, sp_weights=None, combiner="sum")
        tag2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag2_list, sp_weights=None, combiner="sum")
        tag3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag3_list, sp_weights=None, combiner="sum")
        tag4 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag4_list, sp_weights=None, combiner="sum")
        tag5 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag5_list, sp_weights=None, combiner="sum")
        tag6 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag6_list, sp_weights=None, combiner="sum")
        tag7 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag7_list, sp_weights=None, combiner="sum")

        # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims])  # None * (F * K)
        x_concat = tf.concat([tf.reshape(embedding_id, shape=[-1, common_dims]), app_id, level2, level3, tag1,
                              tag2, tag3, tag4, tag5, tag6, tag7], axis=1)

    with tf.name_scope("CVR_Task"):
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_phase = True
        else:
            train_phase = False
        x_cvr = x_concat
        for i in range(len(layers)):
            x_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=layers[i], \
                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_mlp%d' % i)
        y_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=1, activation_fn=tf.identity, \
            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_out')
        y_cvr = tf.reshape(y_cvr,shape=[-1])

    with tf.name_scope("CTR_Task"):
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_phase = True
        else:
            train_phase = False

        x_ctr = x_concat
        for i in range(len(layers)):
            x_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=layers[i], \
                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_mlp%d' % i)
        y_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=1, activation_fn=tf.identity, \
            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_out')
        y_ctr = tf.reshape(y_ctr,shape=[-1])

    with tf.variable_scope("MTL-Layer"):
        pctr = tf.sigmoid(y_ctr)
        pcvr = tf.sigmoid(y_cvr)
        pctcvr = pctr*pcvr

    predictions={"pcvr": pcvr, "pctr": pctr, "pctcvr": pctcvr}
    export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
    # Provide an estimator spec for `ModeKeys.PREDICT`
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)

def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    print('Parsing', filenames)
    def _parse_fn(record):
        features = {
            "y": tf.FixedLenFeature([], tf.float32),
            "z": tf.FixedLenFeature([], tf.float32),
            "ids": tf.FixedLenFeature([15], tf.int64),
            "app_list": tf.VarLenFeature(tf.int64),
            "level2_list": tf.VarLenFeature(tf.int64),
            "level3_list": tf.VarLenFeature(tf.int64),
            "tag1_list": tf.VarLenFeature(tf.int64),
            "tag2_list": tf.VarLenFeature(tf.int64),
            "tag3_list": tf.VarLenFeature(tf.int64),
            "tag4_list": tf.VarLenFeature(tf.int64),
            "tag5_list": tf.VarLenFeature(tf.int64),
            "tag6_list": tf.VarLenFeature(tf.int64),
            "tag7_list": tf.VarLenFeature(tf.int64)
        }
        parsed = tf.parse_single_example(record, features)
        y = parsed.pop('y')
        z = parsed.pop('z')
        return parsed, {"y": y, "z": z}

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    if perform_shuffle:
        dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    dataset = dataset.repeat(num_epochs)
    dataset = dataset.batch(batch_size) # Batch size to use
    # dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None]))   #不定长补齐

    #return dataset.make_one_shot_iterator()
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
    #print("-"*100)
    #print(batch_features,batch_labels)
    return batch_features, batch_labels

def main():
    # dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
    model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/"
    te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
    model_params = {
        "field_size": 15,
        "feature_size": 600000,
        "embedding_size": 32,
        "learning_rate": 0.0001,
        "l2_reg": 0.005,
        "deep_layers": '512,256,128,64,32',
        "dropout": '0.3,0.3,0.3,0.3,0.3',
        "ctr_task_wgt":0.5
    }
    config = tf.estimator.RunConfig().replace(session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':36}),
            log_step_count_steps=100, save_summary_steps=100)
    Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config)

    preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
    # indices = []
    # for prob in preds:
    #     indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
    # return indices
    with open("/home/gmuser/esmm/nearby/pred.txt", "w") as fo:
        for prob in preds:
            fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr']))

def test_map(x):
    return x * x


if __name__ == "__main__":
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    path = "hdfs://172.16.32.4:8020/strategy/esmm/"
    # df = spark.read.format("tfrecords").load(path+"nearby/part-r-00000")
    # df.show()

    name = spark.parallelize([1,2,3,4,5])
    test = name.repartition(5).map(lambda x: test_map())
    test.show()

    b = time.time()
    tf.logging.set_verbosity(tf.logging.INFO)
    # tf.app.run()
    main()
    print("耗时(分钟):")
    print((time.time()-b)/60)