Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
S
serviceRec
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
郭羽
serviceRec
Commits
11fb57e9
Commit
11fb57e9
authored
3 years ago
by
宋柯
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
模型调试
parent
81455a71
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
6 deletions
+9
-6
train_service_sk.py
train/train_service_sk.py
+9
-6
No files found.
train/train_service_sk.py
View file @
11fb57e9
...
...
@@ -76,12 +76,12 @@ def input_fn(csv_path, epoch, shuffle, batch_size):
'ITEM_MULTI_CATEGORY_second_solutions'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_positions'
:
'-1'
,
'ITEM_MULTI_CATEGORY_projects'
:
'-1'
,
'ITEM_NUMERIC_sku_price'
:
0.0
},
0.0
)
dataset
=
dataset
.
map
(
parse_line
,
num_parallel_calls
=
2
)
dataset
=
dataset
.
map
(
parse_line
,
num_parallel_calls
=
8
)
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
,
padding_values
=
padding_values
)
if
shuffle
:
dataset
=
dataset
.
shuffle
(
1000
)
.
prefetch
(
1000
0
)
.
repeat
(
epoch
)
dataset
=
dataset
.
shuffle
(
1000
)
.
prefetch
(
512
*
1
0
)
.
repeat
(
epoch
)
else
:
dataset
=
dataset
.
prefetch
(
1000
0
)
.
repeat
(
epoch
)
dataset
=
dataset
.
prefetch
(
512
*
1
0
)
.
repeat
(
epoch
)
return
dataset
...
...
@@ -205,11 +205,14 @@ print(device_lib.list_local_devices())
distribution
=
tf
.
distribute
.
MirroredStrategy
()
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
log_device_placement
=
True
,
allow_soft_placement
=
True
)
# session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True)
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
allow_soft_placement
=
True
)
session_config
.
gpu_options
.
allow_growth
=
True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config
=
tf
.
estimator
.
RunConfig
(
save_checkpoints_steps
=
10000
)
config
=
tf
.
estimator
.
RunConfig
(
save_checkpoints_steps
=
10000
,
session_config
=
session_config
)
wideAndDeepModel
=
tf
.
estimator
.
DNNLinearCombinedClassifier
(
model_dir
=
BASE_DIR
+
'model'
,
linear_feature_columns
=
linear_feature_columns
,
...
...
@@ -224,7 +227,7 @@ wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR
hooks
=
[]
train_spec
=
tf
.
estimator
.
TrainSpec
(
input_fn
=
lambda
:
input_fn
(
BASE_DIR
+
'train_samples.csv'
,
20
,
True
,
128
),
hooks
=
hooks
)
train_spec
=
tf
.
estimator
.
TrainSpec
(
input_fn
=
lambda
:
input_fn
(
BASE_DIR
+
'train_samples.csv'
,
20
,
True
,
512
),
hooks
=
hooks
)
serving_feature_spec
=
tf
.
feature_column
.
make_parse_example_spec
(
linear_feature_columns
+
dnn_feature_columns
)
...
...
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment