Commit 5c4d9390 authored by Your Name's avatar Your Name

bug fix

parent 86dd0ffd
......@@ -141,10 +141,10 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
#print(batch_features,batch_labels)
return batch_features, batch_labels
def main(te_file):
def main():
# dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/"
# te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
model_params = {
"field_size": 15,
"feature_size": 600000,
......@@ -159,7 +159,7 @@ def main(te_file):
log_step_count_steps=100, save_summary_steps=100)
Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config)
preds = Estimator.predict(input_fn=lambda: input_fn(te_file, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
# with open("/home/gmuser/esmm/nearby/pred.txt", "w") as fo:
# for prob in preds:
......@@ -202,7 +202,7 @@ if __name__ == "__main__":
# indices = rdd_te_files.repartition(2).map(lambda x: main(x))
# print(indices.collect())
main(te_files)
main()
b = time.time()
print("耗时(分钟):")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment