Commit 4b897aba authored by 高雅喆's avatar 高雅喆

Merge branch 'master' of git.wanmeizhensuo.com:ML/ffm-baseline

bugfix in plot legend
parents ac2358db e27ee920
......@@ -5,6 +5,7 @@ import pymysql
from datetime import datetime
import utils
import warnings
from multiprocessing import Pool
# 本地测试脚本
......@@ -30,10 +31,7 @@ def test_con_sql(device_id):
megacity_queue = df.loc[0, "megacity_queue"].split(",")
megacity_queue = list(map(lambda x: "diary|" + str(x), megacity_queue))
db.close()
print(native_queue)
print(nearby_queue)
print(nation_queue)
print(megacity_queue)
return native_queue, nearby_queue, nation_queue, megacity_queue
else:
print("该用户对应的日记队列为空")
......@@ -72,6 +70,7 @@ def feature_en(x_list, device_id):
# 虽然预测y,但ffm转化需要y,并不影响预测结果
data["y"] = 0
data.to_csv("/Users/mac/utils/result/data.csv",index=False)
return data
......@@ -79,11 +78,10 @@ def feature_en(x_list, device_id):
def transform_ffm_format(df, device_id):
with open("/Users/mac/utils/ffm.pkl", "rb") as f:
ffm_format_pandas = pickle.load(f)
data = ffm_format_pandas.transform(df)
data = ffm_format_pandas.native_transform(df)
now = datetime.now().strftime("%Y-%m-%d-%H-%M")
predict_file_name = "/Users/mac/utils/result/{0}_{1}.csv".format(device_id, now)
data.to_csv(predict_file_name, index=False, header=None)
print("成功将ffm预测文件写到本地")
return predict_file_name
......@@ -98,7 +96,6 @@ def predict(queue_name, x_list, device_id):
ffm_model.predict("/Users/mac/utils/model.out",
"/Users/mac/utils/result/{0}_output.txt".format(queue_name))
print("{}预测结束".format(queue_name))
save_result(queue_name, x_list)
......@@ -106,7 +103,6 @@ def save_result(queue_name, x_list):
score_df = pd.read_csv("/Users/mac/utils/result/{0}_output.txt".format(queue_name), header=None)
score_df = score_df.rename(columns={0: "score"})
score_df["cid"] = x_list
print("done save_result")
merge_score(x_list, score_df)
......@@ -130,7 +126,6 @@ def merge_score(x_list, score_df):
db.close()
score_df["score"] = score_df["score"] + score_list
print("done merge_score")
update_dairy_queue(score_df)
......@@ -154,7 +149,6 @@ def update_dairy_queue(score_df):
for j in video_id:
diary_id.insert(i, j)
i += 5
print("done update_dairy_queue")
return diary_id
else:
score_df = score_df.sort_values(by="score", ascending=False)
......@@ -167,21 +161,16 @@ def update_sql_dairy_queue(queue_name, diary_id, device_id):
cursor = db.cursor()
sql = "update device_diary_queue set {}='{}' where device_id = '{}'".format(queue_name, diary_id, device_id)
cursor.execute(sql)
print("done update_sql_dairy_queue")
db.close()
# TODO 多进程更新
# def multi_predict(predict_list,processes=12):
# pool = Pool(processes)
# for device_id in predict_list:
# start = time.time()
# pool.apply_async(router, (device_id,))
# end = time.time()
# print("该用户{}预测耗时{}秒".format(device_id, (end - start)))
#
# pool.close()
# pool.join()
def multi_update(key, name_dict, device_id,native_queue_list):
diary_id = predict(key, name_dict[key], device_id)
if get_native_queue(device_id) == native_queue_list:
update_sql_dairy_queue(key, diary_id, device_id)
print("更新结束")
else:
print("不需要更新日记队列")
if __name__ == "__main__":
......@@ -190,12 +179,10 @@ if __name__ == "__main__":
native_queue_list, nearby_queue_list, nation_queue_list, megacity_queue_list = test_con_sql(device_id)
name_dict = {"native_queue": native_queue_list, "nearby_queue": nearby_queue_list,
"nation_queue": nation_queue_list, "megacity_queue": megacity_queue_list}
pool = Pool(12)
for key in name_dict.keys():
diary_id = predict(key, name_dict[key], device_id)
if get_native_queue(device_id) == native_queue_list:
update_sql_dairy_queue(key, diary_id, device_id)
print("end")
else:
print("不需要更新日记队列")
pool.apply_async(multi_update,(key,name_dict,device_id,native_queue_list,))
pool.close()
pool.join()
......@@ -131,6 +131,11 @@ class multiFFMFormatPandas:
return pd.Series(result_map)
# 原生转化方法,不需要多进程
def native_transform(self,df):
t = df.dtypes.to_dict()
return pd.Series({idx: self.transform_row_(row, t) for idx, row in df.iterrows()})
# 多进程计算方法
def pool_function(self, df, t):
return {idx: self.transform_row_(row, t) for idx, row in df.iterrows()}
......
......@@ -161,6 +161,11 @@ class multiFFMFormatPandas:
return data_list
# 原生转化方法,不需要多进程
def native_transform(self, df):
t = df.dtypes.to_dict()
return pd.Series({idx: self.transform_row_(row, t) for idx, row in df.iterrows()})
# 下面这个方法不是这个类原有的方法,是新增的。目的是用来判断这个用户是不是在训练数据集中存在
def is_feature_index_exist(self, name):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment