from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pymysql
import smtplib
from email.mime.text import MIMEText
from email.utils import formataddr
from email.mime.multipart import MIMEMultipart
from email.mime.application import MIMEApplication
import redis
import datetime
from pyspark import SparkConf
import time
from pyspark.sql import SparkSession
import json
import numpy as np
import pandas as pd
from pyspark.sql.functions import lit
from pyspark.sql.functions import concat_ws
from tool import *


def get_user_service_portrait(x, all_word_tags, all_tag_tag_type, all_3tag_2tag, all_tags_name, size=None, pay_time=0):
    cl_id = x[0]
    search_info = x[1]
    user_df_service = get_user_log(cl_id, all_word_tags, pay_time=pay_time)

    # 增加df字段(days_diff_now, tag_type, tag2)
    if not user_df_service.empty:
        user_df_service["days_diff_now"] = round((int(time.time()) - user_df_service["time"].astype(float)) / (24 * 60 * 60))
        user_df_service["tag_type"] = user_df_service.apply(lambda x: all_tag_tag_type.get(x["tag_id"]), axis=1)
        user_df_service = user_df_service[user_df_service['tag_type'].isin(['2','3'])]
        if not user_df_service.empty:
            user_log_df_tag2_list = user_df_service[user_df_service['tag_type'] == '2']['tag_id'].unique().tolist()
            user_df_service["tag2"] = user_df_service.apply(lambda x:
                                                    get_tag2_from_tag3(x.tag_id, all_3tag_2tag, user_log_df_tag2_list)
                                                    if x.tag_type == '3' else x.tag_id, axis=1)
            user_df_service["tag2_type"] = user_df_service.apply(lambda x: all_tag_tag_type.get(x["tag2"]), axis=1)
            # 算分及比例
            user_df_service["tag_score"] = user_df_service.apply(
                lambda x: compute_henqiang(x.days_diff_now, exponential=0)/get_action_tag_count(user_df_service, x.time) if x.score_type == "henqiang" else (
                    compute_jiaoqiang(x.days_diff_now, exponential=0)/get_action_tag_count(user_df_service, x.time) if x.score_type == "jiaoqiang" else (
                        compute_ai_scan(x.days_diff_now, exponential=0)/get_action_tag_count(user_df_service, x.time) if x.score_type == "ai_scan" else (
                        compute_ruoyixiang(x.days_diff_now, exponential=0)/get_action_tag_count(user_df_service, x.time) if x.score_type == "ruoyixiang" else
                        compute_validate(x.days_diff_now, exponential=0)/get_action_tag_count(user_df_service, x.time)))), axis=1)
            tag_score_sum = user_df_service.groupby(by=["tag2", "tag2_type"]).agg(
                {'tag_score': 'sum', 'cl_id': 'first', 'action': 'first'}).reset_index().sort_values(by=["tag_score"],
                                                                                                     ascending=False)
            tag_score_sum['weight'] = 100 * tag_score_sum['tag_score'] / tag_score_sum['tag_score'].sum()
            tag_score_sum["pay_type"] = tag_score_sum.apply(
                lambda x: 3 if x.action == "api/order/validate" else (
                    2 if x.action == "api/settlement/alipay_callback" else 1
                ), axis=1
            )
            gmkv_tag_score2_sum = tag_score_sum[["tag2", "tag_score"]][:size].to_dict('record')
            gmkv_tag_score2_sum_dict = {i["tag2"]: i["tag_score"] for i in gmkv_tag_score2_sum}
            gmkv_tag_score3_sum_dict = {all_tags_name[i]: gmkv_tag_score2_sum_dict[i] for i in gmkv_tag_score2_sum_dict}
            gmkv_tag_score3_sum_dict_sort_list = sorted(gmkv_tag_score3_sum_dict.items(), key=lambda x:x[1],reverse=True)
            portrait_result = [i[0] for i in gmkv_tag_score3_sum_dict_sort_list]
            if not portrait_result:
                portrait_result = ["000"]
            return cl_id, search_info, portrait_result
    else:
        return cl_id, search_info, ["000"]

# data
device_info = []

# sql: select cl_id, collect_set(params["query"]) from bl_hdfs_maidian_updates where partition_date="20191111" and action="do_search" group by cl_id
with open("/home/gmuser/gyz/log/have_search_device_20191111.csv", "r") as f:
    for line in f.readlines():
        data = line.strip().split("gyz")
        device = data[0]
        search_words = eval(data[1])
        device_info.append([device, search_words])
pay_time = 1573401600
# 获取搜索词及其近义词对应的tag
all_word_tags = get_all_word_tags()
all_tag_tag_type = get_all_tag_tag_type()

# 3级tag对应的2级tag
all_3tag_2tag = get_all_3tag_2tag()

# 标签id对应的中文名称
all_tags_name = get_all_tags_name()

# rdd
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
    .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
    .set("spark.tispark.plan.allow_index_double_read", "false") \
    .set("spark.tispark.plan.allow_index_read", "true") \
    .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
    .set("spark.tispark.pd.addresses", "172.16.40.170:2379").set("spark.io.compression.codec", "lzf") \
    .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec", "snappy")


spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
spark.sparkContext.setLogLevel("WARN")
spark.sparkContext.addPyFile("/srv/apps/ffm-baseline_git/eda/smart_rank/tool.py")
device_ids_lst_rdd = spark.sparkContext.parallelize(device_info)
result = device_ids_lst_rdd.repartition(100).map(lambda x: get_user_service_portrait(x, all_word_tags, all_tag_tag_type, all_3tag_2tag, all_tags_name, size=None, pay_time=pay_time)).filter(lambda x: x is not None)
print(result.count())
print(result.take(10))
df = spark.createDataFrame(result).na.drop().toDF("device", "search_words", "user_portrait").na.drop().toPandas()
df.to_csv("~/gyz/log/user_action_20191111.csv", index=False)
spark.stop()