Commit c6881729 authored by 张彦钊's avatar 张彦钊

change ffm_train.fit_transform argument

parent fe7e0490
...@@ -2,6 +2,7 @@ import pymysql ...@@ -2,6 +2,7 @@ import pymysql
import pandas as pd import pandas as pd
from utils import * from utils import *
from config import * from config import *
import numpy as np
# 候选集cid只能从训练数据集cid中选择 # 候选集cid只能从训练数据集cid中选择
...@@ -34,11 +35,52 @@ def get_cityList(): ...@@ -34,11 +35,52 @@ def get_cityList():
print("成功获取全国城市列表") print("成功获取全国城市列表")
return cityList return cityList
# # 多线程方法获取全国城市热门日记
def multi_get_eachCityDiaryTop3000(): def pool_method(i,sql,allCitiesTop3000):
data = con_sql(sql)
data = data.rename(columns={0: "city_id", 1: "cid"})
data = filter_cid(data)
if data.shape[0] < 3000:
n = 3000 - data.shape[0]
# 全国点击量TOP3000日记中去除该城市的日记
temp = allCitiesTop3000[allCitiesTop3000["city_id"] != i].loc[:n - 1]
data = data.append(temp)
file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(i)
data.to_csv(file_name, index=False)
print("成功保存{}地区DiaryTop3000".format(i))
# 把城市列表切分成n份,然后拼接成一个列表
def split_cityList(cityList,n):
l = len(cityList)
step = np.rint(l/n)
new_list = []
x = 0
while True:
if x + step < :
data_list.append(data.iloc[x:x + step])
x = x + step + 1
else:
data_list.append(data.iloc[x:data.__len__()])
break
# 多线程方法获取全国城市热门日记
def multi_get_eachCityDiaryTop3000(processes):
cityList = get_cityList() cityList = get_cityList()
allCitiesTop3000 = get_allCitiesDiaryTop3000() allCitiesTop3000 = get_allCitiesDiaryTop3000()
pool = Pool(processes)
for i in range(len(data_list)):
data_list[i] = pool.apply_async(self.pool_function, (data_list[i], t,))
result_map = {}
for i in data_list:
result_map.update(i.get())
pool.close()
pool.join()
def get_eachCityDiaryTop3000(): def get_eachCityDiaryTop3000():
......
...@@ -75,7 +75,7 @@ def ffm_transform(data, test_number, validation_number): ...@@ -75,7 +75,7 @@ def ffm_transform(data, test_number, validation_number):
print("Start ffm transform") print("Start ffm transform")
start = time.time() start = time.time()
ffm_train = multiFFMFormatPandas() ffm_train = multiFFMFormatPandas()
data = ffm_train.fit_transform(data, y='y',n=1000000,processes=6) data = ffm_train.fit_transform(data, y='y',n=200000,processes=6)
with open(DIRECTORY_PATH+"ffm.pkl", "wb") as f: with open(DIRECTORY_PATH+"ffm.pkl", "wb") as f:
pickle.dump(ffm_train, f) pickle.dump(ffm_train, f)
......
...@@ -18,9 +18,9 @@ if __name__ == "__main__": ...@@ -18,9 +18,9 @@ if __name__ == "__main__":
train() train()
end = time.time() end = time.time()
print("训练模型耗时{}分".format((end-start)/60)) print("训练模型耗时{}分".format((end-start)/60))
print('---------------prepare candidates--------------') # print('---------------prepare candidates--------------')
get_eachCityDiaryTop3000() # get_eachCityDiaryTop3000()
print("end") # print("end")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment