# -*- coding: utf-8 -*-
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import time


def app_list_func(x,l):
    b = str(x).split(",")
    e = []
    for i in b:
        if i in l.keys():
            e.append(l[i])
        else:
            e.append(0)
    return e


def get_list(db,sql,n):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    v = list(set([i[0] for i in result]))
    app_list_value = [str(i).split(",") for i in v]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
    db.close()
    return number, app_list_map


def get_map():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select app_list from device_app_list"
    a = time.time()
    apps_number, app_list_map = get_list(db,sql,1)
    print("applist")
    print((time.time()-a)/60)
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select level2_ids from diary_feat"
    b = time.time()
    leve2_number, leve2_map = get_list(db, sql, 1+apps_number)
    print("leve2")
    print((time.time() - b) / 60)
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select level3_ids from diary_feat"
    c = time.time()
    leve3_number, leve3_map = get_list(db, sql, 1+leve2_number+apps_number)
    print((time.time() - c) / 60)

    return apps_number, app_list_map,leve2_number, leve2_map,leve3_number, leve3_map


def get_unique(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    v = list(set([i[0] for i in result]))
    db.close()
    print(sql)
    print(len(v))
    return v


def feature_engineer():
    apps_number, app_list_map, level2_number, level2_map, level3_number, level3_map = get_map()

    unique_values = []
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct stat_date from esmm_train_data_dur"
    unique_values.extend(get_unique(db,sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct ucity_id from esmm_train_data_dur"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct ccity_name from esmm_train_data_dur"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct time from cid_time_cut"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct device_type from user_feature"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct manufacturer from user_feature"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct channel from user_feature"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct top from cid_type_top"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct price_min from train_Knowledge_network_data"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct treatment_method from train_Knowledge_network_data"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct price_max from train_Knowledge_network_data"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct treatment_time from train_Knowledge_network_data"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct maintain_time from train_Knowledge_network_data"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct recover_time from train_Knowledge_network_data"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select max(stat_date) from esmm_train_data_dur"
    validate_date = con_sql(db, sql)[0].values.tolist()[0]
    print("validate_date:" + validate_date)
    temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    start = (temp - datetime.timedelta(days=300)).strftime("%Y-%m-%d")
    print(start)

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC')
    sql = "select doctor.hospital_id from jerry_test.esmm_train_data_dur e " \
          "left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
          "left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
          "where e.stat_date >= '{}'".format(start)
    unique_values.extend(get_unique(db, sql))
    features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "stat_date", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]
    unique_values.extend(features)
    print("unique_values length")
    print(len(unique_values))

    temp = list(range(2 + apps_number + level2_number + level3_number,
                      2 + apps_number + level2_number + level3_number + len(unique_values)))
    value_map = dict(zip(unique_values, temp))

    sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
          "u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \
          "wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
          "ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from jerry_test.esmm_train_data e left join jerry_test.user_feature u on e.device_id = u.device_id " \
          "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
          "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
          "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
          "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
          "left join jerry_test.train_Knowledge_network_data k on feat.level2 = k.level2_id " \
          "left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
          "left join jerry_test.question_tag question on e.device_id = question.device_id " \
          "left join jerry_test.search_tag search on e.device_id = search.device_id " \
          "left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
          "left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
          "left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
          "left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
          "left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
          "left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
          "where e.stat_date >= '{}'".format(start)

    df = spark.sql(sql)

    df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
                             "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids",
                             "tag1","tag2","tag3","tag4","tag5","tag6","tag7"])

    df = df.na.fill(dict(zip(features, features)))

    c = time.time()
    rdd = df.select("stat_date","y", "z","app_list","level2_ids","level3_ids",
                    "tag1","tag2","tag3","tag4","tag5","tag6","tag7",
                     "ucity_id", "ccity_name","device_type", "manufacturer", "channel", "top", "time",
                    "hospital_id","treatment_method", "price_min", "price_max", "treatment_time",
                    "maintain_time","recover_time").rdd.coalesce(200).map(lambda x: (x[0],float(x[1]),float(x[2]),
                        app_list_func(x[3], app_list_map), app_list_func(x[4], level2_map),
                        app_list_func(x[5], level3_map), app_list_func(x[6], level2_map),app_list_func(x[7], level2_map),
                        app_list_func(x[8], level2_map), app_list_func(x[9], level2_map),app_list_func(x[10], level2_map),
                        app_list_func(x[11], level2_map),app_list_func(x[12], level2_map),
                        [value_map[x[0]], value_map[x[13]],value_map[x[14]], value_map[x[15]], value_map[x[16]],
                         value_map[x[17]],value_map[x[18]], value_map[x[19]], value_map[x[20]],value_map[x[21]],
                         value_map[x[22]], value_map[x[23]], value_map[x[24]],value_map[x[25]],value_map[x[26]]]))
    d = time.time()
    rdd.persist()
    print("rdd")
    print((d - c) / 60)
    # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集

    train = rdd.filter(lambda x: x[0] != validate_date).map(lambda x:(x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],
                                                                      x[10],x[11],x[12],x[13]))
    f = time.time()
    spark.createDataFrame(train).toDF("y","z","app_list","level2_list","level3_list",
                                      "tag1_list","tag2_list","tag3_list","tag4_list",
                                      "tag5_list","tag6_list","tag7_list","ids") \
        .write.format("tfrecords").save(path=path + "tr/", mode="overwrite")
    h = time.time()
    print("train tfrecord done")
    print((h-f)/60)

    test = rdd.filter(lambda x: x[0] == validate_date).map(lambda x:(x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],
                                                                      x[10],x[11],x[12],x[13]))

    spark.createDataFrame(test).toDF("y", "z", "app_list", "level2_list", "level3_list",
                                      "tag1_list", "tag2_list", "tag3_list", "tag4_list",
                                      "tag5_list", "tag6_list", "tag7_list", "ids") \
        .write.format("tfrecords").save(path=path+"va/", mode="overwrite")

    print("va tfrecord done")

    rdd.unpersist()

    return validate_date,value_map,app_list_map,level2_map,level3_map


def get_predict():
    sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
          "u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
          "dl.app_list,e.hospital_id,feat.level3_ids," \
          "wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
          "ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from jerry_test.esmm_train_data_dur e " \
          "left join jerry_test.user_feature u on e.device_id = u.device_id " \
          "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
          "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
          "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
          "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
          "left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
          "left join jerry_test.question_tag question on e.device_id = question.device_id " \
          "left join jerry_test.search_tag search on e.device_id = search.device_id " \
          "left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
          "left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
          "left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
          "left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
          "left join jerry_test.train_Knowledge_network_data k on feat.level2 = k.level2_id"

    features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]

    df = spark.sql(sql)
    df = df.na.fill(dict(zip(features, features)))

    c = time.time()
    rdd = df.select("label", "y", "z","ucity_id","device_id","cid_id","app_list", "level2_ids", "level3_ids",
            "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
            "ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
            "hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
            "maintain_time", "recover_time") \
        .rdd.repartition(200).map(lambda x: (x[0],float(x[1]),float(x[2]),x[3],x[4],x[5],
                            app_list_func(x[6], app_list_map),app_list_func(x[7], level2_map),
                            app_list_func(x[8], level3_map), app_list_func(x[9], level2_map),
                            app_list_func(x[10], level2_map),app_list_func(x[11], level2_map),
                            app_list_func(x[12], level2_map), app_list_func(x[13], level2_map),
                            app_list_func(x[14], level2_map), app_list_func(x[15], level2_map),
                           [value_map.get(validate_date, 299999),value_map.get(x[16], 299998),
                            value_map.get(x[17], 299997),value_map.get(x[18], 299996),
                            value_map.get(x[19], 299995), value_map.get(x[20], 299994),
                            value_map.get(x[21], 299993), value_map.get(x[22], 299992),
                            value_map.get(x[23], 299991), value_map.get(x[24], 299990),
                            value_map.get(x[25], 299989), value_map.get(x[26], 299988),
                            value_map.get(x[27], 299987), value_map.get(x[28], 299986),
                            value_map.get(x[29], 299985)
                            ]))
    rdd.persist()
    d = time.time()
    print("rdd")
    print((d-c)/60)
    native_pre = spark.createDataFrame(rdd.filter(lambda x:x[0] == 0).map(lambda x:(x[3],x[4],x[5])))\
        .toDF("city","uid","cid_id")
    print("native csv")
    native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
# TODO 写成csv文件改成下面这样
    # native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"native/",header = 'true')

    # 预测的tfrecord必须写成一个文件,这样可以摆保证顺序
    f = time.time()
    spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
                          .map(lambda x: (x[1],x[2],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16]))) \
        .toDF("y","z","app_list", "level2_list", "level3_list","tag1_list", "tag2_list", "tag3_list", "tag4_list",
              "tag5_list", "tag6_list", "tag7_list", "ids").coalesce(1).write.format("tfrecords") \
            .save(path=path+"native/", mode="overwrite")
    print("native tfrecord done")
    h = time.time()
    print((h-f)/60)

    native_pre = spark.createDataFrame(rdd.filter(lambda x: x[0] == 1).map(lambda x: (x[3], x[4], x[5]))) \
        .toDF("city", "uid", "cid_id")
    print("nearby csv")
    native_pre.toPandas().to_csv(local_path + "nearby.csv", header=True)
    # TODO 写成csv文件改成下面这样
    # nearby_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"nearby/",header = 'true')

    spark.createDataFrame(rdd.filter(lambda x: x[0] == 1)
                          .map(
        lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15], x[16]))) \
        .toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
              "tag5_list", "tag6_list", "tag7_list", "ids").coalesce(1).write.format("tfrecords") \
        .save(path=path + "nearby/", mode="overwrite")
    print("nearby tfrecord done")

    rdd.unpersist()


def con_sql(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    db.close()
    return df


if __name__ == '__main__':
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")

    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    ti = pti.TiContext(spark)
    ti.tidbMapDatabase("jerry_test")
    ti.tidbMapDatabase("eagle")
    spark.sparkContext.setLogLevel("WARN")
    path = "hdfs:///strategy/esmm/"
    local_path = "/home/gmuser/esmm/"

    validate_date, value_map, app_list_map, level2_map, level3_map = feature_engineer()
    get_predict()