import pymysql import pandas as pd from utils import * from config import * # 候选集cid只能从训练数据集cid中选择 def filter_cid(df): data_set_cid = pd.read_csv(DIRECTORY_PATH + "data_set_cid.csv")["cid"].values.tolist() if not df.empty: df = df.loc[df["cid"].isin(data_set_cid)] return df def get_allCitiesDiaryTop3000(): # 获取全国点击量TOP3000日记 sql = "select city_id,cid from data_feed_click " \ "where cid_type = 'diary' group by cid order by max(click_count_choice) desc limit 3000" allCitiesTop3000 = con_sql(sql) allCitiesTop3000 = allCitiesTop3000.rename(columns={0: "city_id", 1: "cid"}) allCitiesTop3000 = filter_cid(allCitiesTop3000) allCitiesTop3000.to_csv(DIRECTORY_PATH + "diaryTestSet/allCitiesDiaryTop3000.csv",index=False) print("成功获取全国日记点击量TOP3000") return allCitiesTop3000 def get_cityList(): # 获取全国城市列表 sql = "select distinct city_id from data_feed_click" cityList = con_sql(sql) cityList.to_csv(DIRECTORY_PATH + "diaryTestSet/cityList.csv",index=False) cityList = cityList[0].values.tolist() print("成功获取全国城市列表") return cityList def get_eachCityDiaryTop3000(): # 获取每个城市点击量TOP3000日记,如果数量小于3000,用全国点击量TOP3000日记补充 cityList = get_cityList() allCitiesTop3000 = get_allCitiesDiaryTop3000() for i in cityList: sql = "select city_id,cid from data_feed_click " \ "where cid_type = 'diary' and city_id = '{0}' group by cid " \ "order by max(click_count_choice) desc limit 3000".format(i) data = con_sql(sql) data = data.rename(columns={0: "city_id", 1: "cid"}) data = filter_cid(data) if data.shape[0] < 3000: n = 3000 - data.shape[0] # 全国点击量TOP3000日记中去除该城市的日记 temp = allCitiesTop3000[allCitiesTop3000["city_id"] != i].loc[:n - 1] data = data.append(temp) else: pass file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(i) data.to_csv(file_name,index=False) print("成功保存{}地区DiaryTop3000".format(i)) if __name__ == "__main__": get_eachCityDiaryTop3000()