#coding=utf-8

import pymysql
import os
import json
from datetime import date, timedelta
import tensorflow as tf
import time
import pandas as pd
import datetime

#################### CMD Arguments ####################
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer("dist_mode", 0, "distribuion mode {0-loacal, 1-single_dist, 2-multi_dist}")
tf.app.flags.DEFINE_string("ps_hosts", '', "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("worker_hosts", '', "Comma-separated list of hostname:port pairs")
tf.app.flags.DEFINE_string("job_name", '', "One of 'ps', 'worker'")
tf.app.flags.DEFINE_integer("task_index", 0, "Index of task within the job")
tf.app.flags.DEFINE_integer("num_threads", 16, "Number of threads")
tf.app.flags.DEFINE_integer("feature_size", 0, "Number of features")
tf.app.flags.DEFINE_integer("field_size", 0, "Number of common fields")
tf.app.flags.DEFINE_integer("embedding_size", 32, "Embedding size")
tf.app.flags.DEFINE_integer("num_epochs", 10, "Number of epochs")
tf.app.flags.DEFINE_integer("batch_size", 64, "Number of batch size")
tf.app.flags.DEFINE_integer("log_steps", 1000, "save summary every steps")
tf.app.flags.DEFINE_float("learning_rate", 0.0005, "learning rate")
tf.app.flags.DEFINE_float("l2_reg", 0.0001, "L2 regularization")
tf.app.flags.DEFINE_string("loss_type", 'log_loss', "loss type {square_loss, log_loss}")
tf.app.flags.DEFINE_float("ctr_task_wgt", 0.5, "loss weight of ctr task")
tf.app.flags.DEFINE_string("optimizer", 'Adam', "optimizer type {Adam, Adagrad, GD, Momentum}")
tf.app.flags.DEFINE_string("deep_layers", '256,128,64', "deep layers")
tf.app.flags.DEFINE_string("dropout", '0.5,0.5,0.5', "dropout rate")
tf.app.flags.DEFINE_boolean("batch_norm", False, "perform batch normaization (True or False)")
tf.app.flags.DEFINE_float("batch_norm_decay", 0.9, "decay for the moving average(recommend trying decay=0.9)")
tf.app.flags.DEFINE_string("hdfs_dir", '', "hdfs dir")
tf.app.flags.DEFINE_string("local_dir", '', "local dir")
tf.app.flags.DEFINE_string("dt_dir", '', "data dt partition")
tf.app.flags.DEFINE_string("model_dir", '', "model check point dir")
tf.app.flags.DEFINE_string("servable_model_dir", '', "export servable model for TensorFlow Serving")
tf.app.flags.DEFINE_string("task_type", 'train', "task type {train, infer, eval, export}")
tf.app.flags.DEFINE_boolean("clear_existing_model", False, "clear existing model or not")


def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
    print('Parsing', filenames)
    def _parse_fn(record):
        features = {
            "y": tf.FixedLenFeature([], tf.float32),
            "z": tf.FixedLenFeature([], tf.float32),
            "ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64),
            "app_list": tf.VarLenFeature(tf.int64),
            "level2_list": tf.VarLenFeature(tf.int64),
            "level3_list": tf.VarLenFeature(tf.int64),
            "tag1_list": tf.VarLenFeature(tf.int64),
            "tag2_list": tf.VarLenFeature(tf.int64),
            "tag3_list": tf.VarLenFeature(tf.int64),
            "tag4_list": tf.VarLenFeature(tf.int64),
            "tag5_list": tf.VarLenFeature(tf.int64),
            "tag6_list": tf.VarLenFeature(tf.int64),
            "tag7_list": tf.VarLenFeature(tf.int64),
            "search_tag2_list": tf.VarLenFeature(tf.int64),
            "search_tag3_list": tf.VarLenFeature(tf.int64),
            "uid": tf.VarLenFeature(tf.string),
            "city": tf.VarLenFeature(tf.string),
            "cid_id": tf.VarLenFeature(tf.string)
        }
        parsed = tf.parse_single_example(record, features)
        y = parsed.pop('y')
        z = parsed.pop('z')
        return parsed, {"y": y, "z": z}

    # Extract lines from input files using the Dataset API, can pass one filename or filename list
    # dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8).prefetch(500000)    # multi-thread pre-process then prefetch

    # Randomizes input using a window of 256 elements (read into memory)
    # if perform_shuffle:
    #     dataset = dataset.shuffle(buffer_size=256)

    # epochs from blending together.
    # dataset = dataset.repeat(num_epochs)
    # dataset = dataset.batch(batch_size) # Batch size to use

    files = tf.data.Dataset.list_files(filenames)
    dataset = files.apply(
        tf.data.experimental.parallel_interleave(
            lambda  file: tf.data.TFRecordDataset(file),
            cycle_length=8
        )
    )

    dataset = dataset.apply(tf.data.experimental.map_and_batch(map_func=_parse_fn, batch_size=batch_size, num_parallel_calls=8))
    dataset = dataset.prefetch(10000)


    # dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None]))   #不定长补齐
    #return dataset.make_one_shot_iterator()
    iterator = dataset.make_one_shot_iterator()
    batch_features, batch_labels = iterator.get_next()
    #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
    #print("-"*100)
    #print(batch_features,batch_labels)
    return batch_features, batch_labels

def model_fn(features, labels, mode, params):
    """Bulid Model function f(x) for Estimator."""
    #------hyperparameters----
    field_size = params["field_size"]
    feature_size = params["feature_size"]
    embedding_size = params["embedding_size"]
    l2_reg = params["l2_reg"]
    learning_rate = params["learning_rate"]
    #optimizer = params["optimizer"]
    layers = list(map(int, params["deep_layers"].split(',')))
    dropout = list(map(float, params["dropout"].split(',')))
    ctr_task_wgt = params["ctr_task_wgt"]
    common_dims = field_size*embedding_size

    #------bulid weights------
    Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer())

    feat_ids = features['ids']
    app_list = features['app_list']
    level2_list = features['level2_list']
    level3_list = features['level3_list']
    tag1_list = features['tag1_list']
    tag2_list = features['tag2_list']
    tag3_list = features['tag3_list']
    tag4_list = features['tag4_list']
    tag5_list = features['tag5_list']
    tag6_list = features['tag6_list']
    tag7_list = features['tag7_list']
    search_tag2_list = features['search_tag2_list']
    search_tag3_list = features['search_tag3_list']
    uid = features['uid']
    city = features['city']
    cid_id = features['cid_id']


    if FLAGS.task_type != "infer":
        y = labels['y']
        z = labels['z']

    #------build f(x)------
    with tf.variable_scope("Shared-Embedding-layer"):
        embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
        app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum")
        level2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level2_list, sp_weights=None, combiner="sum")
        level3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level3_list, sp_weights=None, combiner="sum")
        tag1 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag1_list, sp_weights=None, combiner="sum")
        tag2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag2_list, sp_weights=None, combiner="sum")
        tag3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag3_list, sp_weights=None, combiner="sum")
        tag4 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag4_list, sp_weights=None, combiner="sum")
        tag5 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag5_list, sp_weights=None, combiner="sum")
        tag6 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag6_list, sp_weights=None, combiner="sum")
        tag7 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag7_list, sp_weights=None, combiner="sum")
        search_tag2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=search_tag2_list, sp_weights=None, combiner="sum")
        search_tag3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=search_tag3_list, sp_weights=None, combiner="sum")


        # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims])  # None * (F * K)
        x_concat = tf.concat([tf.reshape(embedding_id, shape=[-1, common_dims]), app_id, level2, level3, tag1,
                              tag2, tag3, tag4, tag5, tag6, tag7,search_tag2,search_tag3], axis=1)

        uid = tf.sparse.to_dense(uid,default_value="")
        city = tf.sparse.to_dense(city,default_value="")
        cid_id = tf.sparse.to_dense(cid_id,default_value="")

    with tf.name_scope("CVR_Task"):
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_phase = True
        else:
            train_phase = False
        x_cvr = x_concat
        for i in range(len(layers)):
            x_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=layers[i], \
                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_mlp%d' % i)

            if FLAGS.batch_norm:
                x_cvr = batch_norm_layer(x_cvr, train_phase=train_phase, scope_bn='cvr_bn_%d' %i)       #放在RELU之后 https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md#bn----before-or-after-relu
            if mode == tf.estimator.ModeKeys.TRAIN:
                x_cvr = tf.nn.dropout(x_cvr, keep_prob=dropout[i])                                  #Apply Dropout after all BN layers and set dropout=0.8(drop_ratio=0.2)

        y_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=1, activation_fn=tf.identity, \
            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_out')
        y_cvr = tf.reshape(y_cvr,shape=[-1])

    with tf.name_scope("CTR_Task"):
        if mode == tf.estimator.ModeKeys.TRAIN:
            train_phase = True
        else:
            train_phase = False

        x_ctr = x_concat
        for i in range(len(layers)):
            x_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=layers[i], \
                weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_mlp%d' % i)

            if FLAGS.batch_norm:
                x_ctr = batch_norm_layer(x_ctr, train_phase=train_phase, scope_bn='ctr_bn_%d' %i)       #放在RELU之后 https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md#bn----before-or-after-relu
            if mode == tf.estimator.ModeKeys.TRAIN:
                x_ctr = tf.nn.dropout(x_ctr, keep_prob=dropout[i])                                  #Apply Dropout after all BN layers and set dropout=0.8(drop_ratio=0.2)

        y_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=1, activation_fn=tf.identity, \
            weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_out')
        y_ctr = tf.reshape(y_ctr,shape=[-1])

    with tf.variable_scope("MTL-Layer"):
        pctr = tf.sigmoid(y_ctr)
        pcvr = tf.sigmoid(y_cvr)
        pctcvr = pctr*pcvr

    predictions = {"pctcvr": pctcvr, "uid": uid, "city": city, "cid_id": cid_id}
    export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
    # Provide an estimator spec for `ModeKeys.PREDICT`
    if mode == tf.estimator.ModeKeys.PREDICT:
        return tf.estimator.EstimatorSpec(
                mode=mode,
                predictions=predictions,
                export_outputs=export_outputs)

    if FLAGS.task_type != "infer":
        #------bulid loss------
        ctr_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_ctr, labels=y))
        #cvr_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_ctcvr, labels=z))
        cvr_loss = tf.reduce_mean(tf.losses.log_loss(predictions=pctcvr, labels=z))
        loss = ctr_task_wgt * ctr_loss + (1  -ctr_task_wgt) * cvr_loss + l2_reg * tf.nn.l2_loss(Feat_Emb)

        tf.summary.scalar('ctr_loss', ctr_loss)
        tf.summary.scalar('cvr_loss', cvr_loss)

        # Provide an estimator spec for `ModeKeys.EVAL`
        eval_metric_ops = {
            "CTR_AUC": tf.metrics.auc(y, pctr),
            #"CTR_F1": tf.contrib.metrics.f1_score(y,pctr),
            #"CTR_Precision": tf.metrics.precision(y,pctr),
            #"CTR_Recall": tf.metrics.recall(y,pctr),
            "CVR_AUC": tf.metrics.auc(z, pcvr),
            "CTCVR_AUC": tf.metrics.auc(z, pctcvr)
        }
        if mode == tf.estimator.ModeKeys.EVAL:
            return tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    loss=loss,
                    eval_metric_ops=eval_metric_ops)

        #------bulid optimizer------
        if FLAGS.optimizer == 'Adam':
            optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8)
        elif FLAGS.optimizer == 'Adagrad':
            optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate, initial_accumulator_value=1e-8)
        elif FLAGS.optimizer == 'Momentum':
            optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.95)
        elif FLAGS.optimizer == 'ftrl':
            optimizer = tf.train.FtrlOptimizer(learning_rate)

        train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())

        # Provide an estimator spec for `ModeKeys.TRAIN` modes
        if mode == tf.estimator.ModeKeys.TRAIN:
            return tf.estimator.EstimatorSpec(
                    mode=mode,
                    predictions=predictions,
                    loss=loss,
                    train_op=train_op)

def batch_norm_layer(x, train_phase, scope_bn):
    bn_train = tf.contrib.layers.batch_norm(x, decay=FLAGS.batch_norm_decay, center=True, scale=True, updates_collections=None, is_training=True,  reuse=None, scope=scope_bn)
    bn_infer = tf.contrib.layers.batch_norm(x, decay=FLAGS.batch_norm_decay, center=True, scale=True, updates_collections=None, is_training=False, reuse=True, scope=scope_bn)
    z = tf.cond(tf.cast(train_phase, tf.bool), lambda: bn_train, lambda: bn_infer)
    return z

def set_dist_env():
    if FLAGS.dist_mode == 1:        # 本地分布式测试模式1 chief, 1 ps, 1 evaluator
        ps_hosts = FLAGS.ps_hosts.split(',')
        chief_hosts = FLAGS.chief_hosts.split(',')
        task_index = FLAGS.task_index
        job_name = FLAGS.job_name
        print('ps_host', ps_hosts)
        print('chief_hosts', chief_hosts)
        print('job_name', job_name)
        print('task_index', str(task_index))
        # 无worker参数
        tf_config = {
            'cluster': {'chief': chief_hosts, 'ps': ps_hosts},
            'task': {'type': job_name, 'index': task_index }
        }
        print(json.dumps(tf_config))
        os.environ['TF_CONFIG'] = json.dumps(tf_config)
    elif FLAGS.dist_mode == 2:      # 集群分布式模式
        ps_hosts = FLAGS.ps_hosts.split(',')
        worker_hosts = FLAGS.worker_hosts.split(',')
        chief_hosts = worker_hosts[0:1] # get first worker as chief
        worker_hosts = worker_hosts[2:] # the rest as worker
        task_index = FLAGS.task_index
        job_name = FLAGS.job_name
        print('ps_host', ps_hosts)
        print('worker_host', worker_hosts)
        print('chief_hosts', chief_hosts)
        print('job_name', job_name)
        print('task_index', str(task_index))
        # use #worker=0 as chief
        if job_name == "worker" and task_index == 0:
            job_name = "chief"
        # use #worker=1 as evaluator
        if job_name == "worker" and task_index == 1:
            job_name = 'evaluator'
            task_index = 0
        # the others as worker
        if job_name == "worker" and task_index > 1:
            task_index -= 2

        tf_config = {
            'cluster': {'chief': chief_hosts, 'worker': worker_hosts, 'ps': ps_hosts},
            'task': {'type': job_name, 'index': task_index }
        }
        print(json.dumps(tf_config))
        os.environ['TF_CONFIG'] = json.dumps(tf_config)

def main(file_path):
    #------check Arguments------
    if FLAGS.dt_dir == "":
        FLAGS.dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
    FLAGS.model_dir = FLAGS.model_dir + FLAGS.dt_dir
    #FLAGS.data_dir  = FLAGS.data_dir + FLAGS.dt_dir


    va_files = ["hdfs://172.16.32.4:8020/strategy/esmm/va/part-r-00000"]


    # if FLAGS.clear_existing_model:
    #     try:
    #         shutil.rmtree(FLAGS.model_dir)
    #     except Exception as e:
    #         print(e, "at clear_existing_model")
    #     else:
    #         print("existing model cleaned at %s" % FLAGS.model_dir)

    # set_dist_env()

    #------bulid Tasks------
    model_params = {
        "field_size": FLAGS.field_size,
        "feature_size": FLAGS.feature_size,
        "embedding_size": FLAGS.embedding_size,
        "learning_rate": FLAGS.learning_rate,
        "l2_reg": FLAGS.l2_reg,
        "deep_layers": FLAGS.deep_layers,
        "dropout": FLAGS.dropout,
        "ctr_task_wgt":FLAGS.ctr_task_wgt
    }
    config = tf.estimator.RunConfig().replace(session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':FLAGS.num_threads}),
            log_step_count_steps=FLAGS.log_steps, save_summary_steps=FLAGS.log_steps)
    Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=FLAGS.model_dir, params=model_params, config=config)

    if FLAGS.task_type == 'train':
        train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(file_path, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size))
        eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=FLAGS.batch_size), steps=None, start_delay_secs=1000, throttle_secs=1200)
        result = tf.estimator.train_and_evaluate(Estimator, train_spec, eval_spec)
        for key,value in sorted(result[0].items()):
            print('%s: %s' % (key,value))
    elif FLAGS.task_type == 'eval':
        result = Estimator.evaluate(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=FLAGS.batch_size))
        for key,value in sorted(result.items()):
            print('%s: %s' % (key,value))
    elif FLAGS.task_type == 'infer':
        preds = Estimator.predict(input_fn=lambda: input_fn(file_path, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys=["pctcvr","uid","city","cid_id"])
        result = []
        for prob in preds:
            result.append([str(prob["uid"][0]), str(prob["city"][0]), str(prob["cid_id"][0]), str(prob['pctcvr'])])
        return result
    elif FLAGS.task_type == 'export':
        print("Not Implemented, Do It Yourself!")

def trans(x):
    return str(x)[2:-1] if str(x)[0] == 'b' else x

def set_join(lst):
    l = lst.unique().tolist()
    r = [str(i) for i in l]
    r =r[:500]
    return ','.join(r)

def df_sort(result,queue_name):
    df = pd.DataFrame(result, columns=["uid", "city", "cid_id", "pctcvr"])
    # print(df.head(10))
    df['uid1'] = df['uid'].apply(trans)
    df['city1'] = df['city'].apply(trans)
    df['cid_id1'] = df['cid_id'].apply(trans)

    df2 = df.groupby(by=["uid1", "city1"]).apply(lambda x: x.sort_values(by="pctcvr", ascending=False)) \
        .reset_index(drop=True).groupby(by=["uid1", "city1"]).agg({'cid_id1': set_join}).reset_index(drop=False)
    df2.columns = ["device_id", "city_id", queue_name]
    df2["time"] = str(datetime.datetime.now().strftime('%Y%m%d%H%M'))
    return df2

def update_or_insert(df2,queue_name):
    device_count = df2.shape[0]
    con = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test', charset = 'utf8')
    cur = con.cursor()
    try:
        for i in range(0, device_count):
            query = """INSERT INTO esmm_device_diary_queue (device_id, city_id, time,%s) VALUES('%s', '%s', '%s', '%s') \
            ON DUPLICATE KEY UPDATE device_id='%s', city_id='%s', time='%s', %s='%s'""" % (queue_name, df2.device_id[i],df2.city_id[i], df2.time[i], df2[queue_name][i], df2.device_id[i], df2.city_id[i], df2.time[i], queue_name, df2[queue_name][i])
            cur.execute(query)
            con.commit()
        con.close()
        print("insert or update sucess")
    except Exception as e:
        print(e)


if __name__ == "__main__":

    b = time.time()
    path = "hdfs://172.16.32.4:8020/strategy/esmm/"
    tf.logging.set_verbosity(tf.logging.INFO)
    if FLAGS.task_type == 'train':
        print("train task")
        tr_files = ["hdfs://172.16.32.4:8020/strategy/esmm/tr/part-r-00000"]
        main(tr_files)
    elif FLAGS.task_type == 'infer':
        te_files = ["%s/part-r-00000" % FLAGS.hdfs_dir]
        queue_name = te_files[0].split('/')[-2] + "_queue"
        print(queue_name + " task")
        result = main(te_files)
        df = df_sort(result,queue_name)
        update_or_insert(df,queue_name)
    print("耗时(分钟):")
    print((time.time()-b)/60)