Commit 2b8f52b9 authored by 赵威's avatar 赵威

predict from file

parent 369d99a3
import os import os
import pickle
import random import random
import shutil import shutil
import time import time
...@@ -49,7 +50,6 @@ def main(): ...@@ -49,7 +50,6 @@ def main():
train_df, val_df = train_test_split(train_df, test_size=0.2) train_df, val_df = train_test_split(train_df, test_size=0.2)
# all_features = build_features(df) # all_features = build_features(df)
# params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1} # params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
# model_path = str(Path("~/data/model_tmp/").expanduser()) # model_path = str(Path("~/data/model_tmp/").expanduser())
# if os.path.exists(model_path): # if os.path.exists(model_path):
...@@ -66,8 +66,13 @@ def main(): ...@@ -66,8 +66,13 @@ def main():
# print("save to: " + save_path) # print("save to: " + save_path)
save_path = "/home/gmuser/data/models/1595317247" save_path = "/home/gmuser/data/models/1595317247"
# save_path = str(Path("~/Desktop/models/1595297428").expanduser()) # save_path = str(Path("~/Desktop/models/1595297428").expanduser())
predict_fn = tf.contrib.predictor.from_saved_model(save_path) filename = save_path + "/saved_model.pb"
# tf.saved_model.load
predict_fn = tf.contrib.predictor.from_saved_model(filename)
# for i in range(5): # for i in range(5):
# test_300 = test_df.sample(300) # test_300 = test_df.sample(300)
...@@ -79,13 +84,6 @@ def main(): ...@@ -79,13 +84,6 @@ def main():
# "16195283", "16838351", "17161073", "17297878", "17307484", "17396235", "16418737", "16995481", "17312201", "12237988" # "16195283", "16838351", "17161073", "17297878", "17307484", "17396235", "16418737", "16995481", "17312201", "12237988"
# ] # ]
# df = get_device_df_from_redis()
# df2 = get_diary_df_from_redis()
# redis_device_df = device_feature_engineering(df)
# redis_diary_df = diary_feature_engineering(df2, from_redis=True)
# device_ids = list(redis_device_df["device_id"].values)[:20]
# diary_ids = list(redis_diary_df["card_id"].values)
device_dict = get_device_dict_from_redis() device_dict = get_device_dict_from_redis()
diary_dict = get_diary_dict_from_redis() diary_dict = get_diary_dict_from_redis()
......
...@@ -132,8 +132,8 @@ def model_predict_diary(device_id, diary_ids, device_dict, diary_dict, predict_f ...@@ -132,8 +132,8 @@ def model_predict_diary(device_id, diary_ids, device_dict, diary_dict, predict_f
res_tuple = sorted(zip(diary_ids_res, predictions["output"].tolist()), key=lambda x: x[1], reverse=True) res_tuple = sorted(zip(diary_ids_res, predictions["output"].tolist()), key=lambda x: x[1], reverse=True)
res = [] res = []
for (id, _) in res_tuple: for (id, _) in res_tuple:
res.append(id) res.append(int(id))
print(res) # print(res)
total_1 = (timeit.default_timer() - time_1) total_1 = (timeit.default_timer() - time_1)
print("prediction cost {:.5f}s".format(total_1)) print("prediction cost {:.5f}s".format(total_1))
return res return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment