Commit 021a4e92 authored by 高雅喆's avatar 高雅喆

Merge branch 'master' of git.wanmeizhensuo.com:ML/ffm-baseline

update outline and tranform rate to %
parents 164b7897 d0a857c9
......@@ -2,6 +2,7 @@ import pymysql
import pandas as pd
from utils import *
from config import *
import numpy as np
# 候选集cid只能从训练数据集cid中选择
......@@ -34,11 +35,52 @@ def get_cityList():
print("成功获取全国城市列表")
return cityList
# # 多线程方法获取全国城市热门日记
def multi_get_eachCityDiaryTop3000():
cityList = get_cityList()
allCitiesTop3000 = get_allCitiesDiaryTop3000()
def pool_method(i,sql,allCitiesTop3000):
data = con_sql(sql)
data = data.rename(columns={0: "city_id", 1: "cid"})
data = filter_cid(data)
if data.shape[0] < 3000:
n = 3000 - data.shape[0]
# 全国点击量TOP3000日记中去除该城市的日记
temp = allCitiesTop3000[allCitiesTop3000["city_id"] != i].loc[:n - 1]
data = data.append(temp)
file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(i)
data.to_csv(file_name, index=False)
print("成功保存{}地区DiaryTop3000".format(i))
# 把城市列表切分成n份,然后拼接成一个列表
# def split_cityList(cityList,n):
# l = len(cityList)
# step = np.rint(l/n)
# new_list = []
# x = 0
# while True:
# if x + step < :
# data_list.append(data.iloc[x:x + step])
# x = x + step + 1
# else:
# data_list.append(data.iloc[x:data.__len__()])
# break
# 多线程方法获取全国城市热门日记
# def multi_get_eachCityDiaryTop3000(processes):
# cityList = get_cityList()
# allCitiesTop3000 = get_allCitiesDiaryTop3000()
#
# pool = Pool(processes)
# for i in range(len(data_list)):
# data_list[i] = pool.apply_async(self.pool_function, (data_list[i], t,))
#
# result_map = {}
# for i in data_list:
# result_map.update(i.get())
# pool.close()
# pool.join()
def get_eachCityDiaryTop3000():
......
......@@ -75,7 +75,7 @@ def ffm_transform(data, test_number, validation_number):
print("Start ffm transform")
start = time.time()
ffm_train = multiFFMFormatPandas()
data = ffm_train.fit_transform(data, y='y',n=50000,processes=5)
data = ffm_train.fit_transform(data, y='y',n=200000,processes=6)
with open(DIRECTORY_PATH+"ffm.pkl", "wb") as f:
pickle.dump(ffm_train, f)
......
......@@ -18,9 +18,9 @@ if __name__ == "__main__":
train()
end = time.time()
print("训练模型耗时{}分".format((end-start)/60))
print('---------------prepare candidates--------------')
get_eachCityDiaryTop3000()
print("end")
# print('---------------prepare candidates--------------')
# get_eachCityDiaryTop3000()
# print("end")
......
......@@ -89,7 +89,7 @@ class multiFFMFormatPandas:
return self
def fit_transform(self, df, y=None,n=10000,processes=5):
def fit_transform(self, df, y=None,n=1000000,processes=6):
# n是每个线程运行最大的数据条数,processes是线程数
self.fit(df, y)
n = n
......@@ -112,7 +112,7 @@ class multiFFMFormatPandas:
ffm.append('{}:{}:{}'.format(self.field_index_[col], self.feature_index_[col], val))
return ' '.join(ffm)
def transform(self, df,n=10000,processes=1):
def transform(self, df,n=10000,processes=2):
# n是每个线程运行最大的数据条数,processes是线程数
t = df.dtypes.to_dict()
data_list = self.data_split_line(df,n)
......@@ -120,11 +120,12 @@ class multiFFMFormatPandas:
# 设置进程的数量
pool = Pool(processes)
print("总进度: " + str(len(data_list)))
result_map = {}
for i in range(len(data_list)):
data_list[i] = pool.apply_async(self.pool_function, (data_list[i], t,))
result_map.update(data_list[i].get())
result_map = {}
for i in data_list:
result_map.update(i.get())
pool.close()
pool.join()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment