diff --git a/eda/esmm/Model_pipline/dist_predict.py b/eda/esmm/Model_pipline/dist_predict.py
index 30d1f2679ee1162b949f8e310935ec16371b9c45..0b4ecaadfb6826f61d6005f549e6a8686ba2bcb8 100644
--- a/eda/esmm/Model_pipline/dist_predict.py
+++ b/eda/esmm/Model_pipline/dist_predict.py
@@ -141,10 +141,10 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
     #print(batch_features,batch_labels)
     return batch_features, batch_labels
 
-def main():
+def main(te_file):
     # dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
     model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/"
-    te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
+    # te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
     model_params = {
         "field_size": 15,
         "feature_size": 600000,
@@ -159,7 +159,7 @@ def main():
             log_step_count_steps=100, save_summary_steps=100)
     Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config)
 
-    preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
+    preds = Estimator.predict(input_fn=lambda: input_fn(te_file, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
     # indices = []
     # for prob in preds:
     #     indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
@@ -191,14 +191,12 @@ if __name__ == "__main__":
 
     test = name.repartition(5).map(lambda x: test_map(x))
     print(test)
-    test.collect()
+    print(test.collect())
 
+    tf.logging.set_verbosity(tf.logging.INFO)
     te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
-
+    main(te_files)
 
     b = time.time()
-    tf.logging.set_verbosity(tf.logging.INFO)
-    # tf.app.run()
-    # main()
     print("耗时(分钟):")
     print((time.time()-b)/60)