# -*- coding: utf-8 -*-
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
import datetime
import pandas as pd

def app_list_func(x,l):
    b = x.split(",")
    e = []
    for i in b:
        if i in l.keys():
            e.append(l[i])
        else:
            e.append(0)
    return ",".join([str(j) for j in e])


def multi_hot(df,column,n):
    v = set(df.select(column).rdd.map(lambda x: x[0]).collect())
    app_list_value = [i.split(",") for i in v]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
    return number,app_list_map


def feature_engineer():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select max(stat_date) from esmm_train_data"
    validate_date = con_sql(db, sql)[0].values.tolist()[0]
    print("validate_date:" + validate_date)
    temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    start = (temp - datetime.timedelta(days=1)).strftime("%Y-%m-%d")
    print(start)

    sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
          "u.channel,c.top,cut.time,dl.app_list,e.diary_service_id,feat.level3_ids," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from esmm_train_data e left join user_feature u on e.device_id = u.device_id " \
          "left join cid_type_top c on e.device_id = c.device_id " \
          "left join cid_time_cut cut on e.cid_id = cut.cid " \
          "left join device_app_list dl on e.device_id = dl.device_id " \
          "left join diary_feat feat on e.cid_id = feat.diary_id " \
          "left join train_Knowledge_network_data k on feat.level2 = k.level2_id " \
          "where e.stat_date >= '{}'".format(start)

    df = spark.sql(sql)

    url = "jdbc:mysql://172.16.30.143:3306/zhengxing"
    jdbcDF = spark.read.format("jdbc").option("driver", "com.mysql.jdbc.Driver").option("url", url) \
        .option("dbtable", "api_service").option("user", 'work').option("password", 'BJQaT9VzDcuPBqkd').load()
    jdbcDF.createOrReplaceTempView("api_service")
    jdbc = spark.read.format("jdbc").option("driver", "com.mysql.jdbc.Driver").option("url", url) \
        .option("dbtable", "api_doctor").option("user", 'work').option("password", 'BJQaT9VzDcuPBqkd').load()
    jdbc.createOrReplaceTempView("api_doctor")

    sql = "select s.id as diary_service_id,d.hospital_id " \
          "from api_service s left join api_doctor d on s.doctor_id = d.id"
    hospital = spark.sql(sql)

    df = df.join(hospital,"diary_service_id","left_outer").fillna("na")
    df = df.drop("level2").drop("diary_service_id")
    df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
                              "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids"])

    features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "stat_date", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]

    df = df.na.fill(dict(zip(features,features)))

    apps_number, app_list_map = multi_hot(df,"app_list",1)
    level2_number,leve2_map = multi_hot(df,"level2_ids",1 + apps_number)
    level3_number, leve3_map = multi_hot(df, "level3_ids", 1 + apps_number + level2_number)

    unique_values = []
    for i in features:
        unique_values.extend(list(set(df.select(i).rdd.map(lambda x: x[0]).collect())))
    temp = list(range(2 + apps_number + level2_number + level3_number,
                      2 + apps_number + level2_number + level3_number + len(unique_values)))
    value_map = dict(zip(unique_values, temp))

    train = df.select("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
                  "channel", "top", "time", "hospital_id","treatment_method", "price_min",
                  "price_max", "treatment_time","maintain_time", "recover_time","y","z",)\
        .rdd.filter(lambda x: x[3]!= validate_date).map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], leve2_map),
                            app_list_func(x[2], leve3_map),value_map[x[3]],value_map[x[4]],
                            value_map[x[5]],value_map[x[6]],value_map[x[7]],value_map[x[8]],
                            value_map[x[9]],value_map[x[10]],value_map[x[11]],value_map[x[12]],
                            value_map[x[13]],value_map[x[14]],value_map[x[15]],value_map[x[16]],
                            value_map[x[17]], x[18],x[19]))
    test = df.select("app_list", "level2_ids", "level3_ids", "stat_date", "ucity_id", "ccity_name", "device_type",
                      "manufacturer","channel", "top", "time", "hospital_id", "treatment_method", "price_min",
                      "price_max", "treatment_time", "maintain_time", "recover_time", "y", "z", ) \
        .rdd.filter(lambda x: x[3] == validate_date)\
        .map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], leve2_map),
                   app_list_func(x[2], leve3_map), value_map[x[3]], value_map[x[4]],
                   value_map[x[5]], value_map[x[6]], value_map[x[7]], value_map[x[8]],
                   value_map[x[9]], value_map[x[10]], value_map[x[11]], value_map[x[12]],
                   value_map[x[13]], value_map[x[14]], value_map[x[15]], value_map[x[16]],
                   value_map[x[17]], x[18], x[19]))

    print("test.count",test.count())
    print("train count",train.count())

    spark.createDataFrame(test).write.csv('/recommend/va', mode='overwrite', header=True)
    spark.createDataFrame(train).write.csv('/recommend/tr', mode='overwrite', header=True)
    print("done")

    return validate_date,value_map,app_list_map,leve2_map,leve3_map


# def get_predict(date,value_map,app_list_map,level2_map,level3_map):
#
#     sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
#           "u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
#           "dl.app_list,e.hospital_id,feat.level3_ids,feat.level2 " \
#           "from esmm_pre_data e left join user_feature u on e.device_id = u.device_id " \
#           "left join cid_type_top c on e.device_id = c.device_id " \
#           "left join cid_time_cut cut on e.cid_id = cut.cid " \
#           "left join device_app_list dl on e.device_id = dl.device_id " \
#           "left join diary_feat feat on e.cid_id = feat.diary_id"
#
#
#     df = df.rename(columns={0: "y", 1: "z", 2: "label", 3: "ucity_id", 4: "clevel2_id", 5: "ccity_name",
#                             6: "device_type", 7: "manufacturer", 8: "channel", 9: "top",10: "device_id",
#                             11: "cid_id", 12: "time",13:"app_list",14:"hospital_id",15:"level3_ids",
#                             16: "level2"})
#
#     db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
#     sql = "select level2_id,treatment_method,price_min,price_max,treatment_time,maintain_time,recover_time " \
#           "from train_Knowledge_network_data"
#     knowledge = con_sql(db, sql)
#     knowledge = knowledge.rename(columns={0: "level2", 1: "method", 2: "min", 3: "max",
#                                           4: "treatment_time", 5: "maintain_time", 6: "recover_time"})
#     knowledge["level2"] = knowledge["level2"].astype("str")
#
#     df = pd.merge(df, knowledge, on='level2', how='left')
#     df = df.drop("level2", axis=1)
#     df = df.drop_duplicates(["ucity_id", "clevel2_id", "ccity_name", "device_type", "manufacturer",
#                              "channel", "top", "time", "app_list", "hospital_id", "level3_ids"])
#
#
#     df["stat_date"] = date
#     print(df.head(6))
#     df["app_list"] = df["app_list"].fillna("lost_na")
#     df["app_list"] = df["app_list"].apply(app_list_func,args=(app_list_map,))
#     df["clevel2_id"] = df["clevel2_id"].fillna("lost_na")
#     df["clevel2_id"] = df["clevel2_id"].apply(app_list_func, args=(level2_map,))
#     df["level3_ids"] = df["level3_ids"].fillna("lost_na")
#     df["level3_ids"] = df["level3_ids"].apply(app_list_func, args=(level3_map,))
#
#     # print("predict shape")
#     # print(df.shape)
#     df["uid"] = df["device_id"]
#     df["city"] = df["ucity_id"]
#     features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
#                 "channel", "top", "time", "stat_date","hospital_id",
#                 "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]
#     for i in features:
#         df[i] = df[i].astype("str")
#         df[i] = df[i].fillna("lost")
#         df[i] = df[i] + i
#
#     native_pre = df[df["label"] == 0]
#     native_pre = native_pre.drop("label", axis=1)
#     nearby_pre = df[df["label"] == 1]
#     nearby_pre = nearby_pre.drop("label", axis=1)
#
#     for i in ["ucity_id", "ccity_name", "device_type", "manufacturer",
#                 "channel", "top", "time", "stat_date","hospital_id",
#               "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]:
#         native_pre[i] = native_pre[i].map(value_map)
#         # TODO 没有覆盖到的类别会处理成na,暂时用0填充,后续完善一下
#         native_pre[i] = native_pre[i].fillna(0)
#
#         nearby_pre[i] = nearby_pre[i].map(value_map)
#         # TODO 没有覆盖到的类别会处理成na,暂时用0填充,后续完善一下
#         nearby_pre[i] = nearby_pre[i].fillna(0)
#
#     print("native")
#     print(native_pre.shape)
#
#     native_pre[["uid","city","cid_id"]].to_csv(path+"native.csv",index=False)
#     write_csv(native_pre, "native",200000)
#
#     print("nearby")
#     print(nearby_pre.shape)
#
#     nearby_pre[["uid","city","cid_id"]].to_csv(path+"nearby.csv",index=False)
#     write_csv(nearby_pre, "nearby", 160000)

def con_sql(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    db.close()
    return df


def test():
    # spark.sql("use online")
    # spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
    # spark.sql("ADD JAR /srv/apps/hive-udf-1.0-SNAPSHOT.jar")
    # spark.sql("CREATE TEMPORARY FUNCTION json_map AS 'brickhouse.udf.json.JsonMapUDF'")
    # spark.sql("CREATE TEMPORARY FUNCTION is_json AS 'com.gmei.hive.common.udf.UDFJsonFormatCheck'")

    spark.sparkContext.setLogLevel("WARN")
    df = spark.sql("select device_id,stat_date from esmm_train_data limit 60")
    df.show(6)
    df.write.csv('/recommend/tr', mode='overwrite', header=True)



    # data = [(0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2), (5, 9.2), (6, 14.4)]
    # df = spark.createDataFrame(data, ["id", "hour"])
    # df.show(6)
    # t = df.rdd.map(lambda x:x[0]).collect()
    # print(t)

    # validate_date = spark.sql("select max(stat_date) from esmm_train_data").rdd.map(lambda x: str(x[0]))
    # print(validate_date.count())
    # print("validate_date:" + validate_date)
    # temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    # start = (temp - datetime.timedelta(days=10)).strftime("%Y-%m-%d")
    # print(start)


if __name__ == '__main__':
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")
    # .set("spark.driver.maxResultSize", "4g")

    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    ti = pti.TiContext(spark)
    ti.tidbMapDatabase("jerry_test")
    spark.sparkContext.setLogLevel("WARN")
    feature_engineer()