Commit 496f925a authored by Your Name's avatar Your Name

predict add sample id

parent b038c9ff
......@@ -38,6 +38,7 @@ def model_fn(features, labels, mode, params):
tag5_list = features['tag5_list']
tag6_list = features['tag6_list']
tag7_list = features['tag7_list']
number = features['number']
#------build f(x)------
......@@ -58,6 +59,8 @@ def model_fn(features, labels, mode, params):
x_concat = tf.concat([tf.reshape(embedding_id, shape=[-1, common_dims]), app_id, level2, level3, tag1,
tag2, tag3, tag4, tag5, tag6, tag7], axis=1)
sample_id = tf.sparse.to_dense(number)
with tf.name_scope("CVR_Task"):
if mode == tf.estimator.ModeKeys.TRAIN:
train_phase = True
......@@ -90,7 +93,7 @@ def model_fn(features, labels, mode, params):
pcvr = tf.sigmoid(y_cvr)
pctcvr = pctr*pcvr
predictions={"pcvr": pcvr, "pctr": pctr, "pctcvr": pctcvr}
predictions={"pcvr": pcvr, "pctr": pctr, "pctcvr": pctcvr, "sample_id": sample_id}
export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
# Provide an estimator spec for `ModeKeys.PREDICT`
if mode == tf.estimator.ModeKeys.PREDICT:
......
......@@ -318,9 +318,10 @@ def main(_):
FLAGS.model_dir = FLAGS.model_dir + FLAGS.dt_dir
#FLAGS.data_dir = FLAGS.data_dir + FLAGS.dt_dir
tr_files = ["hdfs://172.16.32.4:8020/strategy/esmm/tr/part-r-00000"]
tr_files = ["hdfs://172.16.32.4:8020/strategy/esmm/test_tr/part-r-00000"]
va_files = ["hdfs://172.16.32.4:8020/strategy/esmm/va/part-r-00000"]
te_files = ["%s/part-r-00000" % FLAGS.hdfs_dir]
# te_files = ["%s/part-r-00000" % FLAGS.hdfs_dir]
te_files = ["hdfs://172.16.32.4:8020/strategy/esmmtest_nearby/part-r-00000"]
if FLAGS.clear_existing_model:
try:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment