# -*- coding: utf-8 -*-
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
import datetime
import pandas as pd

def app_list_func(x,l):
    b = x.split(",")
    e = []
    for i in b:
        if i in l.keys():
            e.append(l[i])
        else:
            e.append(0)
    return ",".join([str(j) for j in e])

def multi_hot(df,column,n):
    app_list_value = [i.split(",") for i in df.select(column).collect().unique()]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
    df = df.select(column).apply(app_list_func, args=(app_list_map,))
    return number,app_list_map

def feature_engineer():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select max(stat_date) from esmm_train_data"
    validate_date = con_sql(db, sql)[0].values.tolist()[0]
    print("validate_date:" + validate_date)
    temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    start = (temp - datetime.timedelta(days=6)).strftime("%Y-%m-%d")
    print(start)

    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "4g")

    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    ti = pti.TiContext(spark)
    ti.tidbMapDatabase("jerry_test")
    spark.sparkContext.setLogLevel("WARN")
    sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
          "u.channel,c.top,cut.time,dl.app_list,e.diary_service_id,feat.level3_ids," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from esmm_train_data e left join user_feature u on e.device_id = u.device_id " \
          "left join cid_type_top c on e.device_id = c.device_id " \
          "left join cid_time_cut cut on e.cid_id = cut.cid " \
          "left join device_app_list dl on e.device_id = dl.device_id " \
          "left join diary_feat feat on e.cid_id = feat.diary_id " \
          "left join train_Knowledge_network_data k on feat.level2 = k.level2_id " \
          "where e.stat_date >= '{}'".format(start)

    df = spark.sql(sql)

    url = "jdbc:mysql://172.16.30.143:3306/zhengxing"
    jdbcDF = spark.read.format("jdbc").option("driver", "com.mysql.jdbc.Driver").option("url", url) \
        .option("dbtable", "api_service").option("user", 'work').option("password", 'BJQaT9VzDcuPBqkd').load()
    jdbcDF.createOrReplaceTempView("api_service")
    jdbc = spark.read.format("jdbc").option("driver", "com.mysql.jdbc.Driver").option("url", url) \
        .option("dbtable", "api_doctor").option("user", 'work').option("password", 'BJQaT9VzDcuPBqkd').load()
    jdbc.createOrReplaceTempView("api_doctor")

    sql = "select s.id as diary_service_id,d.hospital_id " \
          "from api_service s left join api_doctor d on s.doctor_id = d.id"
    hospital = spark.sql(sql)

    df = df.join(hospital,"diary_service_id","left_outer").fillna("na")
    df = df.drop("level2").drop("diary_service_id")
    df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
                              "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids"])

    df = df.fillna("na")

    app_list_value = [i.split(",") for i in df.select("app_list").collect().unique()]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(1, number + 1))))

    df = df.select("app_list","ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer","channel",
                   "top", "time", "stat_date", "hospital_id", "level3_ids","y","z",
                   "treatment_method","price_min","price_max","treatment_time","maintain_time","recover_time")\
        .map(lambda x :app_list_func(x[0],app_list_map))

    df.show(6)








    # app_list_number, app_list_map = multi_hot(df, "app_list", 2)
    # level2_number, level2_map = multi_hot(df, "clevel2_id", 2 + app_list_number)
    # level3_number, level3_map = multi_hot(df, "level3_ids", 2 + app_list_number + level2_number)
    #
    # unique_values = []
    # features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
    #             "channel", "top", "time", "stat_date", "hospital_id",
    #             "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]
    # for i in features:
    #     df[i] = df[i].astype("str")
    #     df[i] = df[i].fillna("lost")
    #     # 下面这行代码是为了区分不同的列中有相同的值
    #     df[i] = df[i] + i
    #     unique_values.extend(list(df[i].unique()))
    #
    # temp = list(range(2 + app_list_number + level2_number + level3_number,
    #                   2 + app_list_number + level2_number + level3_number + len(unique_values)))
    # value_map = dict(zip(unique_values, temp))
    #
    # df = df.drop("device_id", axis=1)
    # train = df[df["stat_date"] != validate_date + "stat_date"]
    # test = df[df["stat_date"] == validate_date + "stat_date"]
    # for i in ["ucity_id", "ccity_name", "device_type", "manufacturer",
    #           "channel", "top", "time", "stat_date", "hospital_id",
    #           "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]:
    #     train[i] = train[i].map(value_map)
    #     test[i] = test[i].map(value_map)
    #
    # print("train shape")
    # print(train.shape)
    # print("test shape")
    # print(test.shape)
    #
    # write_csv(train, "tr", 100000)
    # write_csv(test, "va", 80000)

def con_sql(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    db.close()
    return df

def test():

    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true")\
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true")\
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true")\
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions")\
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")



    spark = SparkSession.builder.config(conf= sparkConf).enableHiveSupport().getOrCreate()

    spark.sql("use online")
    spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
    spark.sql("ADD JAR /srv/apps/hive-udf-1.0-SNAPSHOT.jar")
    spark.sql("CREATE TEMPORARY FUNCTION json_map AS 'brickhouse.udf.json.JsonMapUDF'")
    spark.sql("CREATE TEMPORARY FUNCTION is_json AS 'com.gmei.hive.common.udf.UDFJsonFormatCheck'")
    sql = "select user_id from online.tl_hdfs_maidian_view where partition_date = '20190412' limit 10"
    spark.sql(sql).show(6)

    ti = pti.TiContext(spark)
    ti.tidbMapDatabase("jerry_test")

    spark.sparkContext.setLogLevel("WARN")
    df = spark.sql("select max(stat_date) from esmm_train_data")
    df.show()
    t = df.rdd.map(lambda x: str(x[0])).collect()
    print(t)


    # data = [(0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2), (5, 9.2), (6, 14.4)]
    # df = spark.createDataFrame(data, ["id", "hour"])
    # df.show(6)
    # t = df.rdd.map(lambda x:x[0]).collect()
    # print(t)

    # validate_date = spark.sql("select max(stat_date) from esmm_train_data").rdd.map(lambda x: str(x[0]))
    # print(validate_date.count())
    # print("validate_date:" + validate_date)
    # temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    # start = (temp - datetime.timedelta(days=10)).strftime("%Y-%m-%d")
    # print(start)


if __name__ == '__main__':
    feature_engineer()