diff --git a/eda/esmm/Model_pipline/dist_predict.py b/eda/esmm/Model_pipline/dist_predict.py index 683f9b24ff6e61be20050acc2914c98a1a68ca23..fc7ecb953265d9bc51466cf39e35568eba5156fd 100644 --- a/eda/esmm/Model_pipline/dist_predict.py +++ b/eda/esmm/Model_pipline/dist_predict.py @@ -141,7 +141,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): #print(batch_features,batch_labels) return batch_features, batch_labels -def esmm_predict(dist_data): +def main(_): dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d') model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/" + dt_dir te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"] @@ -157,13 +157,17 @@ def esmm_predict(dist_data): } config = tf.estimator.RunConfig().replace(session_config = tf.ConfigProto(device_count={'GPU':0, 'CPU':36}), log_step_count_steps=100, save_summary_steps=100) - Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/", params=model_params, config=config) + Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=model_dir, params=model_params, config=config) + + preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"]) + # indices = [] + # for prob in preds: + # indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']]) + # return indices + with open("/home/gmuser/esmm/nearby/pred.txt", "w") as fo: + for prob in preds: + fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr'])) - preds = Estimator.predict(input_fn=lambda: input_fn(dist_data, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"]) - indices = [] - for prob in preds: - indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']]) - return indices @@ -184,5 +188,7 @@ if __name__ == "__main__": df.show() b = time.time() + tf.logging.set_verbosity(tf.logging.INFO) + tf.app.run() print("耗时(分钟):") - print((time.time()-b)/60) \ No newline at end of file + print((time.time()-b)/60)