import os
from pyspark import SparkConf
from pyspark.sql import SparkSession
import sys
import time
from datetime import date, timedelta
import pandas as pd
from gensim.models import Word2Vec
import pickle

sys.path.append(os.path.dirname(os.path.abspath(os.path.dirname(__file__))))
import utils.connUtils as connUtils

VERSION = "v1"

def getClickSql(start, end):
    sql = """
    SELECT DISTINCT t1.partition_date, t1.cl_id device_id, t1.card_id,t1.time_stamp,t1.page_stay
      FROM
        (
            select partition_date,cl_id,business_id as card_id,time_stamp,page_stay
            from online.bl_hdfs_maidian_updates
            where action = 'page_view'
            AND partition_date>='{startDay}' and partition_date<='{endDay}'
            AND page_name='welfare_detail'
            AND page_stay >= 10 
            AND cl_id is not null
            AND cl_id != ''
            AND business_id is not null
            AND business_id != ''
            group by partition_date,cl_id,business_id,time_stamp,page_stay
        ) AS t1
        join
        (	--渠道,新老
            SELECT distinct device_id
            FROM online.ml_device_day_active_status
            where partition_date>='{startDay}' and partition_date<='{endDay}'
            AND active_type in ('1','2','4')
            and first_channel_source_type not in ('yqxiu1','yqxiu2','yqxiu3','yqxiu4','yqxiu5','mxyc1','mxyc2','mxyc3'
            ,'wanpu','jinshan','jx','maimai','zhuoyi','huatian','suopingjingling','mocha','mizhe','meika','lamabang'
            ,'js-az1','js-az2','js-az3','js-az4','js-az5','jfq-az1','jfq-az2','jfq-az3','jfq-az4','jfq-az5','toufang1'
            ,'toufang2','toufang3','toufang4','toufang5','toufang6','TF-toufang1','TF-toufang2','TF-toufang3','TF-toufang4'
            ,'TF-toufang5','tf-toufang1','tf-toufang2','tf-toufang3','tf-toufang4','tf-toufang5','benzhan','promotion_aso100'
            ,'promotion_qianka','promotion_xiaoyu','promotion_dianru','promotion_malioaso','promotion_malioaso-shequ'
            ,'promotion_shike','promotion_julang_jl03','promotion_zuimei','','unknown')
            AND first_channel_source_type not like 'promotion\_jf\_%'
        ) t2
        on t1.cl_id = t2.device_id

        LEFT JOIN
        (	--去除黑名单
            select distinct device_id
            from ML.ML_D_CT_DV_DEVICECLEAN_DIMEN_D
            where PARTITION_DAY =regexp_replace(DATE_SUB(current_date,1) ,'-','')
            AND is_abnormal_device = 'true'
        )t3 
        on t3.device_id=t2.device_id
        WHERE t3.device_id is null
         """.format(startDay=start,endDay=end)
    print(sql)
    return sql

def addDays(n, format="%Y%m%d"):
    return (date.today() + timedelta(days=n)).strftime(format)

def reverseCol(x):
    res = []
    datas = sorted(list(x.time_card))
    last_time_stamp = int(datas[0].split("_")[0])
    res_line = []
    for d in datas:
        time_stamp = int(d.split("_")[0])
        cart_id = d.split("_")[1]
        if (time_stamp - last_time_stamp) > 60 * 60:
            if len(res_line) > 1:
                res.append(res_line)
            res_line = [cart_id]
        else:
            res_line.append(cart_id)
        last_time_stamp = time_stamp
    if len(res_line) > 1:
        res.append(res_line)

    return res

def get_spark(appName):
    sparkConf = SparkConf()
    sparkConf.set("spark.sql.crossJoin.enabled", True)
    sparkConf.set("spark.debug.maxToStringFields", "100")
    sparkConf.set("spark.tispark.plan.allow_index_double_read", False)
    sparkConf.set("spark.tispark.plan.allow_index_read", True)
    sparkConf.set("spark.hive.mapred.supports.subdirectories", True)
    sparkConf.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", True)
    sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
    sparkConf.set("mapreduce.output.fileoutputformat.compress", False)
    sparkConf.set("mapreduce.map.output.compress", False)
    spark = (SparkSession
             .builder
             .config(conf=sparkConf)
             .appName(appName)
             .enableHiveSupport()
             .getOrCreate())
    return spark

if __name__ == '__main__':
    start = time.time()
    # 入参
    trainDays = int(sys.argv[1])
    spark = get_spark("service_embedding")
    spark.sparkContext.setLogLevel("ERROR")

    print('trainDays:{}'.format(trainDays), flush=True)

    endDay = addDays(0)
    startDay = addDays(-int(trainDays))

    print("train_data start:{} end:{}".format(startDay, endDay))

    # 行为数据
    clickSql = getClickSql(startDay, endDay)
    clickDF = spark.sql(clickSql)
    df = clickDF.toPandas()
    df = pd.DataFrame(df)
    # pd.DataFrame(df).to_csv("/tmp/service_click.csv",index=False)
    print("count",df.count())

    df["time_card"] = df["time_stamp"].map(str) + "_" + df["card_id"].map(str)
    new_df = df.groupby(["device_id"]).apply(reverseCol).to_frame("card_seq").reset_index()

    df1 = new_df.loc[new_df["card_seq"].map(len) > 1]
    print("user seq size:",df1.count())

    datas = df1["card_seq"].tolist()
    train_datas = []
    for d in datas:
        train_datas.extend(d)
    print("train size:",len(train_datas))

    model = Word2Vec(train_datas, sg=1, vector_size=16, window=5, epochs=50)

    s = pickle.dumps(model)

    # 模型保存
    conn = connUtils.getRedisConn()
    model_key = "strategy:word2vec:{}:{}".format("service",VERSION)
    model_status_key = "strategy:word2vec:status:{}:{}".format("service",VERSION)
    conn.set(model_key,s)
    conn.expire(model_key,60*60*24*30)
    # 模型更新状态保存
    conn.set(model_status_key,"1")
    conn.expire(model_status_key,60*60*24*30)
    # conn.close()

    print("model to redis: success")