Commit 2f918780 authored by 赵威's avatar 赵威

update field

parent a01adb19
import os
import time
from datetime import datetime
from pathlib import Path
import shutil
import tensorflow as tf
from sklearn.model_selection import train_test_split
......@@ -28,6 +30,8 @@ def main():
params = {"feature_columns": all_features, "hidden_units": [32], "learning_rate": 0.1}
model_path = str(Path("~/data/model_tmp/").expanduser())
if os.path.exists(model_path):
shutil.rmtree(model_path)
model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path)
print("train")
......@@ -41,8 +45,9 @@ def main():
# predictions = model.predict(input_fn=lambda: esmm_input_fn(test_df, False))
# print(next(iter(predictions)))
test_300 = test_df.sample(300)
time_1 = time.time()
model_predict(test_df.sample(300), save_path)
model_predict(test_300, save_path)
total_1 = (time.time() - time_1)
print("prediction cost {:.5f} s at {}".format(total_1, datetime.now()))
......
......@@ -111,4 +111,5 @@ def model_predict(inputs, model_path):
example = tf.train.Example(features=tf.train.Features(feature=features))
examples.append(example.SerializeToString())
predictions = predict_fn({"examples": examples})
print(predictions)
# print(predictions)
return predictions
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment