Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
F
ffm-baseline
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ML
ffm-baseline
Commits
4292bf5b
Commit
4292bf5b
authored
May 21, 2019
by
张彦钊
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
修改测试文件
parent
a3b3de0c
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
7 additions
and
7 deletions
+7
-7
train_multi.py
tensnsorflow/train_multi.py
+7
-7
No files found.
tensnsorflow/train_multi.py
View file @
4292bf5b
...
...
@@ -8,10 +8,7 @@
import
shutil
import
os
import
json
import
glob
from
datetime
import
date
,
timedelta
import
random
import
tensorflow
as
tf
#################### CMD Arguments ####################
...
...
@@ -37,7 +34,8 @@ tf.app.flags.DEFINE_string("deep_layers", '256,128,64', "deep layers")
tf
.
app
.
flags
.
DEFINE_string
(
"dropout"
,
'0.5,0.5,0.5'
,
"dropout rate"
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"batch_norm"
,
False
,
"perform batch normaization (True or False)"
)
tf
.
app
.
flags
.
DEFINE_float
(
"batch_norm_decay"
,
0.9
,
"decay for the moving average(recommend trying decay=0.9)"
)
tf
.
app
.
flags
.
DEFINE_string
(
"data_dir"
,
''
,
"data dir"
)
tf
.
app
.
flags
.
DEFINE_string
(
"hdfs_dir"
,
''
,
"hdfs dir"
)
tf
.
app
.
flags
.
DEFINE_string
(
"local_dir"
,
''
,
"local dir"
)
tf
.
app
.
flags
.
DEFINE_string
(
"dt_dir"
,
''
,
"data dt partition"
)
tf
.
app
.
flags
.
DEFINE_string
(
"model_dir"
,
''
,
"model check point dir"
)
tf
.
app
.
flags
.
DEFINE_string
(
"servable_model_dir"
,
''
,
"export servable model for TensorFlow Serving"
)
...
...
@@ -301,7 +299,8 @@ def main(_):
print
(
'task_type '
,
FLAGS
.
task_type
)
print
(
'model_dir '
,
FLAGS
.
model_dir
)
print
(
'data_dir '
,
FLAGS
.
data_dir
)
print
(
'hdfs_dir '
,
FLAGS
.
hdfs_dir
)
print
(
'local_dir '
,
FLAGS
.
local_dir
)
print
(
'dt_dir '
,
FLAGS
.
dt_dir
)
print
(
'num_epochs '
,
FLAGS
.
num_epochs
)
print
(
'feature_size '
,
FLAGS
.
feature_size
)
...
...
@@ -320,6 +319,7 @@ def main(_):
path
=
"hdfs:///strategy/esmm/"
tr_files
=
[
path
+
"tr/part-r-00000"
]
va_files
=
[
path
+
"va/part-r-00000"
]
te_files
=
[
"
%
s/part-r-00000"
%
FLAGS
.
hdfs_dir
]
# tr_files = glob.glob("%s/tr/*tfrecord" % FLAGS.data_dir)
# random.shuffle(tr_files)
...
...
@@ -366,9 +366,9 @@ def main(_):
print
(
'
%
s:
%
s'
%
(
key
,
value
))
elif
FLAGS
.
task_type
==
'infer'
:
preds
=
Estimator
.
predict
(
input_fn
=
lambda
:
input_fn
(
te_files
,
num_epochs
=
1
,
batch_size
=
FLAGS
.
batch_size
),
predict_keys
=
[
"pctcvr"
,
"pctr"
,
"pcvr"
])
with
open
(
FLAGS
.
data
_dir
+
"/pred.txt"
,
"w"
)
as
fo
:
with
open
(
FLAGS
.
local
_dir
+
"/pred.txt"
,
"w"
)
as
fo
:
print
(
"-"
*
100
)
with
open
(
FLAGS
.
data
_dir
+
"/pred.txt"
,
"w"
)
as
fo
:
with
open
(
FLAGS
.
local
_dir
+
"/pred.txt"
,
"w"
)
as
fo
:
for
prob
in
preds
:
fo
.
write
(
"
%
f
\t
%
f
\t
%
f
\n
"
%
(
prob
[
'pctr'
],
prob
[
'pcvr'
],
prob
[
'pctcvr'
]))
elif
FLAGS
.
task_type
==
'export'
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment