Commit e3a53854 authored by 郭羽's avatar 郭羽

美购精排模型

parent d3235bf6
......@@ -70,6 +70,7 @@ def getDataSet(df,shuffleSize = 10000,batchSize=128):
def getTrainColumns(train_columns,data_vocab):
columns = []
dataColumns = []
# 离散特征
for feature in train_columns:
if data_vocab.get(feature):
......@@ -77,15 +78,20 @@ def getTrainColumns(train_columns,data_vocab):
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
col = tf.feature_column.embedding_column(cat_col, 10)
columns.append(col)
dataColumns.append(feature)
elif feature in one_hot_columns or feature.count("Bucket") > 0:
cat_col = tf.feature_column.categorical_column_with_vocabulary_list(key=feature, vocabulary_list=data_vocab[feature])
col = tf.feature_column.indicator_column(cat_col)
columns.append(col)
dataColumns.append(feature)
elif feature in ITEM_NUMBER_COLUMNS or feature.endswith("RatingAvg") or feature.endswith("RatingStddev"):
col = tf.feature_column.numeric_column(feature)
columns.append(col)
return columns
dataColumns.append(feature)
return columns,dataColumns
def train(columns,train_dataset):
......@@ -148,8 +154,12 @@ if __name__ == '__main__':
timestmp2 = int(round(time.time()))
print("读取数据耗时s:{}".format(timestmp2 - timestmp1))
# df_train = df_train[list(data_vocab.keys()) + ITEM_NUMBER_COLUMNS + ["label"]]
# df_test = df_test[list(data_vocab.keys()) + ITEM_NUMBER_COLUMNS + ["label"]]
# 获取训练列
columns = df_train.columns.tolist()
trainColumns, datasColumns = getTrainColumns(columns, data_vocab)
df_train = df_train[datasColumns + ["label"]]
df_test = df_test[datasColumns + ["label"]]
trainSize = df_train["label"].count()
testSize = df_test["label"].count()
......@@ -158,16 +168,14 @@ if __name__ == '__main__':
# 数据类型转换
df_train = csvTypeConvert(df_train,data_vocab)
df_test = csvTypeConvert(df_test,data_vocab)
columns = df_train.columns.tolist()
# 获取训练数据
train_data = getDataSet(df_train,shuffleSize=trainSize,)
test_data = getDataSet(df_test,shuffleSize=testSize)
# 获取训练列
columns = getTrainColumns(columns,data_vocab)
timestmp3 = int(round(time.time()))
model = train(columns,train_data)
model = train(trainColumns,train_data)
timestmp4 = int(round(time.time()))
print("读取数据耗时h:{}".format((timestmp4 - timestmp3)/60/60))
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment