Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
F
ffm-baseline
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ML
ffm-baseline
Commits
aed9fbd2
Commit
aed9fbd2
authored
Jun 24, 2019
by
Your Name
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
test
parent
11e3e4ca
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
227 additions
and
82 deletions
+227
-82
train.py
eda/esmm/Model_pipline/train.py
+227
-82
No files found.
eda/esmm/Model_pipline/train.py
View file @
aed9fbd2
...
...
@@ -49,13 +49,232 @@ tf.app.flags.DEFINE_string("task_type", 'train', "task type {train, infer, eval,
tf
.
app
.
flags
.
DEFINE_boolean
(
"clear_existing_model"
,
False
,
"clear existing model or not"
)
#40362692,0,0,216:9342395:1.0 301:9351665:1.0 205:7702673:1.0 206:8317829:1.0 207:8967741:1.0 508:9356012:2.30259 210:9059239:1.0 210:9042796:1.0 210:9076972:1.0 210:9103884:1.0 210:9063064:1.0 127_14:3529789:2.3979 127_14:3806412:2.70805
# def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
# print('Parsing', filenames)
# def _parse_fn(record):
# features = {
# "y": tf.FixedLenFeature([], tf.float32),
# "z": tf.FixedLenFeature([], tf.float32),
# "ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64),
# "app_list": tf.VarLenFeature(tf.int64),
# "level2_list": tf.VarLenFeature(tf.int64),
# "level3_list": tf.VarLenFeature(tf.int64),
# "tag1_list": tf.VarLenFeature(tf.int64),
# "tag2_list": tf.VarLenFeature(tf.int64),
# "tag3_list": tf.VarLenFeature(tf.int64),
# "tag4_list": tf.VarLenFeature(tf.int64),
# "tag5_list": tf.VarLenFeature(tf.int64),
# "tag6_list": tf.VarLenFeature(tf.int64),
# "tag7_list": tf.VarLenFeature(tf.int64),
# "number": tf.VarLenFeature(tf.int64),
# "uid": tf.VarLenFeature(tf.string),
# "city": tf.VarLenFeature(tf.string),
# "cid_id": tf.VarLenFeature(tf.string)
# }
# parsed = tf.parse_single_example(record, features)
# y = parsed.pop('y')
# z = parsed.pop('z')
# return parsed, {"y": y, "z": z}
#
# # Extract lines from input files using the Dataset API, can pass one filename or filename list
# # dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8).prefetch(500000) # multi-thread pre-process then prefetch
#
# # Randomizes input using a window of 256 elements (read into memory)
# # if perform_shuffle:
# # dataset = dataset.shuffle(buffer_size=256)
#
# # epochs from blending together.
# # dataset = dataset.repeat(num_epochs)
# # dataset = dataset.batch(batch_size) # Batch size to use
#
# files = tf.data.Dataset.list_files(filenames)
# dataset = files.apply(
# tf.data.experimental.parallel_interleave(
# lambda file: tf.data.TFRecordDataset(file),
# cycle_length=8
# )
# )
#
# dataset = dataset.apply(tf.data.experimental.map_and_batch(map_func=_parse_fn, batch_size=batch_size, num_parallel_calls=8))
# dataset = dataset.prefetch(10000)
#
#
# # dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
# #return dataset.make_one_shot_iterator()
# iterator = dataset.make_one_shot_iterator()
# batch_features, batch_labels = iterator.get_next()
# #return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
# #print("-"*100)
# #print(batch_features,batch_labels)
# return batch_features, batch_labels
# def model_fn(features, labels, mode, params):
# """Bulid Model function f(x) for Estimator."""
# #------hyperparameters----
# field_size = params["field_size"]
# feature_size = params["feature_size"]
# embedding_size = params["embedding_size"]
# l2_reg = params["l2_reg"]
# learning_rate = params["learning_rate"]
# #optimizer = params["optimizer"]
# layers = list(map(int, params["deep_layers"].split(',')))
# dropout = list(map(float, params["dropout"].split(',')))
# ctr_task_wgt = params["ctr_task_wgt"]
# common_dims = field_size*embedding_size
#
# #------bulid weights------
# Feat_Emb = tf.get_variable(name='embeddings', shape=[feature_size, embedding_size], initializer=tf.glorot_normal_initializer())
#
# feat_ids = features['ids']
# app_list = features['app_list']
# level2_list = features['level2_list']
# level3_list = features['level3_list']
# tag1_list = features['tag1_list']
# tag2_list = features['tag2_list']
# tag3_list = features['tag3_list']
# tag4_list = features['tag4_list']
# tag5_list = features['tag5_list']
# tag6_list = features['tag6_list']
# tag7_list = features['tag7_list']
# number = features['number']
# uid = features['uid']
# city = features['city']
# cid_id = features['cid_id']
#
# if FLAGS.task_type != "infer":
# y = labels['y']
# z = labels['z']
#
# #------build f(x)------
# with tf.variable_scope("Shared-Embedding-layer"):
# embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
# app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum")
# level2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level2_list, sp_weights=None, combiner="sum")
# level3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level3_list, sp_weights=None, combiner="sum")
# tag1 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag1_list, sp_weights=None, combiner="sum")
# tag2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag2_list, sp_weights=None, combiner="sum")
# tag3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag3_list, sp_weights=None, combiner="sum")
# tag4 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag4_list, sp_weights=None, combiner="sum")
# tag5 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag5_list, sp_weights=None, combiner="sum")
# tag6 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag6_list, sp_weights=None, combiner="sum")
# tag7 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag7_list, sp_weights=None, combiner="sum")
#
# # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims]) # None * (F * K)
# x_concat = tf.concat([tf.reshape(embedding_id, shape=[-1, common_dims]), app_id, level2, level3, tag1,
# tag2, tag3, tag4, tag5, tag6, tag7], axis=1)
#
# sample_id = tf.sparse.to_dense(number)
# uid = tf.sparse.to_dense(uid,default_value="")
# city = tf.sparse.to_dense(city,default_value="")
# cid_id = tf.sparse.to_dense(cid_id,default_value="")
#
# with tf.name_scope("CVR_Task"):
# if mode == tf.estimator.ModeKeys.TRAIN:
# train_phase = True
# else:
# train_phase = False
# x_cvr = x_concat
# for i in range(len(layers)):
# x_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=layers[i], \
# weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_mlp%d' % i)
#
# if FLAGS.batch_norm:
# x_cvr = batch_norm_layer(x_cvr, train_phase=train_phase, scope_bn='cvr_bn_%d' %i) #放在RELU之后 https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md#bn----before-or-after-relu
# if mode == tf.estimator.ModeKeys.TRAIN:
# x_cvr = tf.nn.dropout(x_cvr, keep_prob=dropout[i]) #Apply Dropout after all BN layers and set dropout=0.8(drop_ratio=0.2)
#
# y_cvr = tf.contrib.layers.fully_connected(inputs=x_cvr, num_outputs=1, activation_fn=tf.identity, \
# weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='cvr_out')
# y_cvr = tf.reshape(y_cvr,shape=[-1])
#
# with tf.name_scope("CTR_Task"):
# if mode == tf.estimator.ModeKeys.TRAIN:
# train_phase = True
# else:
# train_phase = False
#
# x_ctr = x_concat
# for i in range(len(layers)):
# x_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=layers[i], \
# weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_mlp%d' % i)
#
# if FLAGS.batch_norm:
# x_ctr = batch_norm_layer(x_ctr, train_phase=train_phase, scope_bn='ctr_bn_%d' %i) #放在RELU之后 https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md#bn----before-or-after-relu
# if mode == tf.estimator.ModeKeys.TRAIN:
# x_ctr = tf.nn.dropout(x_ctr, keep_prob=dropout[i]) #Apply Dropout after all BN layers and set dropout=0.8(drop_ratio=0.2)
#
# y_ctr = tf.contrib.layers.fully_connected(inputs=x_ctr, num_outputs=1, activation_fn=tf.identity, \
# weights_regularizer=tf.contrib.layers.l2_regularizer(l2_reg), scope='ctr_out')
# y_ctr = tf.reshape(y_ctr,shape=[-1])
#
# with tf.variable_scope("MTL-Layer"):
# pctr = tf.sigmoid(y_ctr)
# pcvr = tf.sigmoid(y_cvr)
# pctcvr = pctr*pcvr
#
#
# predictions={"pctcvr": pctcvr, "sample_id": sample_id, "uid":uid, "city":city, "cid_id":cid_id}
# export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
# # Provide an estimator spec for `ModeKeys.PREDICT`
# if mode == tf.estimator.ModeKeys.PREDICT:
# return tf.estimator.EstimatorSpec(
# mode=mode,
# predictions=predictions,
# export_outputs=export_outputs)
#
# if FLAGS.task_type != "infer":
# #------bulid loss------
# ctr_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_ctr, labels=y))
# #cvr_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_ctcvr, labels=z))
# cvr_loss = tf.reduce_mean(tf.losses.log_loss(predictions=pctcvr, labels=z))
# loss = ctr_task_wgt * ctr_loss + (1 -ctr_task_wgt) * cvr_loss + l2_reg * tf.nn.l2_loss(Feat_Emb)
#
# tf.summary.scalar('ctr_loss', ctr_loss)
# tf.summary.scalar('cvr_loss', cvr_loss)
#
# # Provide an estimator spec for `ModeKeys.EVAL`
# eval_metric_ops = {
# # "CTR_AUC": tf.metrics.auc(y, pctr),
# #"CTR_F1": tf.contrib.metrics.f1_score(y,pctr),
# #"CTR_Precision": tf.metrics.precision(y,pctr),
# #"CTR_Recall": tf.metrics.recall(y,pctr),
# # "CVR_AUC": tf.metrics.auc(z, pcvr),
# "CTCVR_AUC": tf.metrics.auc(z, pctcvr)
# }
# if mode == tf.estimator.ModeKeys.EVAL:
# return tf.estimator.EstimatorSpec(
# mode=mode,
# predictions=predictions,
# loss=loss,
# eval_metric_ops=eval_metric_ops)
#
# #------bulid optimizer------
# if FLAGS.optimizer == 'Adam':
# optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8)
# elif FLAGS.optimizer == 'Adagrad':
# optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate, initial_accumulator_value=1e-8)
# elif FLAGS.optimizer == 'Momentum':
# optimizer = tf.train.MomentumOptimizer(learning_rate=learning_rate, momentum=0.95)
# elif FLAGS.optimizer == 'ftrl':
# optimizer = tf.train.FtrlOptimizer(learning_rate)
#
# train_op = optimizer.minimize(loss, global_step=tf.train.get_global_step())
#
# # Provide an estimator spec for `ModeKeys.TRAIN` modes
# if mode == tf.estimator.ModeKeys.TRAIN:
# return tf.estimator.EstimatorSpec(
# mode=mode,
# predictions=predictions,
# loss=loss,
# train_op=train_op)
def
input_fn
(
filenames
,
batch_size
=
32
,
num_epochs
=
1
,
perform_shuffle
=
False
):
print
(
'Parsing'
,
filenames
)
def
_parse_fn
(
record
):
features
=
{
"y"
:
tf
.
FixedLenFeature
([],
tf
.
float32
),
"z"
:
tf
.
FixedLenFeature
([],
tf
.
float32
),
"ids"
:
tf
.
FixedLenFeature
([
FLAGS
.
field_size
],
tf
.
int64
),
"ids"
:
tf
.
FixedLenFeature
([
15
],
tf
.
int64
),
"app_list"
:
tf
.
VarLenFeature
(
tf
.
int64
),
"level2_list"
:
tf
.
VarLenFeature
(
tf
.
int64
),
"level3_list"
:
tf
.
VarLenFeature
(
tf
.
int64
),
...
...
@@ -77,29 +296,17 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
return
parsed
,
{
"y"
:
y
,
"z"
:
z
}
# Extract lines from input files using the Dataset API, can pass one filename or filename list
# dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8
).prefetch(500000) # multi-thread pre-process then prefetch
dataset
=
tf
.
data
.
TFRecordDataset
(
filenames
)
.
map
(
_parse_fn
,
num_parallel_calls
=
10
)
.
prefetch
(
500000
)
# multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
#
if perform_shuffle:
#
dataset = dataset.shuffle(buffer_size=256)
if
perform_shuffle
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
256
)
# epochs from blending together.
# dataset = dataset.repeat(num_epochs)
# dataset = dataset.batch(batch_size) # Batch size to use
files
=
tf
.
data
.
Dataset
.
list_files
(
filenames
)
dataset
=
files
.
apply
(
tf
.
data
.
experimental
.
parallel_interleave
(
lambda
file
:
tf
.
data
.
TFRecordDataset
(
file
),
cycle_length
=
8
)
)
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
map_and_batch
(
map_func
=
_parse_fn
,
batch_size
=
batch_size
,
num_parallel_calls
=
8
))
dataset
=
dataset
.
prefetch
(
10000
)
dataset
=
dataset
.
repeat
(
num_epochs
)
dataset
=
dataset
.
batch
(
batch_size
)
# Batch size to use
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
#return dataset.make_one_shot_iterator()
iterator
=
dataset
.
make_one_shot_iterator
()
batch_features
,
batch_labels
=
iterator
.
get_next
()
...
...
@@ -141,9 +348,6 @@ def model_fn(features, labels, mode, params):
city
=
features
[
'city'
]
cid_id
=
features
[
'cid_id'
]
if
FLAGS
.
task_type
!=
"infer"
:
y
=
labels
[
'y'
]
z
=
labels
[
'z'
]
#------build f(x)------
with
tf
.
variable_scope
(
"Shared-Embedding-layer"
):
...
...
@@ -177,12 +381,6 @@ def model_fn(features, labels, mode, params):
for
i
in
range
(
len
(
layers
)):
x_cvr
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
x_cvr
,
num_outputs
=
layers
[
i
],
\
weights_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
(
l2_reg
),
scope
=
'cvr_mlp
%
d'
%
i
)
if
FLAGS
.
batch_norm
:
x_cvr
=
batch_norm_layer
(
x_cvr
,
train_phase
=
train_phase
,
scope_bn
=
'cvr_bn_
%
d'
%
i
)
#放在RELU之后 https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md#bn----before-or-after-relu
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
x_cvr
=
tf
.
nn
.
dropout
(
x_cvr
,
keep_prob
=
dropout
[
i
])
#Apply Dropout after all BN layers and set dropout=0.8(drop_ratio=0.2)
y_cvr
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
x_cvr
,
num_outputs
=
1
,
activation_fn
=
tf
.
identity
,
\
weights_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
(
l2_reg
),
scope
=
'cvr_out'
)
y_cvr
=
tf
.
reshape
(
y_cvr
,
shape
=
[
-
1
])
...
...
@@ -197,12 +395,6 @@ def model_fn(features, labels, mode, params):
for
i
in
range
(
len
(
layers
)):
x_ctr
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
x_ctr
,
num_outputs
=
layers
[
i
],
\
weights_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
(
l2_reg
),
scope
=
'ctr_mlp
%
d'
%
i
)
if
FLAGS
.
batch_norm
:
x_ctr
=
batch_norm_layer
(
x_ctr
,
train_phase
=
train_phase
,
scope_bn
=
'ctr_bn_
%
d'
%
i
)
#放在RELU之后 https://github.com/ducha-aiki/caffenet-benchmark/blob/master/batchnorm.md#bn----before-or-after-relu
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
x_ctr
=
tf
.
nn
.
dropout
(
x_ctr
,
keep_prob
=
dropout
[
i
])
#Apply Dropout after all BN layers and set dropout=0.8(drop_ratio=0.2)
y_ctr
=
tf
.
contrib
.
layers
.
fully_connected
(
inputs
=
x_ctr
,
num_outputs
=
1
,
activation_fn
=
tf
.
identity
,
\
weights_regularizer
=
tf
.
contrib
.
layers
.
l2_regularizer
(
l2_reg
),
scope
=
'ctr_out'
)
y_ctr
=
tf
.
reshape
(
y_ctr
,
shape
=
[
-
1
])
...
...
@@ -212,8 +404,7 @@ def model_fn(features, labels, mode, params):
pcvr
=
tf
.
sigmoid
(
y_cvr
)
pctcvr
=
pctr
*
pcvr
predictions
=
{
"pctcvr"
:
pctcvr
,
"sample_id"
:
sample_id
,
"uid"
:
uid
,
"city"
:
city
,
"cid_id"
:
cid_id
}
predictions
=
{
"pcvr"
:
pcvr
,
"pctr"
:
pctr
,
"pctcvr"
:
pctcvr
,
"sample_id"
:
sample_id
,
"uid"
:
uid
,
"city"
:
city
,
"cid_id"
:
cid_id
}
export_outputs
=
{
tf
.
saved_model
.
signature_constants
.
DEFAULT_SERVING_SIGNATURE_DEF_KEY
:
tf
.
estimator
.
export
.
PredictOutput
(
predictions
)}
# Provide an estimator spec for `ModeKeys.PREDICT`
if
mode
==
tf
.
estimator
.
ModeKeys
.
PREDICT
:
...
...
@@ -222,52 +413,6 @@ def model_fn(features, labels, mode, params):
predictions
=
predictions
,
export_outputs
=
export_outputs
)
if
FLAGS
.
task_type
!=
"infer"
:
#------bulid loss------
ctr_loss
=
tf
.
reduce_mean
(
tf
.
nn
.
sigmoid_cross_entropy_with_logits
(
logits
=
y_ctr
,
labels
=
y
))
#cvr_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=y_ctcvr, labels=z))
cvr_loss
=
tf
.
reduce_mean
(
tf
.
losses
.
log_loss
(
predictions
=
pctcvr
,
labels
=
z
))
loss
=
ctr_task_wgt
*
ctr_loss
+
(
1
-
ctr_task_wgt
)
*
cvr_loss
+
l2_reg
*
tf
.
nn
.
l2_loss
(
Feat_Emb
)
tf
.
summary
.
scalar
(
'ctr_loss'
,
ctr_loss
)
tf
.
summary
.
scalar
(
'cvr_loss'
,
cvr_loss
)
# Provide an estimator spec for `ModeKeys.EVAL`
eval_metric_ops
=
{
# "CTR_AUC": tf.metrics.auc(y, pctr),
#"CTR_F1": tf.contrib.metrics.f1_score(y,pctr),
#"CTR_Precision": tf.metrics.precision(y,pctr),
#"CTR_Recall": tf.metrics.recall(y,pctr),
# "CVR_AUC": tf.metrics.auc(z, pcvr),
"CTCVR_AUC"
:
tf
.
metrics
.
auc
(
z
,
pctcvr
)
}
if
mode
==
tf
.
estimator
.
ModeKeys
.
EVAL
:
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
loss
=
loss
,
eval_metric_ops
=
eval_metric_ops
)
#------bulid optimizer------
if
FLAGS
.
optimizer
==
'Adam'
:
optimizer
=
tf
.
train
.
AdamOptimizer
(
learning_rate
=
learning_rate
,
beta1
=
0.9
,
beta2
=
0.999
,
epsilon
=
1e-8
)
elif
FLAGS
.
optimizer
==
'Adagrad'
:
optimizer
=
tf
.
train
.
AdagradOptimizer
(
learning_rate
=
learning_rate
,
initial_accumulator_value
=
1e-8
)
elif
FLAGS
.
optimizer
==
'Momentum'
:
optimizer
=
tf
.
train
.
MomentumOptimizer
(
learning_rate
=
learning_rate
,
momentum
=
0.95
)
elif
FLAGS
.
optimizer
==
'ftrl'
:
optimizer
=
tf
.
train
.
FtrlOptimizer
(
learning_rate
)
train_op
=
optimizer
.
minimize
(
loss
,
global_step
=
tf
.
train
.
get_global_step
())
# Provide an estimator spec for `ModeKeys.TRAIN` modes
if
mode
==
tf
.
estimator
.
ModeKeys
.
TRAIN
:
return
tf
.
estimator
.
EstimatorSpec
(
mode
=
mode
,
predictions
=
predictions
,
loss
=
loss
,
train_op
=
train_op
)
def
batch_norm_layer
(
x
,
train_phase
,
scope_bn
):
bn_train
=
tf
.
contrib
.
layers
.
batch_norm
(
x
,
decay
=
FLAGS
.
batch_norm_decay
,
center
=
True
,
scale
=
True
,
updates_collections
=
None
,
is_training
=
True
,
reuse
=
None
,
scope
=
scope_bn
)
bn_infer
=
tf
.
contrib
.
layers
.
batch_norm
(
x
,
decay
=
FLAGS
.
batch_norm_decay
,
center
=
True
,
scale
=
True
,
updates_collections
=
None
,
is_training
=
False
,
reuse
=
True
,
scope
=
scope_bn
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment