Commit 3fcaa79f authored by 张彦钊's avatar 张彦钊

日记候选集从top2000改为3000

parent aeddd470
...@@ -16,17 +16,17 @@ def filter_cid(df): ...@@ -16,17 +16,17 @@ def filter_cid(df):
return df return df
def get_allCitiesDiaryTop2000(): def get_allCitiesDiaryTop3000():
# 获取全国点击量TOP2000日记 # 获取全国点击量TOP3000日记
sql = "select city_id,cid from data_feed_click " \ sql = "select city_id,cid from data_feed_click " \
"where cid_type = 'diary' group by cid order by max(click_count_choice) desc limit 2000" "where cid_type = 'diary' group by cid order by max(click_count_choice) desc limit 3000"
allCitiesTop2000 = con_sql(sql) allCitiesTop3000 = con_sql(sql)
allCitiesTop2000 = allCitiesTop2000.rename(columns={0: "city_id", 1: "cid"}) allCitiesTop3000 = allCitiesTop3000.rename(columns={0: "city_id", 1: "cid"})
allCitiesTop2000 = filter_cid(allCitiesTop2000) allCitiesTop3000 = filter_cid(allCitiesTop3000)
allCitiesTop2000.to_csv(DIRECTORY_PATH + "diaryTestSet/allCitiesDiaryTop2000.csv",index=False) allCitiesTop3000.to_csv(DIRECTORY_PATH + "diaryTestSet/allCitiesDiaryTop3000.csv",index=False)
print("成功获取全国日记点击量TOP2000") print("成功获取全国日记点击量TOP3000")
return allCitiesTop2000 return allCitiesTop3000
def get_cityList(): def get_cityList():
...@@ -40,28 +40,28 @@ def get_cityList(): ...@@ -40,28 +40,28 @@ def get_cityList():
return cityList return cityList
def get_eachCityDiaryTop2000(): def get_eachCityDiaryTop3000():
# 获取每个城市点击量TOP2000日记,如果数量小于2000,用全国点击量TOP2000日记补充 # 获取每个城市点击量TOP3000日记,如果数量小于3000,用全国点击量TOP3000日记补充
cityList = get_cityList() cityList = get_cityList()
allCitiesTop2000 = get_allCitiesDiaryTop2000() allCitiesTop3000 = get_allCitiesDiaryTop3000()
for i in cityList: for i in cityList:
sql = "select city_id,cid from data_feed_click " \ sql = "select city_id,cid from data_feed_click " \
"where cid_type = 'diary' and city_id = '{0}' group by cid " \ "where cid_type = 'diary' and city_id = '{0}' group by cid " \
"order by max(click_count_choice) desc limit 2000".format(i) "order by max(click_count_choice) desc limit 3000".format(i)
data = con_sql(sql) data = con_sql(sql)
data = data.rename(columns={0: "city_id", 1: "cid"}) data = data.rename(columns={0: "city_id", 1: "cid"})
data = filter_cid(data) data = filter_cid(data)
if data.shape[0] < 2000: if data.shape[0] < 3000:
n = 2000 - data.shape[0] n = 3000 - data.shape[0]
# 全国点击量TOP2000日记中去除该城市的日记 # 全国点击量TOP3000日记中去除该城市的日记
temp = allCitiesTop2000[allCitiesTop2000["city_id"] != i].loc[:n - 1] temp = allCitiesTop3000[allCitiesTop3000["city_id"] != i].loc[:n - 1]
data = data.append(temp) data = data.append(temp)
else: else:
pass pass
file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop2000.csv".format(i) file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(i)
data.to_csv(file_name,index=False) data.to_csv(file_name,index=False)
if __name__ == "__main__": if __name__ == "__main__":
get_eachCityDiaryTop2000() get_eachCityDiaryTop3000()
...@@ -11,7 +11,7 @@ from userProfile import fetch_user_profile ...@@ -11,7 +11,7 @@ from userProfile import fetch_user_profile
# 将device_id、city_id拼接到对应的城市热门日记表。注意:下面预测集特征顺序要与训练集保持一致 # 将device_id、city_id拼接到对应的城市热门日记表。注意:下面预测集特征顺序要与训练集保持一致
def feature_en(user_profile): def feature_en(user_profile):
file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop2000.csv".format(user_profile['city_id']) file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(user_profile['city_id'])
data = pd.read_csv(file_name) data = pd.read_csv(file_name)
data["device_id"] = user_profile['device_id'] data["device_id"] = user_profile['device_id']
...@@ -24,7 +24,7 @@ def feature_en(user_profile): ...@@ -24,7 +24,7 @@ def feature_en(user_profile):
data.loc[data["minute"] == 0, ["minute"]] = 60 data.loc[data["minute"] == 0, ["minute"]] = 60
data["hour"] = data["hour"].astype("category") data["hour"] = data["hour"].astype("category")
data["minute"] = data["minute"].astype("category") data["minute"] = data["minute"].astype("category")
# 虽然预测y,但ffm转化需要y,并不影响预测结果
data["y"] = 0 data["y"] = 0
data = data.drop("city_id", axis=1) data = data.drop("city_id", axis=1)
print(data.head(2)) print(data.head(2))
......
...@@ -71,7 +71,7 @@ class FFMFormatPandas: ...@@ -71,7 +71,7 @@ class FFMFormatPandas:
def transform(self, df): def transform(self, df):
t = df.dtypes.to_dict() t = df.dtypes.to_dict()
return pd.Series({idx: self.transform_row_(row, t) for idx, row in df.iterrows()}) return pd.Series({idx: self.transform_row_(row, t) for idx, row in df.iterrows()})
# 下面这个方法不是这个类原有的方法,是新增的。目的是用来判断这个用户是不是在训练数据集中存在
def is_feature_index_exist(self, name): def is_feature_index_exist(self, name):
if name in self.feature_index_: if name in self.feature_index_:
return True return True
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment