#coding=utf-8

import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
from datetime import date, timedelta
import time
from pyspark import StorageLevel
import os
from sqlalchemy import create_engine


def con_sql(sql):
    """
    :type sql : str
    :rtype : tuple
    """
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    db.close()
    return result

def set_join(lst):
    l = lst.unique().tolist()
    r = [str(i) for i in l]
    r =r[:500]
    return ','.join(r)


def main():

    # native queue
    df2 = pd.read_csv(path+'/native.csv')
    df2['cid_id'] = df2['cid_id'].astype(str)

    df1 = pd.read_csv(path+"/native/pred.txt",sep='\t',header=None,names=["ctr","cvr","ctcvr"])
    df2["ctr"],df2["cvr"],df2["ctcvr"] = df1["ctr"],df1["cvr"],df1["ctcvr"]
    df3 = df2.groupby(by=["uid","city"]).apply(lambda x: x.sort_values(by="ctcvr",ascending=False))\
        .reset_index(drop=True).groupby(by=["uid","city"]).agg({'cid_id':set_join}).reset_index(drop=False)
    df3.columns = ["device_id","city_id","native_queue"]
    print("native_device_count",df3.shape)


    # nearby queue
    df2 = pd.read_csv(path+'/nearby.csv')
    df2['cid_id'] = df2['cid_id'].astype(str)

    df1 = pd.read_csv(path+"/nearby/pred.txt",sep='\t',header=None,names=["ctr","cvr","ctcvr"])
    df2["ctr"], df2["cvr"], df2["ctcvr"] = df1["ctr"], df1["cvr"], df1["ctcvr"]
    df4 = df2.groupby(by=["uid","city"]).apply(lambda x: x.sort_values(by="ctcvr",ascending=False))\
        .reset_index(drop=True).groupby(by=["uid","city"]).agg({'cid_id':set_join}).reset_index(drop=False)
    df4.columns = ["device_id","city_id","nearby_queue"]
    print("nearby_device_count",df4.shape)

    #union
    df_all = pd.merge(df3,df4,on=['device_id','city_id'],how='outer').fillna("")
    df_all['device_id'] = df_all['device_id'].astype(str)
    df_all['city_id'] = df_all['city_id'].astype(str)
    df_all["time"] = str(datetime.datetime.now().strftime('%Y%m%d%H%M'))
    print("union_device_count",df_all.shape)

    host='172.16.40.158'
    port=4000
    user='root'
    password='3SYz54LS9#^9sBvC'
    db='jerry_test'
    charset='utf8'

    df_merge = df_all['device_id'] + df_all['city_id']
    to_delete = list(df_merge.values)
    total = len(to_delete)
    df_merge_str = [str(to_delete[:int(total/5)]).strip('[]')]
    for i in range(2,6):
        start = int(total*(i -1)/5)
        end = int(total*i/5)
        tmp = str(to_delete[start:end]).strip('[]')
        df_merge_str.append(tmp)

    try:
        for i in df_merge_str:
            delete_str = 'delete from esmm_device_diary_queue where concat(device_id,city_id) in ({0})'.format(i)
            con = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
            cur = con.cursor()
            cur.execute(delete_str)
            con.commit()
            print("delete done")
            con.close()
        engine = create_engine(str(r"mysql+pymysql://%s:" + '%s' + "@%s:%s/%s") % (user, password, host, port, db))
        df_all.to_sql('esmm_device_diary_queue',con=engine,if_exists='append',index=False,chunksize=8000)
        print("insert done")

    except Exception as e:
        print(e)


if __name__ == "__main__":
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    spark.sparkContext.setLogLevel("WARN")

    path = "hdfs://172.16.32.4:8020/strategy/esmm/"
    # df = spark.read.format("tfrecords").load(path+"nearby/part-r-00000")
    # df.show()

    uid1 = spark.read.format("csv").options(sep=",",header=True).load(path+"nearby/nearby.csv")
    uid1.show()
    pred1 = spark.read.format("csv").options(sep="\t").load(path+"nearby/pred.txt")
    pred1.show()