Commit e61e0e1a authored by 张彦钊's avatar 张彦钊

add content in predictDiaryLocal.py

parent 5668f217
......@@ -4,6 +4,7 @@ import pandas as pd
import pymysql
from datetime import datetime
import utils
import warnings
# 本地测试脚本
......@@ -21,13 +22,13 @@ def test_con_sql(device_id):
if not df.empty:
df = df.rename(columns={0: "native_queue", 1: "nearby_queue", 2: "nation_queue", 3: "megacity_queue"})
native_queue = df.loc[0, "native_queue"].split(",")
native_queue = [i.strip() for i in native_queue]
native_queue = list(map(lambda x:"diary|"+str(x),native_queue))
nearby_queue = df.loc[0, "nearby_queue"].split(",")
nearby_queue = [i.strip() for i in nearby_queue]
nearby_queue = list(map(lambda x: "diary|" + str(x), nearby_queue))
nation_queue = df.loc[0, "nation_queue"].split(",")
nation_queue = [i.strip() for i in nation_queue]
nation_queue = list(map(lambda x: "diary|" + str(x), nation_queue))
megacity_queue = df.loc[0, "megacity_queue"].split(",")
megacity_queue = [i.strip() for i in megacity_queue]
megacity_queue = list(map(lambda x: "diary|" + str(x), megacity_queue))
db.close()
print(native_queue)
print(nearby_queue)
......@@ -38,6 +39,24 @@ def test_con_sql(device_id):
print("该用户对应的日记队列为空")
# 更新前获取最新的native_queue
def get_native_queue(device_id):
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='doris_test')
cursor = db.cursor()
sql = "select native_queue from device_diary_queue where device_id = '{}';".format(device_id)
cursor.execute(sql)
result = cursor.fetchall()
df = pd.DataFrame(list(result))
if not df.empty:
native_queue = df.loc[0,0].split(",")
native_queue = list(map(lambda x:"diary|"+str(x),native_queue))
db.close()
return native_queue
else:
return None
# 将device_id、city_id拼接到对应的城市热门日记表。注意:下面预测集特征顺序要与训练集保持一致
def feature_en(x_list, device_id):
data = pd.DataFrame(x_list)
......@@ -53,7 +72,6 @@ def feature_en(x_list, device_id):
# 虽然预测y,但ffm转化需要y,并不影响预测结果
data["y"] = 0
data.to_csv("/Users/mac/utils/result/data.csv",index=False)
print(data)
return data
......@@ -88,25 +106,31 @@ def save_result(queue_name, x_list):
score_df = pd.read_csv("/Users/mac/utils/result/{0}_output.txt".format(queue_name), header=None)
score_df = score_df.rename(columns={0: "score"})
score_df["cid"] = x_list
print("done save_result")
merge_score(x_list, score_df)
def merge_score(x_list, score_df):
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='zhengxing_test')
cursor = db.cursor()
score_list = []
for i in x_list:
sql = "select score from biz_feed_diary_score where diary_id = '{}';".format(i)
cursor.execute(sql)
if cursor.execute(sql) != 0:
result = cursor.fetchone()
result = cursor.fetchone()[0]
score_list.append(result)
# 没有查到这个diary_id,默认score值是0
else:
score_list.append(0)
db.close()
score_df["score"] = score_df["score"] + score_list
print("done merge_score")
update_dairy_queue(score_df)
......@@ -114,47 +138,40 @@ def update_dairy_queue(score_df):
diary_id = score_df["cid"].values.tolist()
video_id = []
x = 1
while x <= len(diary_id):
while x < len(diary_id):
video_id.append(diary_id[x])
x += 5
not_video_id = list(set(diary_id) - set(video_id))
not_video_id_df = score_df.loc[score_df["cid"].isin(not_video_id)]
not_video_id_df = not_video_id_df.sort_values(by="score", ascending=False)
video_id_df = score_df.loc[score_df["cid"].isin(video_id)]
video_id_df = video_id_df.sort_values(by="score", ascending=False)
not_video_id = not_video_id_df["cid"].values.tolist()
video_id = video_id_df["cid"].values.tolist()
diary_id = not_video_id
i = 1
for j in video_id:
diary_id.insert(i, j)
i += 5
return diary_id
if len(video_id)>0:
not_video_id = list(set(diary_id) - set(video_id))
not_video_id_df = score_df.loc[score_df["cid"].isin(not_video_id)]
not_video_id_df = not_video_id_df.sort_values(by="score", ascending=False)
video_id_df = score_df.loc[score_df["cid"].isin(video_id)]
video_id_df = video_id_df.sort_values(by="score", ascending=False)
not_video_id = not_video_id_df["cid"].values.tolist()
video_id = video_id_df["cid"].values.tolist()
diary_id = not_video_id
i = 1
for j in video_id:
diary_id.insert(i, j)
i += 5
print("done update_dairy_queue")
return diary_id
else:
score_df = score_df.sort_values(by="score", ascending=False)
return score_df["cid"].values.tolist()
def update_sql_dairy_queue(queue_name, diary_id, device_id):
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='doris_test')
cursor = db.cursor()
sql = "update device_diary_queue set '{}'='{}' where device_id = '{}'".format(queue_name, diary_id, device_id)
sql = "update device_diary_queue set {}='{}' where device_id = '{}'".format(queue_name, diary_id, device_id)
cursor.execute(sql)
print("done update_sql_dairy_queue")
db.close()
# 获取数据库表更新时间
def get_update_time():
db = pymysql.connect(host='rdsmaqevmuzj6jy.mysql.rds.aliyuncs.com', port=3306, user='work',
passwd='workwork', db='doris_test')
cursor = db.cursor()
sql = "SELECT `UPDATE_TIME` FROM `information_schema`.`TABLES` " \
"WHERE `information_schema`.`TABLES`.`TABLE_SCHEMA` = 'doris_test' " \
"AND `information_schema`.`TABLES`.`TABLE_NAME` = 'device_diary_queue';"
cursor.execute(sql)
update_time = cursor.fetchone()[0]
return update_time
# 多进程预测
# TODO 多进程更新
# def multi_predict(predict_list,processes=12):
# pool = Pool(processes)
# for device_id in predict_list:
......@@ -168,16 +185,17 @@ def get_update_time():
if __name__ == "__main__":
# 数据库没有更新时间字段,下面的代码不能使用
# sql_update_time_start = get_update_time()
warnings.filterwarnings("ignore")
device_id = "358035085192742"
native_queue_list, nearby_queue_list, nation_queue_list, megacity_queue_list = test_con_sql(device_id)
name_dict = {"native_queue": native_queue_list, "nearby_queue": nearby_queue_list,
"nation_queue": nation_queue_list, "megacity_queue": megacity_queue_list}
for key in name_dict.keys():
diary_id = predict(key, name_dict[key], device_id)
# sql_update_time_end = get_update_time()
# 数据库没有更新时间字段,下面的代码不能使用
# if sql_update_time_start == sql_update_time_end:
update_sql_dairy_queue(key, diary_id, device_id)
if get_native_queue(device_id) == native_queue_list:
update_sql_dairy_queue(key, diary_id, device_id)
print("end")
else:
print("不需要更新日记队列")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment