Commit 14c098f2 authored by 张彦钊's avatar 张彦钊

修改模型嵌入层

parent d981e8a1
...@@ -98,19 +98,18 @@ def model_fn(features, labels, mode, params): ...@@ -98,19 +98,18 @@ def model_fn(features, labels, mode, params):
#------bulid weights------ #------bulid weights------
Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer()) Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer())
#------build feaure------- feat_ids = features['ids']
#{U-A-X-C不需要特殊处理的特征}
ucity_id = features['ucity_id'] # ucity_id = features['ucity_id']
clevel1_id = features['clevel1_id'] # clevel1_id = features['clevel1_id']
ccity_name = features['ccity_name'] # ccity_name = features['ccity_name']
device_type = features['device_type'] # device_type = features['device_type']
manufacturer = features['manufacturer'] # manufacturer = features['manufacturer']
channel = features['channel'] # channel = features['channel']
top = features['top'] # top = features['top']
level2_ids = features['level2_ids'] # level2_ids = features['level2_ids']
time = features['time'] # time = features['time']
stat_date = features['stat_date'] # stat_date = features['stat_date']
if FLAGS.task_type != "infer": if FLAGS.task_type != "infer":
...@@ -119,20 +118,9 @@ def model_fn(features, labels, mode, params): ...@@ -119,20 +118,9 @@ def model_fn(features, labels, mode, params):
#------build f(x)------ #------build f(x)------
with tf.variable_scope("Shared-Embedding-layer"): with tf.variable_scope("Shared-Embedding-layer"):
ucity_id = tf.nn.embedding_lookup(Feat_Emb, ucity_id) embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
clevel1_id = tf.nn.embedding_lookup(Feat_Emb, clevel1_id)
ccity_name = tf.nn.embedding_lookup(Feat_Emb, ccity_name) x_concat = tf.concat(tf.reshape(embedding_id,shape=[-1, common_dims],axis=1)) # None * (F * K)
device_type = tf.nn.embedding_lookup(Feat_Emb, device_type)
manufacturer = tf.nn.embedding_lookup(Feat_Emb, manufacturer)
channel = tf.nn.embedding_lookup(Feat_Emb, channel)
top = tf.nn.embedding_lookup(Feat_Emb, top)
level2_ids = tf.nn.embedding_lookup(Feat_Emb, level2_ids)
time = tf.nn.embedding_lookup(Feat_Emb, time)
stat_date = tf.nn.embedding_lookup(Feat_Emb, stat_date)
x_concat = tf.concat([ucity_id,clevel1_id,ccity_name,device_type,manufacturer,
channel,top,level2_ids,time,stat_date],axis=1) # None * (F * K)
with tf.name_scope("CVR_Task"): with tf.name_scope("CVR_Task"):
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment