Commit 094d9856 authored by 郭羽's avatar 郭羽

美购精排模型

parent 539709e3
......@@ -75,17 +75,21 @@ def getTrainColumns(train_columns,data_vocab):
# 离散特征
for feature in train_columns:
if data_vocab.get(feature):
if feature.startswith("userRatedHistory") or feature.count("__") > 0 or feature in embedding_columns:
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature,vocabulary_list=data_vocab[feature])
col = tf.feature_column.embedding_column(cat_col, 10)
columns.append(col)
dataColumns.append(feature)
elif feature in one_hot_columns or feature.count("Bucket") > 0:
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
col = tf.feature_column.indicator_column(cat_col)
columns.append(col)
dataColumns.append(feature)
# if feature.startswith("userRatedHistory") or feature.count("__") > 0 or feature in embedding_columns:
# cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
# col = tf.feature_column.embedding_column(cat_col, 10)
# columns.append(col)
# dataColumns.append(feature)
#
# elif feature in one_hot_columns or feature.count("Bucket") > 0:
# cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
# col = tf.feature_column.indicator_column(cat_col)
# columns.append(col)
# dataColumns.append(feature)
elif feature in ITEM_NUMBER_COLUMNS:
col = tf.feature_column.numeric_column(feature)
......@@ -129,7 +133,7 @@ def evaluate(model,test_dataset):
def predict(model_path,df):
print("加载模型中...")
model_new = tf.keras.models.load_model("service_fm_v3")
model_new = tf.keras.models.load_model(model_path)
# model_new.summary()
print("模型加载完成...")
# model = tf.keras.models.model_from_json(model.to_json)
......@@ -180,7 +184,9 @@ if __name__ == '__main__':
timestmp4 = int(round(time.time()))
print("读取数据耗时h:{}".format((timestmp4 - timestmp3)/60/60))
# evaluate(model,test_data)
predict(model_file,test_data)
evaluate(model,test_data)
predict(model_file,df_test)
pass
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment