Commit c451d928 authored by 赵威's avatar 赵威

add max_steps

parent af9bf67b
...@@ -42,7 +42,7 @@ def main(): ...@@ -42,7 +42,7 @@ def main():
shutil.rmtree(model_path) shutil.rmtree(model_path)
model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path) model = tf.estimator.Estimator(model_fn=esmm_model_fn, params=params, model_dir=model_path)
train_spec = tf.estimator.TrainSpec(input_fn=lambda: esmm_input_fn(train_df, shuffle=True)) train_spec = tf.estimator.TrainSpec(input_fn=lambda: esmm_input_fn(train_df, shuffle=True), max_steps=40000)
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: esmm_input_fn(val_df, shuffle=False)) eval_spec = tf.estimator.EvalSpec(input_fn=lambda: esmm_input_fn(val_df, shuffle=False))
tf.estimator.train_and_evaluate(model, train_spec, eval_spec) tf.estimator.train_and_evaluate(model, train_spec, eval_spec)
......
...@@ -58,7 +58,9 @@ def esmm_model_fn(features, labels, mode, params): ...@@ -58,7 +58,9 @@ def esmm_model_fn(features, labels, mode, params):
ctcvr_loss = tf.reduce_sum(tf.compat.v1.losses.log_loss(labels=cvr_labels, predictions=ctcvr_preds)) ctcvr_loss = tf.reduce_sum(tf.compat.v1.losses.log_loss(labels=cvr_labels, predictions=ctcvr_preds))
loss = ctr_loss + ctcvr_loss loss = ctr_loss + ctcvr_loss
ctr_accuracy = tf.compat.v1.metrics.accuracy(labels=ctr_labels, predictions=tf.to_float(tf.greater_equal(ctr_preds, 0.5))) if mode == tf.estimator.ModeKeys.EVAL:
ctr_accuracy = tf.compat.v1.metrics.accuracy(labels=ctr_labels,
predictions=tf.to_float(tf.greater_equal(ctr_preds, 0.5)))
ctcvr_accuracy = tf.compat.v1.metrics.accuracy(labels=cvr_labels, ctcvr_accuracy = tf.compat.v1.metrics.accuracy(labels=cvr_labels,
predictions=tf.to_float(tf.greater_equal(ctcvr_preds, 0.5))) predictions=tf.to_float(tf.greater_equal(ctcvr_preds, 0.5)))
ctr_auc = tf.compat.v1.metrics.auc(labels=ctr_labels, predictions=ctr_preds) ctr_auc = tf.compat.v1.metrics.auc(labels=ctr_labels, predictions=ctr_preds)
...@@ -68,7 +70,6 @@ def esmm_model_fn(features, labels, mode, params): ...@@ -68,7 +70,6 @@ def esmm_model_fn(features, labels, mode, params):
tf.compat.v1.summary.scalar("ctcvr_accuracy", ctcvr_accuracy[1]) tf.compat.v1.summary.scalar("ctcvr_accuracy", ctcvr_accuracy[1])
tf.compat.v1.summary.scalar("ctr_auc", ctr_auc[1]) tf.compat.v1.summary.scalar("ctr_auc", ctr_auc[1])
tf.compat.v1.summary.scalar("ctcvr_auc", ctcvr_auc[1]) tf.compat.v1.summary.scalar("ctcvr_auc", ctcvr_auc[1])
if mode == tf.estimator.ModeKeys.EVAL:
return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics) return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)
train_op = optimizer.minimize(loss, global_step=tf.compat.v1.train.get_global_step()) train_op = optimizer.minimize(loss, global_step=tf.compat.v1.train.get_global_step())
res = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics) res = tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op, eval_metric_ops=metrics)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment