Commit 83b62060 authored by 张彦钊's avatar 张彦钊

test pickle

parent 400cb6a0
data/
*.pyc
import pymysql
import pandas as pd
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
# 从数据库获取数据,并将数据转化成DataFrame
def get_data(sql):
cursor = db.cursor()
cursor.execute(sql)
data = cursor.fetchall()
data = pd.DataFrame(list(data)).dropna()
return data
# 获取全国点击量TOP2000日记
sql = "select city_id,cid where cid_type = 'diary' order by click_count_choice desc limit 2000"
allCitiesTop2000 = get_data(sql)
allCitiesTop2000 = allCitiesTop2000.rename(columns={0:"city_id",1:"cid"})
allCitiesTop2000.to_csv("\home\zhangyanzhao\diaryTestSet\allCitiesTop2000.csv")
print("成功获取全国日记点击量TOP2000")
# 获取全国城市列表
sql = "select distinct city_id from data_feed_click"
cityList = get_data(sql)
cityList.to_csv("\home\zhangyanzhao\diaryTestSet\cityList.csv")
cityList = cityList[0].values.tolist()
print("成功获取城市列表")
# 获取每个城市点击量TOP2000日记,如果数量小于2000,用全国点击量TOP2000日记补充
for i in cityList:
sql = "select city_id,cid from data_feed_click " \
"where cid_type = 'diary' and city_id = {0} " \
"order by click_count_choice desc limit 2000".format(i)
data = get_data(sql)
data = data.rename(columns={0:"city_id",1:"cid"})
if data.shape[0]<2000:
n = 2000-data.shape[0]
# 全国点击量TOP2000日记中去除该城市的日记
temp = allCitiesTop2000[allCitiesTop2000["city_id"]!=i].loc[:n-1]
data = data.append(temp)
else:
pass
file_name = "\home\zhangyanzhao\diaryTestSet\{0}DiaryTop2000.csv".format(i)
data.to_csv(file_name)
print("end")
from utils import *
from config import *
def get_allCitiesDiaryTop2000():
# 获取全国点击量TOP2000日记
sql = "select city_id,cid where cid_type = 'diary' order by click_count_choice desc limit 2000"
allCitiesTop2000 = con_sql(sql)
allCitiesTop2000 = allCitiesTop2000.rename(columns={0:"city_id",1:"cid"})
allCitiesTop2000.to_csv(DIRECTORY_PATH+"diaryTestSet/allCitiesDiaryTop2000.csv")
print("成功获取全国日记点击量TOP2000")
return allCitiesTop2000
def get_cityList():
# 获取全国城市列表
sql = "select distinct city_id from data_feed_click"
cityList = con_sql(sql)
cityList.to_csv(DIRECTORY_PATH+"diaryTestSet/cityList.csv")
cityList = cityList[0].values.tolist()
print("成功获取全国城市列表")
return cityList
def get_eachCityDiaryTop2000():
# 获取每个城市点击量TOP2000日记,如果数量小于2000,用全国点击量TOP2000日记补充
cityList = get_cityList()
allCitiesTop2000 = get_allCitiesDiaryTop2000()
for i in cityList:
sql = "select city_id,cid from data_feed_click " \
"where cid_type = 'diary' and city_id = {0} " \
"order by click_count_choice desc limit 2000".format(i)
data = con_sql(sql)
data = data.rename(columns={0:"city_id",1:"cid"})
if data.shape[0]<2000:
n = 2000-data.shape[0]
# 全国点击量TOP2000日记中去除该城市的日记
temp = allCitiesTop2000[allCitiesTop2000["city_id"]!=i].loc[:n-1]
data = data.append(temp)
else:
pass
file_name = DIRECTORY_PATH+"diaryTestSet/{0}DiaryTop2000.csv".format(i)
data.to_csv(file_name)
if __name__ == "__main__":
get_eachCityDiaryTop2000()
......
......@@ -3,15 +3,20 @@ from config import *
print("Start training")
ffm_model = xl.create_ffm()
ffm_model.setTrain(DIRECTORY_PATH + "train.csv")
ffm_model.setValidate(DIRECTORY_PATH + "validation.csv")
ffm_model.setTrain(DIRECTORY_PATH + "train{0}-{1}.csv".format(DATA_START_DATE, VALIDATION_DATE))
ffm_model.setValidate(DIRECTORY_PATH + "validation{0}.csv".format(VALIDATION_DATE))
lr =0.03
l2_lambda = 0.002
param = {'task': 'binary', 'lr': lr, 'lambda': l2_lambda, 'metric': 'auc'}
param = {'task': 'binary', 'lr': 0.03, 'lambda': 0.002, 'metric': 'auc'}
ffm_model.fit(param, DIRECTORY_PATH + "model.out")
ffm_model.fit(param, DIRECTORY_PATH + "model_{0}-{1}_lr{2}_lambda{3}.out".format(DATA_START_DATE,
DATA_END_DATE,lr,l2_lambda))
print("predicting")
ffm_model.setTest(DIRECTORY_PATH + "test.csv")
ffm_model.setTest(DIRECTORY_PATH + "test{0}.csv".format(TEST_DATE))
ffm_model.setSigmoid()
ffm_model.predict(DIRECTORY_PATH + "model.out",
DIRECTORY_PATH + "output.txt")
ffm_model.predict(DIRECTORY_PATH + "model_{0}-{1}_lr{2}_lambda{3}.out".format(DATA_START_DATE,
DATA_END_DATE,"0.03","0.002"),
DIRECTORY_PATH + "testset{0}_output_model_{1}-{2}_lr{3}_lambda{4}.txt".format(TEST_DATE,
DATA_START_DATE,DATA_END_DATE,"0.03","0.002"))
from utils import *
import datetime
import pickle
if __name__ == '__main__':
data = pd.read_csv("data/raw-exposure.csv")[["cid","device_id","time"]]
data = pd.read_csv("../data/test-data/raw-exposure.csv")[["cid","device_id"]]
data["y"]=1
test_data = data.tail(5)
ffm = FFMFormatPandas()
data = ffm.fit_transform(data, y='y')
data.to_csv("ffm_data.csv",index=False)
with open("ffm.object","wb") as f:
pickle.dump(ffm,f)
with open("ffm.object","rb") as f:
ffm = pickle.load(f)
result = ffm.transform(test_data)
print(result)
data_1 = pd.read_csv("ffm_data.csv",header=None).tail(5)
print(data_1)
data["hour"] = data["time"].apply(lambda x: lambda x:datetime.datetime.fromtimestamp(x).hour)
#data.to_csv("data/data.csv")
print(data.head())
......
......@@ -68,7 +68,7 @@ test = data.loc[:test_number]
print("测试集大小")
print(test.shape[0])
test.to_csv(DIRECTORY_PATH + "test{0}.csv".format(TEST_DATE), index=False, header=None)
# 注意:测试集的日期一定要大于验证集,否则数据切割可能会出现错误
validation = data.loc[(test_number + 1):(test_number + validation_number)]
print("验证集大小")
print(validation.shape[0])
......
......@@ -54,7 +54,7 @@ class FFMFormatPandas:
def transform_row_(self, row, t):
ffm = []
if self.y != None:
if self.y is not None:
ffm.append(str(row.loc[row.index == self.y][0]))
if self.y is None:
ffm.append(str(0))
......@@ -62,7 +62,7 @@ class FFMFormatPandas:
for col, val in row.loc[row.index != self.y].to_dict().items():
col_type = t[col]
name = '{}_{}'.format(col, val)
if col_type.kind == 'O':
if col_type.kind == 'O':
ffm.append('{}:{}:1'.format(self.field_index_[col], self.feature_index_[name]))
elif col_type.kind == 'i':
ffm.append('{}:{}:{}'.format(self.field_index_[col], self.feature_index_[col], val))
......@@ -70,4 +70,4 @@ class FFMFormatPandas:
def transform(self, df):
t = df.dtypes.to_dict()
return pd.Series({idx: self.transform_row_(row, t) for idx, row in df.iterrows()})
\ No newline at end of file
return pd.Series({idx: self.transform_row_(row, t) for idx, row in df.iterrows()})
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment