Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
F
ffm-baseline
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
ML
ffm-baseline
Commits
8f983d69
Commit
8f983d69
authored
Jun 11, 2019
by
王志伟
Browse files
Options
Browse Files
Download
Plain Diff
统计icon数据
parents
fc52c288
b452ad9c
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
7 deletions
+20
-7
train.py
eda/esmm/Model_pipline/train.py
+20
-7
No files found.
eda/esmm/Model_pipline/train.py
View file @
8f983d69
...
...
@@ -46,7 +46,7 @@ tf.app.flags.DEFINE_string("servable_model_dir", '', "export servable model for
tf
.
app
.
flags
.
DEFINE_string
(
"task_type"
,
'train'
,
"task type {train, infer, eval, export}"
)
tf
.
app
.
flags
.
DEFINE_boolean
(
"clear_existing_model"
,
False
,
"clear existing model or not"
)
#40362692,0,0,216:9342395:1.0 301:9351665:1.0 205:7702673:1.0 206:8317829:1.0 207:8967741:1.0 508:9356012:2.30259 210:9059239:1.0 210:9042796:1.0 210:9076972:1.0 210:9103884:1.0 210:9063064:1.0 127_14:3529789:2.3979 127_14:3806412:2.70805
def
input_fn
(
filenames
,
batch_size
=
32
,
num_epochs
=
1
,
perform_shuffle
=
False
):
print
(
'Parsing'
,
filenames
)
def
_parse_fn
(
record
):
...
...
@@ -71,17 +71,29 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
return
parsed
,
{
"y"
:
y
,
"z"
:
z
}
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset
=
tf
.
data
.
TFRecordDataset
(
filenames
)
.
map
(
_parse_fn
,
num_parallel_calls
=
10
)
.
prefetch
(
500000
)
# multi-thread pre-process then prefetch
# dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8
).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if
perform_shuffle
:
dataset
=
dataset
.
shuffle
(
buffer_size
=
256
)
#
if perform_shuffle:
#
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset
=
dataset
.
repeat
(
num_epochs
)
dataset
=
dataset
.
batch
(
batch_size
)
# Batch size to use
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
# dataset = dataset.repeat(num_epochs)
# dataset = dataset.batch(batch_size) # Batch size to use
files
=
tf
.
data
.
Dataset
.
list_files
(
filenames
)
dataset
=
files
.
apply
(
tf
.
data
.
experimental
.
parallel_interleave
(
lambda
file
:
tf
.
data
.
TFRecordDataset
(
file
),
cycle_length
=
8
)
)
dataset
=
dataset
.
apply
(
tf
.
data
.
experimental
.
map_and_batch
(
map_func
=
_parse_fn
,
batch_size
=
batch_size
,
num_parallel_calls
=
8
))
dataset
=
dataset
.
prefetch
(
10000
)
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
#return dataset.make_one_shot_iterator()
iterator
=
dataset
.
make_one_shot_iterator
()
batch_features
,
batch_labels
=
iterator
.
get_next
()
...
...
@@ -90,6 +102,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
#print(batch_features,batch_labels)
return
batch_features
,
batch_labels
def
model_fn
(
features
,
labels
,
mode
,
params
):
"""Bulid Model function f(x) for Estimator."""
#------hyperparameters----
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment