From a2937aedb82b403956316bb4e08e47b63e7f45f9 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?=E5=BC=A0=E5=BD=A6=E9=92=8A?= <zhangyanzhao@igengmei.com>
Date: Mon, 25 Mar 2019 14:32:41 +0800
Subject: [PATCH] =?UTF-8?q?=E9=A2=84=E6=B5=8B=E9=9B=86=E5=A2=9E=E5=8A=A0?=
 =?UTF-8?q?=E5=BA=94=E7=94=A8=E5=88=97=E8=A1=A8=E7=89=B9=E5=BE=81?=
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

---
 tensnsorflow/es/train.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/tensnsorflow/es/train.py b/tensnsorflow/es/train.py
index 3b3ce726..b72929cd 100644
--- a/tensnsorflow/es/train.py
+++ b/tensnsorflow/es/train.py
@@ -112,7 +112,7 @@ def model_fn(features, labels, mode, params):
         app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum")
 
         # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims])  # None * (F * K)
-        x_concat = tf.concat([embedding_id,app_id], axis=1)
+        x_concat = tf.concat([tf.reshape(embedding_id,shape=[-1, common_dims]),app_id], axis=1)
 
     with tf.name_scope("CVR_Task"):
         if mode == tf.estimator.ModeKeys.TRAIN:
-- 
2.18.0