Commit 605673e8 authored by 张彦钊's avatar 张彦钊

update predictDiaryLocal file

parent 4758923b
from config import *
import pandas as pd
import pickle
import xlearn as xl
from userProfile import *
import time
from utils import *
import os
# 本地测试脚本
......@@ -20,19 +16,20 @@ def test_con_sql(device_id):
result = cursor.fetchall()
df = pd.DataFrame(list(result))
if not df.empty:
df = df.rename(columns = {0:"native_queue",1:"nearby_queue",2:"nation_queue",3:"megacity_queue"})
native_queue_list = df.loc[0,"native_queue"].split(",")
nearby_queue_list = df.loc[0,"nearby_queue"].split(",")
nation_queue_list = df.loc[0,"nation_queue"].split(",")
megacity_queue_list = df.loc[0,"megacity_queue"].split(",")
df = df.rename(columns={0: "native_queue", 1: "nearby_queue", 2: "nation_queue", 3: "megacity_queue"})
native_queue = df.loc[0, "native_queue"].split(",")
nearby_queue = df.loc[0, "nearby_queue"].split(",")
nation_queue = df.loc[0, "nation_queue"].split(",")
megacity_queue = df.loc[0, "megacity_queue"].split(",")
db.close()
return native_queue_list,nearby_queue_list,nation_queue_list,megacity_queue_list
return native_queue, nearby_queue, nation_queue, megacity_queue
# 将device_id、city_id拼接到对应的城市热门日记表。注意:下面预测集特征顺序要与训练集保持一致
def feature_en(x_list,device_id):
def feature_en(x_list, device_id):
data = pd.DataFrame(x_list)
data = data.rename(columns= {0:"diary_id"})
data = data.rename(columns={0: "diary_id"})
data["device_id"] = device_id
now = datetime.now()
data["hour"] = now.hour
......@@ -48,55 +45,99 @@ def feature_en(x_list,device_id):
# 把ffm.pkl load进来,将上面的表转化为ffm格式
def transform_ffm_format(df, device_id):
with open("/Users/mac/utils/ffm.pkl","rb") as f:
with open("/Users/mac/utils/ffm.pkl", "rb") as f:
ffm_format_pandas = pickle.load(f)
data = ffm_format_pandas.transform(df)
now = datetime.now().strftime("%Y-%m-%d-%H-%M")
predict_file_name = "/Users/mac/utils/result/{0}_{1}.csv".format(device_id, now)
data.to_csv(predict_file_name, index=False,header=None)
data.to_csv(predict_file_name, index=False, header=None)
print("成功将ffm预测文件写到本地")
return predict_file_name
# 将模型加载,预测,把预测日记的概率值按照降序排序,存到一个表里
def predict(queue_name,x_list,device_id):
instance = feature_en(x_list)
instance_file_path = transform_ffm_format(instance, device_id)
def predict(queue_name, x_list, device_id):
data = feature_en(x_list)
data_file_path = transform_ffm_format(data, device_id)
ffm_model = xl.create_ffm()
ffm_model.setTest(instance_file_path)
ffm_model.setTest(data_file_path)
ffm_model.setSigmoid()
ffm_model.predict("/Users/mac/utils/model.out",
"/Users/mac/utils/result/{0}_output.txt".format(queue_name))
print("{}预测结束".format(queue_name))
predict_save_to_local(user_profile, instance)
save_result(queue_name, x_list)
# 将预测结果与device_id 进行拼接,并按照概率降序排序
def wrapper_result(user_profile, instance):
proba = pd.read_csv(DIRECTORY_PATH +
"result/{0}_output.txt".format(user_profile['device_id']), header=None)
proba = proba.rename(columns={0: "prob"})
proba["cid"] = instance['cid']
proba = proba.sort_values(by="prob", ascending=False)
proba = proba.head(50)
return proba
def save_result(queue_name, x_list):
score_df = pd.read_csv("/Users/mac/utils/result/{0}_output.txt".format(queue_name), header=None)
score_df = score_df.rename(columns={0: "score"})
score_df["diary_id"] = x_list
merge_score(x_list, score_df)
# 预测候选集保存到本地
def predict_save_to_local(user_profile, instance):
proba = wrapper_result(user_profile, instance)
proba.loc[:, "url"] = proba["cid"].apply(lambda x: "http://m.igengmei.com/diary_book/" + str(x[6:]) + '/')
proba.to_csv(DIRECTORY_PATH + "result/feed_{}".format(user_profile['device_id']), index=False)
print("成功将预测候选集保存到本地")
def merge_score(x_list, score_df):
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='zhengxing_test')
cursor = db.cursor()
score_list = []
for i in x_list:
sql = "select score from biz_feed_diary_score where diary_id = '{}';".format(i)
if cursor.execute(sql) != 0:
result = cursor.fetchone()
score_list.append(result)
# 没有查到这个diary_id,默认score值是0
else:
score_list.append(0)
db.close()
score_df["score"] = score_df["score"] + score_list
update_dairy_queue(score_df)
def update_dairy_queue(score_df):
diary_id = score_df["diary_id"].values.tolist()
video_id = []
x = 1
while x <= len(diary_id):
video_id.append(diary_id[x])
x += 5
not_video_id = list(set(diary_id) - set(video_id))
not_video_id_df = score_df.loc[score_df["diary_id"].isin(not_video_id)]
not_video_id_df = not_video_id_df.sort_values(by="score", ascending=False)
video_id_df = score_df.loc[score_df["diary_id"].isin(video_id)]
video_id_df = video_id_df.sort_values(by="score", ascending=False)
not_video_id = not_video_id_df["diary_id"].values.tolist()
video_id = video_id_df["diary_id"].values.tolist()
diary_id = not_video_id
i = 1
for j in video_id:
diary_id.insert(i, j)
i += 5
return diary_id
def update_sql_dairy_queue(queue_name, diary_id, device_id):
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='doris_test')
cursor = db.cursor()
sql = "update device_diary_queue set '{}'='{}' where device_id = '{}'".format(queue_name, diary_id, device_id)
cursor.execute(sql)
db.close()
# def router(device_id):
# user_profile, not_exist = fetch_user_profile(device_id)
# if not_exist:
# print('Sorry, we don\'t have you.')
# else:
# predict(user_profile)
# 获取数据库表更新时间
def get_update_time():
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='doris_test')
cursor = db.cursor()
sql = "SELECT `UPDATE_TIME` FROM `information_schema`.`TABLES` " \
"WHERE `information_schema`.`TABLES`.`TABLE_SCHEMA` = 'doris_test' " \
"AND `information_schema`.`TABLES`.`TABLE_NAME` = 'device_diary_queue';"
cursor.execute(sql)
update_time = cursor.fetchone()[0]
return update_time
# 多进程预测
......@@ -113,16 +154,14 @@ def predict_save_to_local(user_profile, instance):
if __name__ == "__main__":
# 数据库没有更新时间字段,下面的代码不能使用
# sql_update_time_start = get_update_time()
native_queue_list, nearby_queue_list, nation_queue_list, megacity_queue_list = test_con_sql("device_id")
predict("native_queue",native_queue_list,"device_id")
predict("nearby_queue", nearby_queue_list, "device_id")
predict("nation_queue", nation_queue_list, "device_id")
predict("megacity_queue", megacity_queue_list, "device_id")
name_dict = {"native_queue": native_queue_list, "nearby_queue": nearby_queue_list,
"nation_queue": nation_queue_list, "megacity_queue": megacity_queue_list}
for key in name_dict.keys():
diary_id = predict(key, name_dict[key], "devcie_id")
sql_update_time_end = get_update_time()
# 数据库没有更新时间字段,下面的代码不能使用
# if sql_update_time_start == sql_update_time_end:
update_sql_dairy_queue(key, diary_id, "device_id")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment