Commit c21cc9dc authored by Your Name's avatar Your Name

predict add sample id

parent 496f925a
......@@ -163,7 +163,7 @@ def main(te_file):
log_step_count_steps=100, save_summary_steps=100)
Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config)
preds = Estimator.predict(input_fn=lambda: input_fn(te_file, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
preds = Estimator.predict(input_fn=lambda: input_fn(te_file, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr","sample_id"])
# with open("/home/gmuser/esmm/nearby/pred.txt", "w") as fo:
# for prob in preds:
......@@ -171,7 +171,7 @@ def main(te_file):
ctcvr = []
for prob in preds:
ctcvr.append(prob['pctcvr'])
ctcvr.append(prob["sample_id"],prob['pctcvr'])
return ctcvr
# indices = []
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment