from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pymysql
import smtplib
from email.mime.text import MIMEText
from email.utils import formataddr
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
import redis
import datetime
from pyspark import SparkConf
import time
from pyspark.sql import SparkSession
import json
import zstd

def zstd_compress(x):
    if x:
        return zstd.compress(bytes(json.dumps(x), encoding='utf-8'), 22)


def zstd_decompress(x):
    if x:
        result = zstd.decompress(x)
        return json.loads(result)


# ctr变化(更新用户的所有美购smart rank分
def update_device_smart_rank(device_id, result_all_dict, service_detail_view_count_30_dict, result_smart_rank_score_dict):
    device_meigou_ctr_key = 'device_meigou_ctr:device_id:' + str(device_id)
    gm_kv_cli = redis.Redis(host="172.16.40.135", port=5379, db=2, socket_timeout=2000)
    if gm_kv_cli.exists(device_meigou_ctr_key):
        ts_device_meigou_ctr = gm_kv_cli.hgetall(device_meigou_ctr_key)
        device_meigou_smart_rank = dict()
        for i in ts_device_meigou_ctr:
            ts_ctr = float(ts_device_meigou_ctr[i])
            service_id = str(i,encoding="utf-8")
            meigou_smart_rank_score = get_meigou_smart_rank(service_id, result_all_dict, service_detail_view_count_30_dict, ts_ctr, result_smart_rank_score_dict)
            device_meigou_smart_rank.update({service_id: meigou_smart_rank_score})

        device_meigou_smart_rank_key = 'device_meigou_smart_rank_zstd:device_id:' + str(device_id)
        REDIS_URL = 'redis://:ReDis!GmTx*0aN6@172.16.40.133:6379'
        cli_ins = redis.StrictRedis.from_url(REDIS_URL)
        cli_ins.set(device_meigou_smart_rank_key, zstd_compress(device_meigou_smart_rank))
        cli_ins.expire(device_meigou_smart_rank_key, time=24 * 60 * 60)
        return meigou_smart_rank_score
    return "periodic update fail"


# 获取美购的smart rank分
def get_meigou_smart_rank(service_id, result_all_dict, service_detail_view_count_30_dict, meigou_ctr, result_smart_rank_score_dict, table_cpc_price=-1):
    if service_id in result_all_dict:
        consult_value = result_all_dict[service_id]["consult_value"]
        if table_cpc_price == -1:
            click_price = result_all_dict[service_id]["click_price"]
        else:
            click_price = table_cpc_price
        service_detail_view_count_30 = service_detail_view_count_30_dict[service_id].get("service_detail_view_count_30", 0)
        if click_price == 0 and service_detail_view_count_30 <= 500:
            ctr_value = meigou_ctr
        else:
            return float('%.4g' % result_smart_rank_score_dict[service_id]["new_smart_rank"])
        discount_value = result_all_dict[service_id]["discount_value"]
        cpt_value = result_all_dict[service_id]["cpt_value"]
        org_value = discount_value + 0.5 * cpt_value + click_price
    else:
        if service_id in result_smart_rank_score_dict:
            return float('%.4g' % result_smart_rank_score_dict[service_id]["new_smart_rank"])
        else:
            consult_value = 0.001
            ctr_value = 0.1
            discount_value = 0.001
            cpt_value = 0.001
            click_price = 0
            org_value = discount_value + 0.5 * cpt_value + click_price
    meigou_smart_rank_score = consult_value * ctr_value * org_value
    return float('%.4g' % meigou_smart_rank_score)


if __name__ == '__main__':
    try:
        start = time.time()
        db_zhengxing = pymysql.connect(host="172.16.30.141", port=3306, user="work", password="BJQaT9VzDcuPBqkd",
                                       db="zhengxing", cursorclass=pymysql.cursors.DictCursor)
        cur_zhengxing = db_zhengxing.cursor()

        sql = "select service_id,service_detail_view_count_30 from statistic_service_smart_rank_v3 where stat_date=(select max(stat_date) from statistic_service_smart_rank_v3)"
        cur_zhengxing.execute(sql)
        result = cur_zhengxing.fetchall()
        service_detail_view_count_30_dict = dict()
        for i in result:
            service_detail_view_count_30_dict.update({str(i["service_id"]): i})
        # meigou smart_rank所有因子
        sql_smart_rank = "select service_id,discount_value,cpt_value,click_price,consult_value,ctr_value from api_smart_rank_factor"
        cur_zhengxing.execute(sql_smart_rank)
        result_all = cur_zhengxing.fetchall()
        result_all_dict = dict()
        for i in result_all:
            result_all_dict.update({str(i["service_id"]): i})
        # smart_rank_score
        sql_smart_rank_score = "select service_id, new_smart_rank from api_smart_rank"
        cur_zhengxing.execute(sql_smart_rank_score)
        result_smart_rank_score = cur_zhengxing.fetchall()
        result_smart_rank_score_dict = dict()
        for i in result_smart_rank_score:
            result_smart_rank_score_dict.update({str(i["service_id"]): i})
        db_zhengxing.close()

        # rdd
        sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
            .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
            .set("spark.tispark.plan.allow_index_double_read", "false") \
            .set("spark.tispark.plan.allow_index_read", "true") \
            .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
            .set("spark.tispark.pd.addresses", "172.16.40.170:2379").set("spark.io.compression.codec", "lzf") \
            .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec", "snappy")

        spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
        spark.sparkContext.setLogLevel("WARN")
        REDIS_URL = 'redis://:ReDis!GmTx*0aN6@172.16.40.133:6379'
        cli_ins = redis.StrictRedis.from_url(REDIS_URL)
        gray_level_device_ids = "doris:ctr_estimate:device_id_list"
        if cli_ins.exists(gray_level_device_ids):
            device_ids = cli_ins.smembers(gray_level_device_ids)
            # device_ids = [b"9C5E7C73-380C-4623-8F48-A64C8034E315" for i in range(1000)]
            device_ids_rdd = spark.sparkContext.parallelize(device_ids)
            result = device_ids_rdd.repartition(40).map(
                lambda x: update_device_smart_rank(str(x, encoding='utf-8'), result_all_dict,
                                                   service_detail_view_count_30_dict, result_smart_rank_score_dict))
            result.collect()
            print(time.time() - start)
    except Exception as e:
        print(e)