diff --git a/eda/esmm/Model_pipline/dist_predict.py b/eda/esmm/Model_pipline/dist_predict.py index 9e042264dca5d1448fe2ce7e47dacb51e23b0521..828206bce65606269220e7ffaaa29a8896df15b7 100644 --- a/eda/esmm/Model_pipline/dist_predict.py +++ b/eda/esmm/Model_pipline/dist_predict.py @@ -159,7 +159,7 @@ def main(_): } config = tf.estimator.RunConfig().replace(session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':36}), log_step_count_steps=100, save_summary_steps=100) - Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config) + Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/", params=model_params, config=config) preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"]) # indices = []