From a2937aedb82b403956316bb4e08e47b63e7f45f9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BC=A0=E5=BD=A6=E9=92=8A?= <zhangyanzhao@igengmei.com> Date: Mon, 25 Mar 2019 14:32:41 +0800 Subject: [PATCH] =?UTF-8?q?=E9=A2=84=E6=B5=8B=E9=9B=86=E5=A2=9E=E5=8A=A0?= =?UTF-8?q?=E5=BA=94=E7=94=A8=E5=88=97=E8=A1=A8=E7=89=B9=E5=BE=81?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tensnsorflow/es/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensnsorflow/es/train.py b/tensnsorflow/es/train.py index 3b3ce726..b72929cd 100644 --- a/tensnsorflow/es/train.py +++ b/tensnsorflow/es/train.py @@ -112,7 +112,7 @@ def model_fn(features, labels, mode, params): app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum") # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims]) # None * (F * K) - x_concat = tf.concat([embedding_id,app_id], axis=1) + x_concat = tf.concat([tf.reshape(embedding_id,shape=[-1, common_dims]),app_id], axis=1) with tf.name_scope("CVR_Task"): if mode == tf.estimator.ModeKeys.TRAIN: -- 2.18.0