Commit a67b6780 authored by Your Name's avatar Your Name

bug fix

parent 7c5e1ade
......@@ -160,14 +160,15 @@ def main(te_file):
Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config)
preds = Estimator.predict(input_fn=lambda: input_fn(te_file, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
# indices = []
# for prob in preds:
# indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
# return indices
with open("/home/gmuser/esmm/nearby/pred.txt", "w") as fo:
for prob in preds:
fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr']))
return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
indices = []
for prob in preds:
indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
return indices
def test_map(x):
return x * x
......@@ -195,10 +196,10 @@ if __name__ == "__main__":
print(test.collect())
tf.logging.set_verbosity(tf.logging.INFO)
te_files = [["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]]
te_files = [["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"],["hdfs://172.16.32.4:8020/strategy/esmm/native/part-r-00000"]]
rdd_te_files = spark.sparkContext.parallelize(te_files)
now_time = rdd_te_files.map(lambda x: main(x))
print(now_time)
indices = rdd_te_files.repartition(2).map(lambda x: main(x))
print(indices.collect())
# main(te_files)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment