# -*- coding: utf-8 -*-
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import hdfs
import avro

def app_list_func(x,l):
    b = x.split(",")
    e = []
    for i in b:
        if i in l.keys():
            e.append(l[i])
        else:
            e.append(0)
    return ",".join([str(j) for j in e])


def multi_hot(df,column,n):
    v = df.select(column).distinct().rdd.map(lambda x: x[0]).collect()
    app_list_value = [i.split(",") for i in v]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
    return number,app_list_map


def feature_engineer():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select max(stat_date) from esmm_train_data"
    validate_date = con_sql(db, sql)[0].values.tolist()[0]
    print("validate_date:" + validate_date)
    temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    start = (temp - datetime.timedelta(days=300)).strftime("%Y-%m-%d")
    print(start)

    sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
          "u.channel,c.top,cut.time,dl.app_list,e.diary_service_id,feat.level3_ids," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from esmm_train_data e left join user_feature u on e.device_id = u.device_id " \
          "left join cid_type_top c on e.device_id = c.device_id " \
          "left join cid_time_cut cut on e.cid_id = cut.cid " \
          "left join device_app_list dl on e.device_id = dl.device_id " \
          "left join diary_feat feat on e.cid_id = feat.diary_id " \
          "left join train_Knowledge_network_data k on feat.level2 = k.level2_id " \
          "where e.stat_date >= '{}'".format(start)

    df = spark.sql(sql)

    url = "jdbc:mysql://172.16.30.143:3306/zhengxing"
    jdbcDF = spark.read.format("jdbc").option("driver", "com.mysql.jdbc.Driver").option("url", url) \
        .option("dbtable", "api_service").option("user", 'work').option("password", 'BJQaT9VzDcuPBqkd').load()
    jdbcDF.createOrReplaceTempView("api_service")
    jdbc = spark.read.format("jdbc").option("driver", "com.mysql.jdbc.Driver").option("url", url) \
        .option("dbtable", "api_doctor").option("user", 'work').option("password", 'BJQaT9VzDcuPBqkd').load()
    jdbc.createOrReplaceTempView("api_doctor")

    sql = "select s.id as diary_service_id,d.hospital_id " \
          "from api_service s left join api_doctor d on s.doctor_id = d.id"
    hospital = spark.sql(sql)

    df = df.join(hospital,"diary_service_id","left_outer").fillna("na")
    df = df.drop("diary_service_id")
    df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
                              "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids"])

    features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "stat_date", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]

    df = df.na.fill(dict(zip(features,features)))

    apps_number, app_list_map = multi_hot(df,"app_list",1)
    level2_number,leve2_map = multi_hot(df,"level2_ids",1 + apps_number)
    level3_number, leve3_map = multi_hot(df, "level3_ids", 1 + apps_number + level2_number)

    unique_values = []
    for i in features:
        unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect())
    temp = list(range(2 + apps_number + level2_number + level3_number,
                      2 + apps_number + level2_number + level3_number + len(unique_values)))
    value_map = dict(zip(unique_values, temp))

    rdd = df.select("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
                  "channel", "top", "time", "hospital_id","treatment_method", "price_min",
                  "price_max", "treatment_time","maintain_time", "recover_time","y","z",).rdd
    rdd.persist()
    # TODO 上线后把下面train fliter 删除，因为最近一天的数据也要作为训练集
    train = rdd.filter(lambda x: x[3]!= validate_date).map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], leve2_map),
                            app_list_func(x[2], leve3_map),value_map[x[3]],value_map[x[4]],
                            value_map[x[5]],value_map[x[6]],value_map[x[7]],value_map[x[8]],
                            value_map[x[9]],value_map[x[10]],value_map[x[11]],value_map[x[12]],
                            value_map[x[13]],value_map[x[14]],value_map[x[15]],value_map[x[16]],
                            value_map[x[17]], x[18],x[19]))
    test = rdd.filter(lambda x: x[3] == validate_date)\
        .map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], leve2_map),
                   app_list_func(x[2], leve3_map), value_map[x[3]], value_map[x[4]],
                   value_map[x[5]], value_map[x[6]], value_map[x[7]], value_map[x[8]],
                   value_map[x[9]], value_map[x[10]], value_map[x[11]], value_map[x[12]],
                   value_map[x[13]], value_map[x[14]], value_map[x[15]], value_map[x[16]],
                   value_map[x[17]], x[18], x[19]))

    # spark.createDataFrame(test).write.csv('/recommend/va', mode='overwrite', header=True)
    # spark.createDataFrame(train).write.csv('/recommend/tr', mode='overwrite', header=True)

    a = spark.createDataFrame(train).toPandas()
    print(a.shape)
    print("done")
    rdd.unpersist()

    return validate_date,value_map,app_list_map,leve2_map,leve3_map


def get_predict(date,value_map,app_list_map,level2_map,level3_map):
    sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
          "u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
          "dl.app_list,e.hospital_id,feat.level3_ids," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from esmm_pre_data e left join user_feature u on e.device_id = u.device_id " \
          "left join cid_type_top c on e.device_id = c.device_id " \
          "left join cid_time_cut cut on e.cid_id = cut.cid " \
          "left join device_app_list dl on e.device_id = dl.device_id " \
          "left join diary_feat feat on e.cid_id = feat.diary_id " \
          "left join train_Knowledge_network_data k on feat.level2 = k.level2_id"

    features = ["app_list", "level2_ids", "level3_ids","ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]
    df = spark.sql(sql)

    df = df.na.fill(dict(zip(features, features)))
    df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
                             "device_id","cid_id","label",
                             "channel", "top", "time", "app_list", "hospital_id", "level3_ids"])

    rdd = df.select("app_list", "level2_ids", "level3_ids","ucity_id","device_id","cid_id","label", "y", "z",
                    "ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
              "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
              "recover_time","top") \
        .rdd.map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], level2_map),
                        app_list_func(x[2], level3_map), x[3],x[4],x[5],x[6],x[7],x[8],
                            value_map.get(x[3], 300000),value_map.get(x[9], 299999),
                            value_map.get(x[10], 299998), value_map.get(x[11], 299997),
                            value_map.get(x[12], 299996), value_map.get(x[13], 299995),
                            value_map.get(x[14], 299994),value_map.get(x[15], 299993),
                            value_map.get(x[16], 299992),value_map.get(x[17], 299991),
                            value_map.get(x[18], 299990),value_map.get(x[19], 299989),
                            value_map.get(x[20], 299988),value_map.get(x[21], 299987),
                        value_map[date]))

    rdd.persist()

    native_pre = spark.createDataFrame(rdd.filter(lambda x:x[6] == 0).map(lambda x:(x[3],x[4],x[5])))\
        .toDF("city","uid","cid_id")
    print("native")
    print(native_pre.count())
    native_pre.write.csv('/recommend', mode='overwrite', header=True)

    spark.createDataFrame(rdd.filter(lambda x: x[6] == 0)
                                   .map(lambda x: (x[0], x[1], x[2],x[9],x[10],x[11],x[12],x[13],x[14],x[15],
                                                   x[16],x[17],x[18],x[19],x[20],x[21],x[22],x[23]))) \
        .toDF("app_list", "level2_ids", "level3_ids","ucity_id",
             "ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
              "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
              "recover_time", "top","stat_date").write.csv('/recommend/native', mode='overwrite', header=True)

    nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[6] == 1).map(lambda x: (x[3], x[4], x[5]))) \
        .toDF("city", "uid", "cid_id")
    print("nearby")
    print(nearby_pre.count())
    nearby_pre.write.csv('/recommend', mode='overwrite', header=True)

    spark.createDataFrame(rdd.filter(lambda x: x[6] == 1)
                          .map(lambda x: (x[0], x[1], x[2], x[9], x[10], x[11], x[12], x[13], x[14], x[15],
                                          x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]))) \
        .toDF("app_list", "level2_ids", "level3_ids", "ucity_id",
               "ccity_name", "device_type", "manufacturer", "channel", "time", "hospital_id",
              "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
              "recover_time","top","stat_date").write.csv('/recommend/nearby', mode='overwrite', header=True)

    rdd.unpersist()


def con_sql(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    db.close()
    return df


def test():
    sql = "select stat_date,cid_id from esmm_train_data e where stat_date = '{}' limit 60".format("2019-04-25")
    df = spark.createDataFrame(spark.sql(sql).rdd.map(lambda x:(x[0],x[1])).zipWithIndex()
                               .map(lambda x:(x[1],x[0][0],x[0][1]))).toDF("ind","k","v")
    df.show(6)

    # df.write.csv('/recommend/tr', mode='overwrite', header=True)
    df.write.format("avro").save("/recommend/tr/avro.csv")


    #
    # from hdfs import InsecureClient
    # from hdfs.ext.dataframe import read_dataframe
    # from hdfs.ext.dataframe import write_dataframe
    #
    #
    # client = InsecureClient('http://nvwa01:50070')
    # # write_dataframe(client, '/recommend/va/a.csv', df)
    #
    # df = read_dataframe(client,"/recommend/va/a.csv")
    #
    #
    # print(df.head())

    # spark.sql("use online")
    # spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
    # spark.sql("ADD JAR /srv/apps/hive-udf-1.0-SNAPSHOT.jar")
    # spark.sql("CREATE TEMPORARY FUNCTION json_map AS 'brickhouse.udf.json.JsonMapUDF'")
    # spark.sql("CREATE TEMPORARY FUNCTION is_json AS 'com.gmei.hive.common.udf.UDFJsonFormatCheck'")
    #
    # spark.sql("select cl_type from online.tl_hdfs_maidian_view where partition_date = '20190312' limit 6").show()



    # data = [(0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2), (5, 9.2), (6, 14.4)]
    # df = spark.createDataFrame(data, ["id", "hour"])
    # df.show(6)
    # t = df.rdd.map(lambda x:x[0]).collect()
    # print(t)

    # validate_date = spark.sql("select max(stat_date) from esmm_train_data").rdd.map(lambda x: str(x[0]))
    # print(validate_date.count())
    # print("validate_date:" + validate_date)
    # temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    # start = (temp - datetime.timedelta(days=10)).strftime("%Y-%m-%d")
    # print(start)



if __name__ == '__main__':
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "8g")

    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    ti = pti.TiContext(spark)
    ti.tidbMapDatabase("jerry_test")
    spark.sparkContext.setLogLevel("WARN")

    # validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
    # get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)

    test()


