#!/srv/envs/nvwa/bin/python
# -*- coding: utf-8 -*-
import pickle
import xlearn as xl
import pandas as pd
import pymysql
from datetime import datetime
# utils 包必须要导,否则ffm转化时用到的pickle找不到utils,会报错
import utils
import warnings
from multiprocessing import Pool
from userProfile import get_active_users
from sklearn.preprocessing import MinMaxScaler
import time
from config import *
from utils import judge_online,con_sql


def get_video_id(cache_video_id):
    if flag:
        db = pymysql.connect(host=ONLINE_EAGLE_HOST, port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='eagle')
    else:
        # 本地数据库,没有密码,可能报错
        db = pymysql.connect(host=LOCAL_EAGLE_HOST, port=4000, user='root', db='eagle')
    cursor = db.cursor()
    sql = "select diary_id from feed_diary_boost;"
    try:
        cursor.execute(sql)
        result = cursor.fetchall()
        df = pd.DataFrame(list(result))
    except Exception:
        print("发生异常", Exception)
        df = pd.DataFrame()
    finally:
        db.close()

    if df.empty:
        return cache_video_id
    else:
        video_id = df[0].values.tolist()
        print("videoid")
        print(video_id[:2])
        return video_id


# 将device_id、city_id拼接到对应的城市热门日记表。注意:下面预测集特征顺序要与训练集保持一致
def feature_en(x_list, device_id):
    data = pd.DataFrame(x_list)
    # 下面的列名一定要用cid,不能用diaryid,因为预测模型用到的ffm上是cid
    data = data.rename(columns={0: "cid"})
    data["device_id"] = device_id
    now = datetime.now()
    data["hour"] = now.hour
    data["minute"] = now.minute
    data.loc[data["hour"] == 0, ["hour"]] = 24
    data.loc[data["minute"] == 0, ["minute"]] = 60
    data["hour"] = data["hour"].astype("category")
    data["minute"] = data["minute"].astype("category")
    # 虽然预测y,但ffm转化需要y,并不影响预测结果
    data["y"] = 0
    print("done 特征工程")

    return data


# 把ffm.pkl load进来,将上面的数据转化为ffm格式
def transform_ffm_format(df,queue_name,device_id):
    with open(path + "ffm.pkl", "rb") as f:
        ffm_format_pandas = pickle.load(f)
        data = ffm_format_pandas.native_transform(df)
        predict_file_name = path + "result/{0}_{1}.csv".format(device_id, queue_name)
        data.to_csv(predict_file_name, index=False, header=None)
        print("done ffm")
        return predict_file_name


def predict(queue_name,queue_arg,device_id):
    data = feature_en(queue_arg[0], device_id)
    data_file_path = transform_ffm_format(data,queue_name,device_id)
    ffm_model = xl.create_ffm()
    ffm_model.setTest(data_file_path)
    ffm_model.setSigmoid()
    ffm_model.predict(path + "model.out",
                      path + "result/output{0}_{1}.csv".format(device_id, queue_name))


def save_result(queue_name,queue_arg,device_id):
    score_df = pd.read_csv(path + "result/output{0}_{1}.csv".format(device_id, queue_name), header=None)
    mm_scaler = MinMaxScaler()
    mm_scaler.fit(score_df)
    score_df = pd.DataFrame(mm_scaler.transform(score_df))
    score_df = score_df.rename(columns={0: "score"})
    score_df["cid"] = queue_arg[0]
    # 去掉cid前面的"diary|"
    score_df["cid"] = score_df["cid"].apply(lambda x:x[6:])
    # print("score_df:")
    # print(score_df.head(1))
    # print(score_df.shape)
    if queue_arg[1] != []:
        df_temp = pd.DataFrame(queue_arg[1]).rename(columns={0: "cid"})
        df_temp["score"] = 0
        df_temp = df_temp.sort_index(axis=1,ascending=False)
        df_temp["cid"] = df_temp["cid"].apply(lambda x: x[6:])

        predict_score_df = score_df.append(df_temp)

        return predict_score_df

    else:
        return score_df


def get_score(queue_arg):
    if flag:
        db = pymysql.connect(host=SCORE_DB_ONLINE["host"], port=SCORE_DB_ONLINE["port"],
                             user=SCORE_DB_ONLINE["user"],passwd=SCORE_DB_ONLINE["passwd"],
                             db=SCORE_DB_ONLINE["db"])
    else:
        db = pymysql.connect(host=SCORE_DB_LOCAL["host"], port=SCORE_DB_LOCAL["port"],
                             user=SCORE_DB_LOCAL["user"], passwd=SCORE_DB_LOCAL["passwd"],
                             db=SCORE_DB_LOCAL["db"])

    # 去除diary_id 前面的"diary|"
    diary_list = tuple(list(map(lambda x:x[6:],queue_arg[2])))
    sql = "select score,diary_id from biz_feed_diary_score where diary_id in {};".format(diary_list)
    score_df = con_sql(db,sql)
    print("get score")
    return score_df


def update_dairy_queue(score_df,predict_score_df,total_video_id):
    diary_id = score_df["cid"].values.tolist()
    if total_video_id != []:
        video_id = list(set(diary_id)&set(total_video_id))
        if len(video_id)>0:
            not_video = list(set(diary_id) - set(video_id))
            # 为了相加时cid能够匹配,先把cid变成索引
            not_video_df = score_df.loc[score_df["cid"].isin(not_video)].set_index(["cid"])
            not_video_predict_df = predict_score_df.loc[predict_score_df["cid"].isin(not_video)].set_index(["cid"])
            not_video_df["score"] = not_video_df["score"] + not_video_predict_df["score"]
            not_video_df = not_video_df.sort_values(by="score", ascending=False)

            video_df = score_df.loc[score_df["cid"].isin(video_id)].set_index(["cid"])
            video_predict_df = predict_score_df.loc[predict_score_df["cid"].isin(video_id)].set_index(["cid"])
            video_df["score"] = video_df["score"] + video_predict_df["score"]
            video_df = video_df.sort_values(by="score", ascending=False)

            not_video_id = not_video_df.index.tolist()
            video_id = video_df.index.tolist()
            new_queue = not_video_id
            i = 1
            for j in video_id:
                new_queue.insert(i, j)
                i += 5

            print("分数合并成功")
            return new_queue
        # 如果取交集后没有视频日记
        else:
            score_df = score_df.set_index(["cid"])
            predict_score_df = predict_score_df.set_index(["cid"])
            score_df["score"]=score_df["score"]+predict_score_df["score"]
            score_df = score_df.sort_values(by="score", ascending=False)
            print("分数合并成功1")
            return score_df.index.tolist()
    # 如果total_video_id是空列表
    else:
        score_df = score_df.set_index(["cid"])
        predict_score_df = predict_score_df.set_index(["cid"])
        score_df["score"] = score_df["score"] + predict_score_df["score"]
        score_df = score_df.sort_values(by="score", ascending=False)
        # print("分数合并成功1")
        return score_df.index.tolist()


def update_sql_dairy_queue(queue_name, diary_id,device_id, city_id):
    if flag:
        db = pymysql.connect(host=QUEUE_ONLINE_HOST, port=3306, user='doris', passwd='o5gbA27hXHHm',
                             db='doris_prod')
    else:
        db = pymysql.connect(host=LOCAL_HOST, port=3306, user='work',passwd='workwork', db='doris_test')
    cursor = db.cursor()
    id_str = str(diary_id[0])
    for i in range(1, len(diary_id)):
        id_str = id_str + "," + str(diary_id[i])

    sql = "update device_diary_queue set {}='{}' where device_id = '{}' and city_id = '{}'".format\
        (queue_name,id_str,device_id, city_id)
    cursor.execute(sql)
    db.commit()
    db.close()
    print("成功写入diary_id")


def queue_compare(old_list, new_list):
    # 去掉前面的"diary|"
    old_list = list(map(lambda x: int(x[6:]),old_list))
    # print("旧表前十个")
    # print(old_list[:10])
    # print("新表前十个")
    # print(new_list[:10])
    temp = list(range(len(old_list)))
    x_dict = dict(zip(old_list, temp))
    temp = list(range(len(new_list)))
    y_dict = dict(zip(new_list, temp))
    i = 0
    for key in x_dict.keys():
        if x_dict[key] != y_dict[key]:
            i += 1
    if i >0:
        print("日记队列更新前日记总个数{},位置发生变化个数{},发生变化率{}%".format(len(old_list), i,
                                                          round(i / len(old_list) * 100), 2))
        return True
    else:
        return False


def get_queue(device_id, city_id,queue_name):
    if flag:
        db = pymysql.connect(host=QUEUE_ONLINE_HOST, port=3306, user='doris',passwd='o5gbA27hXHHm',
                             db='doris_prod')

    else:
        db = pymysql.connect(host=LOCAL_HOST, port=3306, user='work',
                             passwd='workwork', db='doris_test')
    cursor = db.cursor()
    sql = "select {} from device_diary_queue " \
          "where device_id = '{}' and city_id = '{}';".format(queue_name,device_id, city_id)
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    if df.empty:
        print("该用户对应的日记为空")
        return False
    else:
        queue_list = df.loc[0, 0].split(",")
        queue_list = list(map(lambda x: "diary|" + str(x), queue_list))
        db.close()
        print("成功获取queue")
        return queue_list


def pipe_line(queue_name, queue_arg, device_id,total_video_id):
    predict(queue_name, queue_arg, device_id)
    predict_score_df = save_result(queue_name, queue_arg, device_id)
    score_df = get_score(queue_arg)
    if score_df.empty:
        print("获取的日记列表是空")
        return False
    else:
        score_df = score_df.rename(columns={0: "score", 1: "cid"})

        diary_queue = update_dairy_queue(score_df, predict_score_df,total_video_id)
        return diary_queue


def user_update(device_id, city_id, queue_name,data_set_cid,total_video_id):
    queue_list = get_queue(device_id, city_id, queue_name)
    if queue_list:
        queue_predict = list(set(queue_list) & set(data_set_cid))
        queue_not_predict = list(set(queue_list) - set(data_set_cid))
        queue_arg = [queue_predict, queue_not_predict, queue_list]
        if queue_predict != []:
            diary_queue = pipe_line(queue_name, queue_arg, device_id,total_video_id)
            if diary_queue and queue_compare(queue_list, diary_queue):
                update_sql_dairy_queue(queue_name, diary_queue, device_id, city_id)
                print("更新结束")
            else:
                print("获取的日记列表是空或者日记队列顺序没有变化,所以不更新日记队列")
        else:
            print("预测集是空,不需要预测")
    else:
        print("日记队列为空")


def multi_proecess_update(device_id, city_id, data_set_cid,total_video_id):
    queue_name_list = ["native_queue","nearby_queue","nation_queue","megacity_queue"]

    pool = Pool(4)
    for queue_name in queue_name_list:
        pool.apply_async(user_update, (device_id, city_id, queue_name,data_set_cid,total_video_id,))
    pool.close()
    pool.join()


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    flag,path = judge_online()
    # 增加缓存日记视频列表
    cache_video_id = []
    cache_device_city_list = []
    differ = 0
    while True:
        start = time.time()
        device_city_list = get_active_users(flag, path, differ)
        time1 = time.time()
        print("获取用户活跃表耗时:{}秒".format(time1-start))
        # 过滤掉5分钟内预测过的用户
        device_city_list = list(set(tuple(device_city_list))-set(tuple(cache_device_city_list)))
        print("device_city_list")
        print(device_city_list)
        if datetime.now().minute % 5 == 0:
            cache_device_city_list = []
        if device_city_list != []:
            data_set_cid = pd.read_csv(path + "data_set_cid.csv")["cid"].values.tolist()
            total_video_id = get_video_id(cache_video_id)
            cache_video_id = total_video_id
            cache_device_city_list.extend(device_city_list)
            for device_city in device_city_list:
                multi_proecess_update(device_city[0], device_city[1], data_set_cid, total_video_id)
        differ = time.time()-start
        print("differ:{}秒".format(differ))



# # TODO 上线后把预测用户改成多进程预测