Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
S
serviceRec
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
郭羽
serviceRec
Commits
6a2b8057
Commit
6a2b8057
authored
Dec 20, 2021
by
宋柯
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
模型调试
parent
a2a2801e
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
256 additions
and
0 deletions
+256
-0
eval_service_sk.py
train/eval_service_sk.py
+256
-0
No files found.
train/eval_service_sk.py
0 → 100644
View file @
6a2b8057
import
tensorflow
as
tf
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
import
sys
import
os
os
.
environ
[
'TF_FORCE_GPU_ALLOW_GROWTH'
]
=
'true'
BASE_DIR
=
'/data/files/wideAndDeep/'
def
input_fn
(
csv_path
,
epoch
,
shuffle
,
batch_size
):
dataset
=
tf
.
data
.
TextLineDataset
(
csv_path
)
def
parse_line
(
line_tensor
):
splits
=
tf
.
compat
.
v1
.
string_split
([
line_tensor
],
delimiter
=
'|'
,
skip_empty
=
False
)
.
values
return
{
'ITEM_CATEGORY_card_id'
:
splits
[
0
],
'USER_CATEGORY_device_id'
:
splits
[
2
],
'USER_CATEGORY_os'
:
splits
[
3
],
'USER_CATEGORY_user_city_id'
:
splits
[
4
],
'USER_MULTI_CATEGORY_second_solutions'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
6
]],
delimiter
=
','
)
.
values
,
'USER_MULTI_CATEGORY_second_demands'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
7
]],
delimiter
=
','
)
.
values
,
'USER_MULTI_CATEGORY_second_positions'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
8
]],
delimiter
=
','
)
.
values
,
'USER_MULTI_CATEGORY_projects'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
9
]],
delimiter
=
','
)
.
values
,
'ITEM_NUMERIC_click_count_sum'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
10
]),
'ITEM_NUMERIC_click_count_avg'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
11
]),
'ITEM_NUMERIC_click_count_stddev'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
12
]),
'ITEM_NUMERIC_exp_count_sum'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
13
]),
'ITEM_NUMERIC_exp_count_avg'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
14
]),
'ITEM_NUMERIC_exp_count_stddev'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
15
]),
'ITEM_NUMERIC_discount'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
16
]),
'ITEM_NUMERIC_case_count'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
17
]),
'ITEM_NUMERIC_sales_count'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
18
]),
'ITEM_CATEGORY_service_type'
:
splits
[
19
],
'ITEM_CATEGORY_merchant_id'
:
splits
[
20
],
'ITEM_CATEGORY_doctor_type'
:
splits
[
21
],
'ITEM_CATEGORY_doctor_id'
:
splits
[
22
],
'ITEM_CATEGORY_doctor_famous'
:
splits
[
23
],
'ITEM_CATEGORY_hospital_id'
:
splits
[
24
],
'ITEM_CATEGORY_hospital_city_tag_id'
:
splits
[
25
],
'ITEM_CATEGORY_hospital_type'
:
splits
[
26
],
'ITEM_CATEGORY_hospital_is_high_quality'
:
splits
[
27
],
'ITEM_MULTI_CATEGORY_second_demands'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
28
]],
delimiter
=
','
)
.
values
,
'ITEM_MULTI_CATEGORY_second_solutions'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
29
]],
delimiter
=
','
)
.
values
,
'ITEM_MULTI_CATEGORY_second_positions'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
30
]],
delimiter
=
','
)
.
values
,
'ITEM_MULTI_CATEGORY_projects'
:
tf
.
compat
.
v1
.
string_split
([
splits
[
31
]],
delimiter
=
','
)
.
values
,
'ITEM_NUMERIC_sku_price'
:
tf
.
compat
.
v1
.
string_to_number
(
splits
[
32
]),
# 'label': tf.compat.v1.string_to_number(splits[5])
},
tf
.
compat
.
v1
.
string_to_number
(
splits
[
5
])
padded_shapes
=
({
'ITEM_CATEGORY_card_id'
:
(),
'USER_CATEGORY_device_id'
:
(),
'USER_CATEGORY_os'
:
(),
'USER_CATEGORY_user_city_id'
:
(),
'USER_MULTI_CATEGORY_second_solutions'
:
[
-
1
],
'USER_MULTI_CATEGORY_second_demands'
:
[
-
1
],
'USER_MULTI_CATEGORY_second_positions'
:
[
-
1
],
'USER_MULTI_CATEGORY_projects'
:
[
-
1
],
'ITEM_NUMERIC_click_count_sum'
:
(),
'ITEM_NUMERIC_click_count_avg'
:
(),
'ITEM_NUMERIC_click_count_stddev'
:
(),
'ITEM_NUMERIC_exp_count_sum'
:
(),
'ITEM_NUMERIC_exp_count_avg'
:
(),
'ITEM_NUMERIC_exp_count_stddev'
:
(),
'ITEM_NUMERIC_discount'
:
(),
'ITEM_NUMERIC_case_count'
:
(),
'ITEM_NUMERIC_sales_count'
:
(),
'ITEM_CATEGORY_service_type'
:
(),
'ITEM_CATEGORY_merchant_id'
:
(),
'ITEM_CATEGORY_doctor_type'
:
(),
'ITEM_CATEGORY_doctor_id'
:
(),
'ITEM_CATEGORY_doctor_famous'
:
(),
'ITEM_CATEGORY_hospital_id'
:
(),
'ITEM_CATEGORY_hospital_city_tag_id'
:
(),
'ITEM_CATEGORY_hospital_type'
:
(),
'ITEM_CATEGORY_hospital_is_high_quality'
:
(),
'ITEM_MULTI_CATEGORY_second_demands'
:
[
-
1
],
'ITEM_MULTI_CATEGORY_second_solutions'
:
[
-
1
],
'ITEM_MULTI_CATEGORY_second_positions'
:
[
-
1
],
'ITEM_MULTI_CATEGORY_projects'
:
[
-
1
],
'ITEM_NUMERIC_sku_price'
:
()},
())
padding_values
=
({
'ITEM_CATEGORY_card_id'
:
'-1'
,
'USER_CATEGORY_device_id'
:
'-1'
,
'USER_CATEGORY_os'
:
'-1'
,
'USER_CATEGORY_user_city_id'
:
'-1'
,
'USER_MULTI_CATEGORY_second_solutions'
:
'-1'
,
'USER_MULTI_CATEGORY_second_demands'
:
'-1'
,
'USER_MULTI_CATEGORY_second_positions'
:
'-1'
,
'USER_MULTI_CATEGORY_projects'
:
'-1'
,
'ITEM_NUMERIC_click_count_sum'
:
0.0
,
'ITEM_NUMERIC_click_count_avg'
:
0.0
,
'ITEM_NUMERIC_click_count_stddev'
:
0.0
,
'ITEM_NUMERIC_exp_count_sum'
:
0.0
,
'ITEM_NUMERIC_exp_count_avg'
:
0.0
,
'ITEM_NUMERIC_exp_count_stddev'
:
0.0
,
'ITEM_NUMERIC_discount'
:
0.0
,
'ITEM_NUMERIC_case_count'
:
0.0
,
'ITEM_NUMERIC_sales_count'
:
0.0
,
'ITEM_CATEGORY_service_type'
:
'-1'
,
'ITEM_CATEGORY_merchant_id'
:
'-1'
,
'ITEM_CATEGORY_doctor_type'
:
'-1'
,
'ITEM_CATEGORY_doctor_id'
:
'-1'
,
'ITEM_CATEGORY_doctor_famous'
:
'-1'
,
'ITEM_CATEGORY_hospital_id'
:
'-1'
,
'ITEM_CATEGORY_hospital_city_tag_id'
:
'-1'
,
'ITEM_CATEGORY_hospital_type'
:
'-1'
,
'ITEM_CATEGORY_hospital_is_high_quality'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_demands'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_solutions'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_positions'
:
'-1'
,
'ITEM_MULTI_CATEGORY_projects'
:
'-1'
,
'ITEM_NUMERIC_sku_price'
:
0.0
},
0.0
)
dataset
=
dataset
.
map
(
parse_line
,
num_parallel_calls
=
8
)
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
,
padding_values
=
padding_values
)
if
shuffle
:
dataset
=
dataset
.
shuffle
(
1000
)
.
prefetch
(
512
*
100
)
.
repeat
(
epoch
)
else
:
dataset
=
dataset
.
prefetch
(
512
*
100
)
.
repeat
(
epoch
)
return
dataset
boundaries
=
[
0
,
10
,
100
]
ITEM_NUMERIC_click_count_sum_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_click_count_sum'
),
boundaries
)
ITEM_NUMERIC_exp_count_sum_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_exp_count_sum'
),
boundaries
)
ITEM_NUMERIC_click_count_avg_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_click_count_avg'
),
boundaries
)
ITEM_NUMERIC_exp_count_avg_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_exp_count_avg'
),
boundaries
)
boundaries
=
[
0
,
0.01
,
0.1
]
ITEM_NUMERIC_click_count_stddev_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_click_count_stddev'
),
boundaries
)
ITEM_NUMERIC_exp_count_stddev_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_exp_count_stddev'
),
boundaries
)
boundaries
=
[
0
,
0.01
,
0.1
,
1
]
ITEM_NUMERIC_discount_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_discount'
),
boundaries
)
boundaries
=
[
0
,
10
,
100
]
ITEM_NUMERIC_case_count_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_case_count'
),
boundaries
)
ITEM_NUMERIC_sales_count_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_sales_count'
),
boundaries
)
ITEM_NUMERIC_sku_price_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_sku_price'
),
boundaries
)
USER_CATEGORY_device_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_CATEGORY_device_id'
,
BASE_DIR
+
'USER_CATEGORY_device_id_vocab.csv'
)
USER_CATEGORY_os_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_CATEGORY_os'
,
BASE_DIR
+
'USER_CATEGORY_os_vocab.csv'
)
USER_CATEGORY_user_city_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_CATEGORY_user_city_id'
,
BASE_DIR
+
'USER_CATEGORY_user_city_id_vocab.csv'
)
USER_MULTI_CATEGORY__second_solutions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_second_solutions'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_second_solutions_vocab.csv'
)
USER_MULTI_CATEGORY__second_positions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_second_positions'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_second_positions_vocab.csv'
)
USER_MULTI_CATEGORY__second_demands_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_second_demands'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_second_demands_vocab.csv'
)
USER_MULTI_CATEGORY__projects_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_projects'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_projects_vocab.csv'
)
ITEM_CATEGORY_card_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_card_id'
,
BASE_DIR
+
'ITEM_CATEGORY_card_id_vocab.csv'
)
ITEM_CATEGORY_service_type_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_service_type'
,
BASE_DIR
+
'ITEM_CATEGORY_service_type_vocab.csv'
)
ITEM_CATEGORY_merchant_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_merchant_id'
,
BASE_DIR
+
'ITEM_CATEGORY_merchant_id_vocab.csv'
)
ITEM_CATEGORY_doctor_type_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_doctor_type'
,
BASE_DIR
+
'ITEM_CATEGORY_doctor_type_vocab.csv'
)
ITEM_CATEGORY_doctor_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_doctor_id'
,
BASE_DIR
+
'ITEM_CATEGORY_doctor_id_vocab.csv'
)
ITEM_CATEGORY_doctor_famous_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_doctor_famous'
,
BASE_DIR
+
'ITEM_CATEGORY_doctor_famous_vocab.csv'
)
ITEM_CATEGORY_hospital_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_id'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_id_vocab.csv'
)
ITEM_CATEGORY_hospital_city_tag_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_city_tag_id'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_city_tag_id_vocab.csv'
)
ITEM_CATEGORY_hospital_type_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_type'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_type_vocab.csv'
)
ITEM_CATEGORY_hospital_is_high_quality_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_is_high_quality'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_is_high_quality_vocab.csv'
)
ITEM_MULTI_CATEGORY__second_solutions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_second_solutions'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_second_solutions_vocab.csv'
)
ITEM_MULTI_CATEGORY__second_positions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_second_positions'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_second_positions_vocab.csv'
)
ITEM_MULTI_CATEGORY__second_demands_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_second_demands'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_second_demands_vocab.csv'
)
ITEM_MULTI_CATEGORY__projects_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_projects'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_projects_vocab.csv'
)
def
embedding_fc
(
categorical_column
,
dim
):
return
tf
.
feature_column
.
embedding_column
(
categorical_column
,
dim
)
linear_feature_columns
=
[
ITEM_NUMERIC_click_count_sum_fc
,
ITEM_NUMERIC_exp_count_sum_fc
,
ITEM_NUMERIC_click_count_avg_fc
,
ITEM_NUMERIC_exp_count_avg_fc
,
ITEM_NUMERIC_click_count_stddev_fc
,
ITEM_NUMERIC_exp_count_stddev_fc
,
ITEM_NUMERIC_discount_fc
,
ITEM_NUMERIC_case_count_fc
,
ITEM_NUMERIC_sales_count_fc
,
ITEM_NUMERIC_sku_price_fc
,
embedding_fc
(
ITEM_CATEGORY_card_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_service_type_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_merchant_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_doctor_type_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_doctor_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_doctor_famous_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_city_tag_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_type_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_is_high_quality_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__projects_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_demands_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_positions_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_solutions_fc
,
1
),
]
dnn_feature_columns
=
[
embedding_fc
(
USER_CATEGORY_device_id_fc
,
8
),
embedding_fc
(
USER_CATEGORY_os_fc
,
8
),
embedding_fc
(
USER_CATEGORY_user_city_id_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__second_solutions_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__second_positions_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__second_demands_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__projects_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_click_count_sum_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_exp_count_sum_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_click_count_avg_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_exp_count_avg_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_click_count_stddev_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_exp_count_stddev_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_discount_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_case_count_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_sales_count_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_sku_price_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_card_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_service_type_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_merchant_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_doctor_type_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_doctor_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_doctor_famous_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_city_tag_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_type_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_is_high_quality_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__projects_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_demands_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_positions_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_solutions_fc
,
8
),
]
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0, 1, 2"
from
tensorflow.python.client
import
device_lib
print
(
device_lib
.
list_local_devices
())
distribution
=
tf
.
distribute
.
MirroredStrategy
()
# session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True)
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
allow_soft_placement
=
True
)
session_config
.
gpu_options
.
allow_growth
=
True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config
=
tf
.
estimator
.
RunConfig
(
save_checkpoints_steps
=
10000
,
session_config
=
session_config
)
wideAndDeepModel
=
tf
.
estimator
.
DNNLinearCombinedClassifier
(
model_dir
=
BASE_DIR
+
'model'
,
linear_feature_columns
=
linear_feature_columns
,
dnn_feature_columns
=
dnn_feature_columns
,
dnn_hidden_units
=
[
128
,
32
],
dnn_dropout
=
0.5
,
config
=
config
)
# early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(wideAndDeepModel, eval_dir = wideAndDeepModel.eval_dir(), metric_name='auc', max_steps_without_decrease=1000, min_steps = 100)
# early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(wideAndDeepModel, metric_name = 'auc', max_steps_without_increase = 1000, min_steps = 1000)
hooks
=
[]
train_spec
=
tf
.
estimator
.
TrainSpec
(
input_fn
=
lambda
:
input_fn
(
BASE_DIR
+
'train_samples.csv'
,
20
,
True
,
512
),
hooks
=
hooks
)
serving_feature_spec
=
tf
.
feature_column
.
make_parse_example_spec
(
linear_feature_columns
+
dnn_feature_columns
)
serving_input_receiver_fn
=
(
tf
.
estimator
.
export
.
build_parsing_serving_input_receiver_fn
(
serving_feature_spec
))
exporter
=
tf
.
estimator
.
BestExporter
(
name
=
"best_exporter"
,
compare_fn
=
lambda
best_eval_result
,
current_eval_result
:
current_eval_result
[
'auc'
]
>
best_eval_result
[
'auc'
],
serving_input_receiver_fn
=
serving_input_receiver_fn
,
exports_to_keep
=
3
)
eval_spec
=
tf
.
estimator
.
EvalSpec
(
input_fn
=
lambda
:
input_fn
(
BASE_DIR
+
'eval_samples.csv'
,
1
,
False
,
2
**
15
),
steps
=
None
,
throttle_secs
=
120
,
exporters
=
exporter
)
# def my_auc(labels, predictions):
# return {'auc_pr_careful_interpolation': tf.metrics.auc(labels, predictions['logistic'], curve='ROC',
# summation_method='careful_interpolation')}
# wideAndDeepModel = tf.contrib.estimator.add_metrics(wideAndDeepModel, my_auc)
# tf.estimator.train_and_evaluate(wideAndDeepModel, train_spec, eval_spec)
wideAndDeepModel
.
evaluate
(
lambda
:
input_fn
(
BASE_DIR
+
'eval_samples.csv'
,
1
,
False
,
2
**
15
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment