import pymysql import pandas as pd from utils import * from config import * import numpy as np import time # 候选集cid只能从训练数据集cid中选择 def filter_cid(df): data_set_cid = pd.read_csv(DIRECTORY_PATH + "data_set_cid.csv")["cid"].values.tolist() if not df.empty: df = df.loc[df["cid"].isin(data_set_cid)] return df def get_allCitiesDiaryTop3000(): # 获取全国点击量TOP3000日记 sql = "select city_id,cid from data_feed_click2 " \ "where cid_type = 'diary' group by cid order by max(click_count_choice) desc limit 3000" allCitiesTop3000 = con_sql(sql) allCitiesTop3000 = allCitiesTop3000.rename(columns={0: "city_id", 1: "cid"}) allCitiesTop3000 = filter_cid(allCitiesTop3000) allCitiesTop3000.to_csv(DIRECTORY_PATH + "diaryTestSet/allCitiesDiaryTop3000.csv",index=False) return allCitiesTop3000 def get_cityList(): # 获取全国城市列表 sql = "select distinct city_id from data_feed_click2" cityList = con_sql(sql) cityList.to_csv(DIRECTORY_PATH + "diaryTestSet/cityList.csv",index=False) cityList = cityList[0].values.tolist() return cityList def get_eachCityDiaryTop3000(): # 获取每个城市点击量TOP3000日记,如果数量小于3000,用全国点击量TOP3000日记补充 cityList = get_cityList() allCitiesTop3000 = get_allCitiesDiaryTop3000() for i in cityList: sql = "select city_id,cid from data_feed_click2 " \ "where cid_type = 'diary' and city_id = '{0}' group by cid " \ "order by max(click_count_choice) desc limit 3000".format(i) data = con_sql(sql) data = data.rename(columns={0: "city_id", 1: "cid"}) data = filter_cid(data) if data.shape[0] < 3000: n = 3000 - data.shape[0] # 全国点击量TOP3000日记中去除该城市的日记 temp = allCitiesTop3000[allCitiesTop3000["city_id"] != i].loc[:n - 1] data = data.append(temp) else: pass file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(i) data.to_csv(file_name,index=False) def pool_method(city,sql,allCitiesTop3000): data = con_sql(sql) data = data.rename(columns={0: "city_id", 1: "cid"}) data = filter_cid(data) if data.shape[0] < 3000: n = 3000 - data.shape[0] # 全国点击量TOP3000日记中去除该城市的日记 temp = allCitiesTop3000[allCitiesTop3000["city_id"] != city].loc[:n - 1] data = data.append(temp) file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(city) data.to_csv(file_name, index=False) # 多线程方法获取全国城市热门日记 def multi_get_eachCityDiaryTop3000(processes=8): city_list = get_cityList() allCitiesTop3000 = get_allCitiesDiaryTop3000() pool = Pool(processes) for city in city_list: sql = "select city_id,cid from data_feed_click2 " \ "where cid_type = 'diary' and city_id = '{0}' group by cid " \ "order by max(click_count_choice) desc limit 3000".format(city) pool.apply_async(pool_method,(city,sql,allCitiesTop3000,)) pool.close() pool.join() if __name__ == "__main__": start = time.time() multi_get_eachCityDiaryTop3000() end = time.time() print("获取各城市热门日记耗时{}分".format((end-start)/60))