Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
S
serviceRec
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
郭羽
serviceRec
Commits
1a556d84
Commit
1a556d84
authored
Dec 21, 2021
by
宋柯
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
模型调试
parent
22d0a85e
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
325 additions
and
0 deletions
+325
-0
down_vocab_tfrecord.py
train/down_vocab_tfrecord.py
+45
-0
train_service_sk_tfrecord.py
train/train_service_sk_tfrecord.py
+280
-0
No files found.
train/down_vocab_tfrecord.py
0 → 100644
View file @
1a556d84
import
redis
import
sys
import
os
import
json
def
getRedisConn
():
pool
=
redis
.
ConnectionPool
(
host
=
"172.16.50.145"
,
password
=
"XfkMCCdWDIU
%
ls$h"
,
port
=
6379
,
db
=
0
)
conn
=
redis
.
Redis
(
connection_pool
=
pool
)
# conn = redis.Redis(host="172.16.50.145", port=6379, password="XfkMCCdWDIU%ls$h",db=0)
# conn = redis.Redis(host="172.18.51.10", port=6379, db=0, decode_responses = True) #test
return
conn
if
len
(
sys
.
argv
)
==
2
:
save_dir
=
sys
.
argv
[
1
]
else
:
save_dir
=
'/data/files/wideAndDeep/'
print
(
'save_dir: '
,
save_dir
)
if
not
os
.
path
.
exists
(
save_dir
):
print
(
'mkdir save_dir: '
,
save_dir
)
os
.
makedirs
(
save_dir
)
conn
=
getRedisConn
()
vocab_keys
=
conn
.
lrange
(
"strategy:all:vocab"
,
0
,
-
1
)
print
(
"vocab_keys: "
,
vocab_keys
[
0
])
vocab_keys
=
eval
(
vocab_keys
[
0
])
for
vocab_key
in
vocab_keys
:
print
(
'vocab_key: '
,
vocab_key
)
splits
=
vocab_key
.
split
(
":"
)
field
=
splits
[
1
]
filename
=
field
+
"_vocab.csv"
print
(
'filename: '
,
filename
)
with
open
(
os
.
path
.
join
(
save_dir
,
filename
),
'w'
)
as
f
:
texts
=
conn
.
lrange
(
vocab_key
,
0
,
-
1
)
texts
=
list
(
filter
(
lambda
x
:
x
!=
''
,
eval
(
texts
[
0
])))
print
(
'texts: '
,
len
(
texts
))
f
.
write
(
'
\n
'
.
join
(
texts
))
os
.
system
(
"hdfs dfs -getmerge /strategy/train_samples_tfrecord {save_dir}train_samples.tfrecord"
.
format
(
save_dir
=
save_dir
))
os
.
system
(
"hdfs dfs -getmerge /strategy/eval_samples_tfrecord {save_dir}eval_samples.tfrecord"
.
format
(
save_dir
=
save_dir
))
train/train_service_sk_tfrecord.py
0 → 100644
View file @
1a556d84
import
tensorflow
as
tf
tf
.
logging
.
set_verbosity
(
tf
.
logging
.
INFO
)
import
sys
import
os
os
.
environ
[
'TF_FORCE_GPU_ALLOW_GROWTH'
]
=
'true'
BASE_DIR
=
'/data/files/wideAndDeep/'
def
input_fn
(
csv_path
,
epoch
,
shuffle
,
batch_size
):
dataset
=
tf
.
data
.
TFRecordDataset
(
csv_path
,
buffer_size
=
1024
,
num_parallel_reads
=
2
)
dics
=
{
'ITEM_CATEGORY_card_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'USER_CATEGORY_device_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'USER_CATEGORY_os'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'USER_CATEGORY_user_city_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'USER_MULTI_CATEGORY_second_solutions'
:
tf
.
VarLenFeature
(
tf
.
string
),
'USER_MULTI_CATEGORY_second_demands'
:
tf
.
VarLenFeature
(
tf
.
string
),
'USER_MULTI_CATEGORY_second_positions'
:
tf
.
VarLenFeature
(
tf
.
string
),
'USER_MULTI_CATEGORY_projects'
:
tf
.
VarLenFeature
(
tf
.
string
),
'ITEM_NUMERIC_click_count_sum'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_click_count_avg'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_click_count_stddev'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_exp_count_sum'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_exp_count_avg'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_exp_count_stddev'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_discount'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_case_count'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_NUMERIC_sales_count'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'ITEM_CATEGORY_service_type'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_merchant_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_doctor_type'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_doctor_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_doctor_famous'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_hospital_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_hospital_city_tag_id'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_hospital_type'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_CATEGORY_hospital_is_high_quality'
:
tf
.
FixedLenFeature
((),
tf
.
string
,
default_value
=
'-1'
),
'ITEM_MULTI_CATEGORY_second_demands'
:
tf
.
VarLenFeature
(
tf
.
string
),
'ITEM_MULTI_CATEGORY_second_solutions'
:
tf
.
VarLenFeature
(
tf
.
string
),
'ITEM_MULTI_CATEGORY_second_positions'
:
tf
.
VarLenFeature
(
tf
.
string
),
'ITEM_MULTI_CATEGORY_projects'
:
tf
.
VarLenFeature
(
tf
.
string
),
'ITEM_NUMERIC_sku_price'
:
tf
.
FixedLenFeature
((),
tf
.
float32
,
default_value
=
0
),
'label'
:
tf
.
FixedLenFeature
((),
tf
.
int64
,
default_value
=
0
),
}
def
parse_serialized_example
(
serialized_example
):
parsed_example
=
tf
.
parse_single_example
(
serialized_example
,
dics
)
parsed_example
[
'USER_MULTI_CATEGORY_second_solutions'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'USER_MULTI_CATEGORY_second_solutions'
],
default_value
=
'-1'
)
parsed_example
[
'USER_MULTI_CATEGORY_second_demands'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'USER_MULTI_CATEGORY_second_demands'
],
default_value
=
'-1'
)
parsed_example
[
'USER_MULTI_CATEGORY_second_positions'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'USER_MULTI_CATEGORY_second_positions'
],
default_value
=
'-1'
)
parsed_example
[
'USER_MULTI_CATEGORY_projects'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'USER_MULTI_CATEGORY_projects'
],
default_value
=
'-1'
)
parsed_example
[
'ITEM_MULTI_CATEGORY_second_demands'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'ITEM_MULTI_CATEGORY_second_demands'
],
default_value
=
'-1'
)
parsed_example
[
'ITEM_MULTI_CATEGORY_second_solutions'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'ITEM_MULTI_CATEGORY_second_solutions'
],
default_value
=
'-1'
)
parsed_example
[
'ITEM_MULTI_CATEGORY_second_positions'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'ITEM_MULTI_CATEGORY_second_positions'
],
default_value
=
'-1'
)
parsed_example
[
'ITEM_MULTI_CATEGORY_projects'
]
=
tf
.
sparse_tensor_to_dense
(
parsed_example
[
'ITEM_MULTI_CATEGORY_projects'
],
default_value
=
'-1'
)
return
parsed_example
,
parsed_example
.
pop
(
'label'
)
padded_shapes
=
({
'ITEM_CATEGORY_card_id'
:
(),
'USER_CATEGORY_device_id'
:
(),
'USER_CATEGORY_os'
:
(),
'USER_CATEGORY_user_city_id'
:
(),
'USER_MULTI_CATEGORY_second_solutions'
:
[
-
1
],
'USER_MULTI_CATEGORY_second_demands'
:
[
-
1
],
'USER_MULTI_CATEGORY_second_positions'
:
[
-
1
],
'USER_MULTI_CATEGORY_projects'
:
[
-
1
],
'ITEM_NUMERIC_click_count_sum'
:
(),
'ITEM_NUMERIC_click_count_avg'
:
(),
'ITEM_NUMERIC_click_count_stddev'
:
(),
'ITEM_NUMERIC_exp_count_sum'
:
(),
'ITEM_NUMERIC_exp_count_avg'
:
(),
'ITEM_NUMERIC_exp_count_stddev'
:
(),
'ITEM_NUMERIC_discount'
:
(),
'ITEM_NUMERIC_case_count'
:
(),
'ITEM_NUMERIC_sales_count'
:
(),
'ITEM_CATEGORY_service_type'
:
(),
'ITEM_CATEGORY_merchant_id'
:
(),
'ITEM_CATEGORY_doctor_type'
:
(),
'ITEM_CATEGORY_doctor_id'
:
(),
'ITEM_CATEGORY_doctor_famous'
:
(),
'ITEM_CATEGORY_hospital_id'
:
(),
'ITEM_CATEGORY_hospital_city_tag_id'
:
(),
'ITEM_CATEGORY_hospital_type'
:
(),
'ITEM_CATEGORY_hospital_is_high_quality'
:
(),
'ITEM_MULTI_CATEGORY_second_demands'
:
[
-
1
],
'ITEM_MULTI_CATEGORY_second_solutions'
:
[
-
1
],
'ITEM_MULTI_CATEGORY_second_positions'
:
[
-
1
],
'ITEM_MULTI_CATEGORY_projects'
:
[
-
1
],
'ITEM_NUMERIC_sku_price'
:
()},
())
padding_values
=
({
'ITEM_CATEGORY_card_id'
:
'-1'
,
'USER_CATEGORY_device_id'
:
'-1'
,
'USER_CATEGORY_os'
:
'-1'
,
'USER_CATEGORY_user_city_id'
:
'-1'
,
'USER_MULTI_CATEGORY_second_solutions'
:
'-1'
,
'USER_MULTI_CATEGORY_second_demands'
:
'-1'
,
'USER_MULTI_CATEGORY_second_positions'
:
'-1'
,
'USER_MULTI_CATEGORY_projects'
:
'-1'
,
'ITEM_NUMERIC_click_count_sum'
:
0.0
,
'ITEM_NUMERIC_click_count_avg'
:
0.0
,
'ITEM_NUMERIC_click_count_stddev'
:
0.0
,
'ITEM_NUMERIC_exp_count_sum'
:
0.0
,
'ITEM_NUMERIC_exp_count_avg'
:
0.0
,
'ITEM_NUMERIC_exp_count_stddev'
:
0.0
,
'ITEM_NUMERIC_discount'
:
0.0
,
'ITEM_NUMERIC_case_count'
:
0.0
,
'ITEM_NUMERIC_sales_count'
:
0.0
,
'ITEM_CATEGORY_service_type'
:
'-1'
,
'ITEM_CATEGORY_merchant_id'
:
'-1'
,
'ITEM_CATEGORY_doctor_type'
:
'-1'
,
'ITEM_CATEGORY_doctor_id'
:
'-1'
,
'ITEM_CATEGORY_doctor_famous'
:
'-1'
,
'ITEM_CATEGORY_hospital_id'
:
'-1'
,
'ITEM_CATEGORY_hospital_city_tag_id'
:
'-1'
,
'ITEM_CATEGORY_hospital_type'
:
'-1'
,
'ITEM_CATEGORY_hospital_is_high_quality'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_demands'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_solutions'
:
'-1'
,
'ITEM_MULTI_CATEGORY_second_positions'
:
'-1'
,
'ITEM_MULTI_CATEGORY_projects'
:
'-1'
,
'ITEM_NUMERIC_sku_price'
:
0.0
},
tf
.
constant
(
0
,
dtype
=
tf
.
int64
))
dataset
=
dataset
.
map
(
parse_serialized_example
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
if
shuffle
:
dataset
=
dataset
.
shuffle
(
1024
)
else
:
dataset
=
dataset
dataset
=
dataset
.
padded_batch
(
batch_size
,
padded_shapes
,
padding_values
=
padding_values
)
dataset
.
prefetch
(
tf
.
data
.
experimental
.
AUTOTUNE
)
dataset
.
repeat
(
epoch
)
return
dataset
boundaries
=
[
0
,
10
,
100
]
ITEM_NUMERIC_click_count_sum_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_click_count_sum'
),
boundaries
)
ITEM_NUMERIC_exp_count_sum_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_exp_count_sum'
),
boundaries
)
ITEM_NUMERIC_click_count_avg_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_click_count_avg'
),
boundaries
)
ITEM_NUMERIC_exp_count_avg_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_exp_count_avg'
),
boundaries
)
boundaries
=
[
0
,
0.01
,
0.1
]
ITEM_NUMERIC_click_count_stddev_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_click_count_stddev'
),
boundaries
)
ITEM_NUMERIC_exp_count_stddev_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_exp_count_stddev'
),
boundaries
)
boundaries
=
[
0
,
0.01
,
0.1
,
1
]
ITEM_NUMERIC_discount_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_discount'
),
boundaries
)
boundaries
=
[
0
,
10
,
100
]
ITEM_NUMERIC_case_count_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_case_count'
),
boundaries
)
ITEM_NUMERIC_sales_count_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_sales_count'
),
boundaries
)
ITEM_NUMERIC_sku_price_fc
=
tf
.
feature_column
.
bucketized_column
(
tf
.
feature_column
.
numeric_column
(
'ITEM_NUMERIC_sku_price'
),
boundaries
)
USER_CATEGORY_device_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_CATEGORY_device_id'
,
BASE_DIR
+
'USER_CATEGORY_device_id_vocab.csv'
)
USER_CATEGORY_os_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_CATEGORY_os'
,
BASE_DIR
+
'USER_CATEGORY_os_vocab.csv'
)
USER_CATEGORY_user_city_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_CATEGORY_user_city_id'
,
BASE_DIR
+
'USER_CATEGORY_user_city_id_vocab.csv'
)
USER_MULTI_CATEGORY__second_solutions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_second_solutions'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_second_solutions_vocab.csv'
)
USER_MULTI_CATEGORY__second_positions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_second_positions'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_second_positions_vocab.csv'
)
USER_MULTI_CATEGORY__second_demands_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_second_demands'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_second_demands_vocab.csv'
)
USER_MULTI_CATEGORY__projects_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'USER_MULTI_CATEGORY_projects'
,
BASE_DIR
+
'USER_MULTI_CATEGORY_projects_vocab.csv'
)
ITEM_CATEGORY_card_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_card_id'
,
BASE_DIR
+
'ITEM_CATEGORY_card_id_vocab.csv'
)
ITEM_CATEGORY_service_type_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_service_type'
,
BASE_DIR
+
'ITEM_CATEGORY_service_type_vocab.csv'
)
ITEM_CATEGORY_merchant_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_merchant_id'
,
BASE_DIR
+
'ITEM_CATEGORY_merchant_id_vocab.csv'
)
ITEM_CATEGORY_doctor_type_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_doctor_type'
,
BASE_DIR
+
'ITEM_CATEGORY_doctor_type_vocab.csv'
)
ITEM_CATEGORY_doctor_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_doctor_id'
,
BASE_DIR
+
'ITEM_CATEGORY_doctor_id_vocab.csv'
)
ITEM_CATEGORY_doctor_famous_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_doctor_famous'
,
BASE_DIR
+
'ITEM_CATEGORY_doctor_famous_vocab.csv'
)
ITEM_CATEGORY_hospital_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_id'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_id_vocab.csv'
)
ITEM_CATEGORY_hospital_city_tag_id_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_city_tag_id'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_city_tag_id_vocab.csv'
)
ITEM_CATEGORY_hospital_type_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_type'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_type_vocab.csv'
)
ITEM_CATEGORY_hospital_is_high_quality_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_CATEGORY_hospital_is_high_quality'
,
BASE_DIR
+
'ITEM_CATEGORY_hospital_is_high_quality_vocab.csv'
)
ITEM_MULTI_CATEGORY__second_solutions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_second_solutions'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_second_solutions_vocab.csv'
)
ITEM_MULTI_CATEGORY__second_positions_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_second_positions'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_second_positions_vocab.csv'
)
ITEM_MULTI_CATEGORY__second_demands_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_second_demands'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_second_demands_vocab.csv'
)
ITEM_MULTI_CATEGORY__projects_fc
=
tf
.
feature_column
.
categorical_column_with_vocabulary_file
(
'ITEM_MULTI_CATEGORY_projects'
,
BASE_DIR
+
'ITEM_MULTI_CATEGORY_projects_vocab.csv'
)
def
embedding_fc
(
categorical_column
,
dim
):
return
tf
.
feature_column
.
embedding_column
(
categorical_column
,
dim
)
linear_feature_columns
=
[
ITEM_NUMERIC_click_count_sum_fc
,
ITEM_NUMERIC_exp_count_sum_fc
,
ITEM_NUMERIC_click_count_avg_fc
,
ITEM_NUMERIC_exp_count_avg_fc
,
ITEM_NUMERIC_click_count_stddev_fc
,
ITEM_NUMERIC_exp_count_stddev_fc
,
ITEM_NUMERIC_discount_fc
,
ITEM_NUMERIC_case_count_fc
,
ITEM_NUMERIC_sales_count_fc
,
ITEM_NUMERIC_sku_price_fc
,
embedding_fc
(
ITEM_CATEGORY_card_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_service_type_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_merchant_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_doctor_type_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_doctor_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_doctor_famous_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_city_tag_id_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_type_fc
,
1
),
embedding_fc
(
ITEM_CATEGORY_hospital_is_high_quality_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__projects_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_demands_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_positions_fc
,
1
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_solutions_fc
,
1
),
]
dnn_feature_columns
=
[
embedding_fc
(
USER_CATEGORY_device_id_fc
,
8
),
embedding_fc
(
USER_CATEGORY_os_fc
,
8
),
embedding_fc
(
USER_CATEGORY_user_city_id_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__second_solutions_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__second_positions_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__second_demands_fc
,
8
),
embedding_fc
(
USER_MULTI_CATEGORY__projects_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_click_count_sum_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_exp_count_sum_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_click_count_avg_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_exp_count_avg_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_click_count_stddev_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_exp_count_stddev_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_discount_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_case_count_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_sales_count_fc
,
8
),
embedding_fc
(
ITEM_NUMERIC_sku_price_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_card_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_service_type_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_merchant_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_doctor_type_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_doctor_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_doctor_famous_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_city_tag_id_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_type_fc
,
8
),
embedding_fc
(
ITEM_CATEGORY_hospital_is_high_quality_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__projects_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_demands_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_positions_fc
,
8
),
embedding_fc
(
ITEM_MULTI_CATEGORY__second_solutions_fc
,
8
),
]
import
os
os
.
environ
[
"CUDA_VISIBLE_DEVICES"
]
=
"0, 1, 2"
from
tensorflow.python.client
import
device_lib
print
(
device_lib
.
list_local_devices
())
distribution
=
tf
.
distribute
.
MirroredStrategy
()
# session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True)
session_config
=
tf
.
compat
.
v1
.
ConfigProto
(
allow_soft_placement
=
True
)
session_config
.
gpu_options
.
allow_growth
=
True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config
=
tf
.
estimator
.
RunConfig
(
save_checkpoints_steps
=
10000
,
session_config
=
session_config
)
wideAndDeepModel
=
tf
.
estimator
.
DNNLinearCombinedClassifier
(
model_dir
=
BASE_DIR
+
'model_tfrecord'
,
linear_feature_columns
=
linear_feature_columns
,
dnn_feature_columns
=
dnn_feature_columns
,
dnn_hidden_units
=
[
128
,
32
],
dnn_dropout
=
0.5
,
config
=
config
)
# early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(wideAndDeepModel, eval_dir = wideAndDeepModel.eval_dir(), metric_name='auc', max_steps_without_decrease=1000, min_steps = 100)
# early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(wideAndDeepModel, metric_name = 'auc', max_steps_without_increase = 1000, min_steps = 1000)
hooks
=
[]
train_spec
=
tf
.
estimator
.
TrainSpec
(
input_fn
=
lambda
:
input_fn
(
BASE_DIR
+
'train_samples.tfrecord'
,
20
,
True
,
512
),
hooks
=
hooks
)
serving_feature_spec
=
tf
.
feature_column
.
make_parse_example_spec
(
linear_feature_columns
+
dnn_feature_columns
)
serving_input_receiver_fn
=
(
tf
.
estimator
.
export
.
build_parsing_serving_input_receiver_fn
(
serving_feature_spec
))
exporter
=
tf
.
estimator
.
BestExporter
(
name
=
"best_exporter"
,
compare_fn
=
lambda
best_eval_result
,
current_eval_result
:
current_eval_result
[
'auc'
]
>
best_eval_result
[
'auc'
],
serving_input_receiver_fn
=
serving_input_receiver_fn
,
exports_to_keep
=
3
)
eval_spec
=
tf
.
estimator
.
EvalSpec
(
input_fn
=
lambda
:
input_fn
(
BASE_DIR
+
'eval_samples.tfrecord'
,
1
,
False
,
2
**
15
),
steps
=
None
,
throttle_secs
=
120
,
exporters
=
exporter
)
# def my_auc(labels, predictions):
# return {'auc_pr_careful_interpolation': tf.metrics.auc(labels, predictions['logistic'], curve='ROC',
# summation_method='careful_interpolation')}
# wideAndDeepModel = tf.contrib.estimator.add_metrics(wideAndDeepModel, my_auc)
tf
.
estimator
.
train_and_evaluate
(
wideAndDeepModel
,
train_spec
,
eval_spec
)
wideAndDeepModel
.
evaluate
(
lambda
:
input_fn
(
BASE_DIR
+
'eval_samples.tfrecord'
,
1
,
False
,
2
**
15
))
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment