Commit 6885e903 authored by 赵威's avatar 赵威

update function name

parent 17b5c5c8
...@@ -30,7 +30,7 @@ def read_csv_data(dataset_path): ...@@ -30,7 +30,7 @@ def read_csv_data(dataset_path):
return tractate_df, click_df, conversion_df return tractate_df, click_df, conversion_df
def get_tractate_from_redis(): def get_tractate_dict_from_redis():
""" """
return: {tractate_id: {first_demands: [], is_pure_author: 1}} return: {tractate_id: {first_demands: [], is_pure_author: 1}}
""" """
......
...@@ -65,18 +65,18 @@ def main(): ...@@ -65,18 +65,18 @@ def main():
predict_fn = tf.contrib.predictor.from_saved_model(save_path) predict_fn = tf.contrib.predictor.from_saved_model(save_path)
device_dict = device_fe.get_device_dict_from_redis() device_dict = device_fe.get_device_dict_from_redis()
diary_dict = tractate_fe.get_tractate_dict_from_redis() tractate_dict = tractate_fe.get_tractate_dict_from_redis()
device_ids = list(device_dict.keys())[:20] device_ids = list(device_dict.keys())[:20]
diary_ids = list(diary_dict.keys()) tractate_ids = list(tractate_dict.keys())
print(device_dict(device_ids[0]), "\n") print(device_dict(device_ids[0]), "\n")
print(diary_dict(diary_ids[0]), "\n") print(tractate_dict(tractate_ids[0]), "\n")
for i in range(5): for i in range(5):
time_1 = timeit.default_timer() time_1 = timeit.default_timer()
res = model_predict_tractate( res = model_predict_tractate(
random.sample(device_ids, 1)[0], random.sample(diary_ids, 200), device_dict, diary_dict, predict_fn) random.sample(device_ids, 1)[0], random.sample(tractate_ids, 200), device_dict, tractate_dict, predict_fn)
print(res[:10]) print(res[:10])
total_1 = (timeit.default_timer() - time_1) total_1 = (timeit.default_timer() - time_1)
print("total prediction cost {:.5f}s".format(total_1), "\n") print("total prediction cost {:.5f}s".format(total_1), "\n")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment