Commit baa58b4f authored by 赵威's avatar 赵威

try predict

parent c451d928
...@@ -38,8 +38,8 @@ def main(): ...@@ -38,8 +38,8 @@ def main():
all_features = build_features(df) all_features = build_features(df)
params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1} params = {"feature_columns": all_features, "hidden_units": [64, 32], "learning_rate": 0.1}
model_path = str(Path("~/data/model_tmp/").expanduser()) model_path = str(Path("~/data/model_tmp/").expanduser())
if os.path.exists(model_path): # if os.path.exists(model_path):
shutil.rmtree(model_path) # shutil.rmtree(model_path)
model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path) model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path)
train_spec = tf.estimator.TrainSpec(input_fn=lambda: esmm_input_fn(train_df, shuffle=True), max_steps=40000) train_spec = tf.estimator.TrainSpec(input_fn=lambda: esmm_input_fn(train_df, shuffle=True), max_steps=40000)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment