diaryQueue.py 7.1 KB
import pickle
import xlearn as xl
import pandas as pd
import pymysql
from datetime import datetime
import utils
import warnings
from multiprocessing import Pool


# 本地测试脚本

# 从测试Tidb数据库的表里获取数据,并转化成df格式
def test_con_sql(device_id):
    db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
                         passwd='workwork', db='doris_test')
    cursor = db.cursor()
    sql = "select native_queue,nearby_queue,nation_queue,megacity_queue from device_diary_queue " \
          "where device_id = '{}';".format(device_id)
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    if not df.empty:
        df = df.rename(columns={0: "native_queue", 1: "nearby_queue", 2: "nation_queue", 3: "megacity_queue"})
        native_queue = df.loc[0, "native_queue"].split(",")
        native_queue = list(map(lambda x:"diary|"+str(x),native_queue))
        nearby_queue = df.loc[0, "nearby_queue"].split(",")
        nearby_queue = list(map(lambda x: "diary|" + str(x), nearby_queue))
        nation_queue = df.loc[0, "nation_queue"].split(",")
        nation_queue = list(map(lambda x: "diary|" + str(x), nation_queue))
        megacity_queue = df.loc[0, "megacity_queue"].split(",")
        megacity_queue = list(map(lambda x: "diary|" + str(x), megacity_queue))
        db.close()

        return native_queue, nearby_queue, nation_queue, megacity_queue
    else:
        print("该用户对应的日记队列为空")


 # 更新前获取最新的native_queue
def get_native_queue(device_id):
    db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
                         passwd='workwork', db='doris_test')
    cursor = db.cursor()
    sql = "select native_queue from device_diary_queue where device_id = '{}';".format(device_id)
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    if not df.empty:
        native_queue = df.loc[0,0].split(",")
        native_queue = list(map(lambda x:"diary|"+str(x),native_queue))
        db.close()
        return native_queue
    else:
        return None


# 将device_id、city_id拼接到对应的城市热门日记表。注意:下面预测集特征顺序要与训练集保持一致
def feature_en(x_list, device_id):
    data = pd.DataFrame(x_list)
    data = data.rename(columns={0: "cid"})
    data["device_id"] = device_id
    now = datetime.now()
    data["hour"] = now.hour
    data["minute"] = now.minute
    data.loc[data["hour"] == 0, ["hour"]] = 24
    data.loc[data["minute"] == 0, ["minute"]] = 60
    data["hour"] = data["hour"].astype("category")
    data["minute"] = data["minute"].astype("category")
    # 虽然预测y,但ffm转化需要y,并不影响预测结果
    data["y"] = 0
    data.to_csv("/Users/mac/utils/result/data.csv",index=False)

    return data


# 把ffm.pkl load进来,将上面的表转化为ffm格式
def transform_ffm_format(df, device_id):
    with open("/Users/mac/utils/ffm.pkl", "rb") as f:
        ffm_format_pandas = pickle.load(f)
        data = ffm_format_pandas.native_transform(df)
        now = datetime.now().strftime("%Y-%m-%d-%H-%M")
        predict_file_name = "/Users/mac/utils/result/{0}_{1}.csv".format(device_id, now)
        data.to_csv(predict_file_name, index=False, header=None)
        return predict_file_name


# 将模型加载,预测,把预测日记的概率值按照降序排序,存到一个表里
def predict(queue_name, x_list, device_id):
    data = feature_en(x_list,device_id)
    data_file_path = transform_ffm_format(data, device_id)

    ffm_model = xl.create_ffm()
    ffm_model.setTest(data_file_path)
    ffm_model.setSigmoid()

    ffm_model.predict("/Users/mac/utils/model.out",
                      "/Users/mac/utils/result/{0}_output.txt".format(queue_name))
    save_result(queue_name, x_list)


def save_result(queue_name, x_list):
    score_df = pd.read_csv("/Users/mac/utils/result/{0}_output.txt".format(queue_name), header=None)
    score_df = score_df.rename(columns={0: "score"})
    score_df["cid"] = x_list
    score_df = score_df.sort_values(by="score",ascending=False)
    merge_score(x_list, score_df)


def merge_score(x_list, score_df):
    db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
                         passwd='workwork', db='zhengxing_test')

    cursor = db.cursor()
    score_list = []
    for i in x_list:
        sql = "select score from biz_feed_diary_score where diary_id = '{}';".format(i)
        cursor.execute(sql)

        if cursor.execute(sql) != 0:
            result = cursor.fetchone()[0]
            score_list.append(result)
        # 没有查到这个diary_id,默认score值是0
        else:
            score_list.append(0)

    db.close()

    score_df["score"] = score_df["score"] + score_list
    update_dairy_queue(score_df)


def update_dairy_queue(score_df):
    diary_id = score_df["cid"].values.tolist()
    video_id = []
    x = 1
    while x < len(diary_id):
        video_id.append(diary_id[x])
        x += 5
    if len(video_id)>0:
        not_video_id = list(set(diary_id) - set(video_id))
        not_video_id_df = score_df.loc[score_df["cid"].isin(not_video_id)]
        not_video_id_df = not_video_id_df.sort_values(by="score", ascending=False)
        video_id_df = score_df.loc[score_df["cid"].isin(video_id)]
        video_id_df = video_id_df.sort_values(by="score", ascending=False)
        not_video_id = not_video_id_df["cid"].values.tolist()
        video_id = video_id_df["cid"].values.tolist()
        diary_id = not_video_id
        i = 1
        for j in video_id:
            diary_id.insert(i, j)
            i += 5
        return diary_id
    else:
        score_df = score_df.sort_values(by="score", ascending=False)
        return score_df["cid"].values.tolist()


def update_sql_dairy_queue(queue_name, diary_id, device_id):
    db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
                         passwd='workwork', db='doris_test')
    cursor = db.cursor()
    sql = "update device_diary_queue set {}='{}' where device_id = '{}'".format(queue_name, diary_id, device_id)
    cursor.execute(sql)
    db.close()


def multi_update(key, name_dict, device_id,native_queue_list):
    diary_id = predict(key, name_dict[key], device_id)
    if get_native_queue(device_id) == native_queue_list:
        update_sql_dairy_queue(key, diary_id, device_id)
        print("更新结束")
    else:
        print("不需要更新日记队列")


if __name__ == "__main__":
    warnings.filterwarnings("ignore")
    # TODO 上线后把预测用户改成多进程预测
    device_id = "358035085192742"
    native_queue_list, nearby_queue_list, nation_queue_list, megacity_queue_list = test_con_sql(device_id)
    name_dict = {"native_queue": native_queue_list, "nearby_queue": nearby_queue_list,
                 "nation_queue": nation_queue_list, "megacity_queue": megacity_queue_list}
    pool = Pool(4)
    for key in name_dict.keys():
        pool.apply_async(multi_update,(key,name_dict,device_id,native_queue_list,))
    pool.close()
    pool.join()