Commit 14c098f2 authored by 张彦钊's avatar 张彦钊

修改模型嵌入层

parent d981e8a1
......@@ -98,19 +98,18 @@ def model_fn(features, labels, mode, params):
#------bulid weights------
Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer())
#------build feaure-------
#{U-A-X-C不需要特殊处理的特征}
feat_ids = features['ids']
ucity_id = features['ucity_id']
clevel1_id = features['clevel1_id']
ccity_name = features['ccity_name']
device_type = features['device_type']
manufacturer = features['manufacturer']
channel = features['channel']
top = features['top']
level2_ids = features['level2_ids']
time = features['time']
stat_date = features['stat_date']
# ucity_id = features['ucity_id']
# clevel1_id = features['clevel1_id']
# ccity_name = features['ccity_name']
# device_type = features['device_type']
# manufacturer = features['manufacturer']
# channel = features['channel']
# top = features['top']
# level2_ids = features['level2_ids']
# time = features['time']
# stat_date = features['stat_date']
if FLAGS.task_type != "infer":
......@@ -119,20 +118,9 @@ def model_fn(features, labels, mode, params):
#------build f(x)------
with tf.variable_scope("Shared-Embedding-layer"):
ucity_id = tf.nn.embedding_lookup(Feat_Emb, ucity_id)
clevel1_id = tf.nn.embedding_lookup(Feat_Emb, clevel1_id)
ccity_name = tf.nn.embedding_lookup(Feat_Emb, ccity_name)
device_type = tf.nn.embedding_lookup(Feat_Emb, device_type)
manufacturer = tf.nn.embedding_lookup(Feat_Emb, manufacturer)
channel = tf.nn.embedding_lookup(Feat_Emb, channel)
top = tf.nn.embedding_lookup(Feat_Emb, top)
level2_ids = tf.nn.embedding_lookup(Feat_Emb, level2_ids)
time = tf.nn.embedding_lookup(Feat_Emb, time)
stat_date = tf.nn.embedding_lookup(Feat_Emb, stat_date)
x_concat = tf.concat([ucity_id,clevel1_id,ccity_name,device_type,manufacturer,
channel,top,level2_ids,time,stat_date],axis=1) # None * (F * K)
embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
x_concat = tf.concat(tf.reshape(embedding_id,shape=[-1, common_dims],axis=1)) # None * (F * K)
with tf.name_scope("CVR_Task"):
if mode == tf.estimator.ModeKeys.TRAIN:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment