Skip to content
Projects
Groups
Snippets
Help
Loading...
Sign in
Toggle navigation
S
serviceRec
Project
Project
Details
Activity
Cycle Analytics
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Charts
Issues
0
Issues
0
List
Board
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Charts
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Charts
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
郭羽
serviceRec
Commits
094d9856
Commit
094d9856
authored
3 years ago
by
郭羽
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
美购精排模型
parent
539709e3
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
20 additions
and
14 deletions
+20
-14
train.py
mlp/train.py
+20
-14
No files found.
mlp/train.py
View file @
094d9856
...
...
@@ -75,17 +75,21 @@ def getTrainColumns(train_columns,data_vocab):
# 离散特征
for
feature
in
train_columns
:
if
data_vocab
.
get
(
feature
):
if
feature
.
startswith
(
"userRatedHistory"
)
or
feature
.
count
(
"__"
)
>
0
or
feature
in
embedding_columns
:
cat_col
=
tf
.
feature_column
.
categorical_column_with_vocabulary_list
(
key
=
feature
,
vocabulary_list
=
data_vocab
[
feature
])
col
=
tf
.
feature_column
.
embedding_column
(
cat_col
,
10
)
columns
.
append
(
col
)
dataColumns
.
append
(
feature
)
elif
feature
in
one_hot_columns
or
feature
.
count
(
"Bucket"
)
>
0
:
cat_col
=
tf
.
feature_column
.
categorical_column_with_vocabulary_list
(
key
=
feature
,
vocabulary_list
=
data_vocab
[
feature
])
col
=
tf
.
feature_column
.
indicator_column
(
cat_col
)
columns
.
append
(
col
)
dataColumns
.
append
(
feature
)
cat_col
=
tf
.
feature_column
.
categorical_column_with_vocabulary_list
(
key
=
feature
,
vocabulary_list
=
data_vocab
[
feature
])
col
=
tf
.
feature_column
.
embedding_column
(
cat_col
,
10
)
columns
.
append
(
col
)
dataColumns
.
append
(
feature
)
# if feature.startswith("userRatedHistory") or feature.count("__") > 0 or feature in embedding_columns:
# cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
# col = tf.feature_column.embedding_column(cat_col, 10)
# columns.append(col)
# dataColumns.append(feature)
#
# elif feature in one_hot_columns or feature.count("Bucket") > 0:
# cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
# col = tf.feature_column.indicator_column(cat_col)
# columns.append(col)
# dataColumns.append(feature)
elif
feature
in
ITEM_NUMBER_COLUMNS
:
col
=
tf
.
feature_column
.
numeric_column
(
feature
)
...
...
@@ -129,7 +133,7 @@ def evaluate(model,test_dataset):
def
predict
(
model_path
,
df
):
print
(
"加载模型中..."
)
model_new
=
tf
.
keras
.
models
.
load_model
(
"service_fm_v3"
)
model_new
=
tf
.
keras
.
models
.
load_model
(
model_path
)
# model_new.summary()
print
(
"模型加载完成..."
)
# model = tf.keras.models.model_from_json(model.to_json)
...
...
@@ -180,7 +184,9 @@ if __name__ == '__main__':
timestmp4
=
int
(
round
(
time
.
time
()))
print
(
"读取数据耗时h:{}"
.
format
((
timestmp4
-
timestmp3
)
/
60
/
60
))
# evaluate(model,test_data)
predict
(
model_file
,
test_data
)
evaluate
(
model
,
test_data
)
predict
(
model_file
,
df_test
)
pass
This diff is collapsed.
Click to expand it.
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment