diaryCandidateSet.py 3.33 KB
import pymysql
import pandas as pd
from utils import *
from config import *
import numpy as np
import time


# 候选集cid只能从训练数据集cid中选择
def filter_cid(df):
    data_set_cid = pd.read_csv(DIRECTORY_PATH + "data_set_cid.csv")["cid"].values.tolist()
    if not df.empty:
        df = df.loc[df["cid"].isin(data_set_cid)]
    return df


def get_allCitiesDiaryTop3000():
    # 获取全国点击量TOP3000日记

    sql = "select city_id,cid from data_feed_click2 " \
          "where cid_type = 'diary' group by cid order by max(click_count_choice) desc limit 3000"
    allCitiesTop3000 = con_sql(sql)
    allCitiesTop3000 = allCitiesTop3000.rename(columns={0: "city_id", 1: "cid"})
    allCitiesTop3000 = filter_cid(allCitiesTop3000)
    allCitiesTop3000.to_csv(DIRECTORY_PATH + "diaryTestSet/allCitiesDiaryTop3000.csv",index=False)
    return allCitiesTop3000


def get_cityList():
    # 获取全国城市列表
    sql = "select distinct city_id from data_feed_click2"
    cityList = con_sql(sql)
    cityList.to_csv(DIRECTORY_PATH + "diaryTestSet/cityList.csv",index=False)
    cityList = cityList[0].values.tolist()
    return cityList


def get_eachCityDiaryTop3000():
    # 获取每个城市点击量TOP3000日记,如果数量小于3000,用全国点击量TOP3000日记补充
    cityList = get_cityList()
    allCitiesTop3000 = get_allCitiesDiaryTop3000()
    for i in cityList:
        sql = "select city_id,cid from data_feed_click2 " \
              "where cid_type = 'diary' and city_id = '{0}' group by cid " \
              "order by max(click_count_choice) desc limit 3000".format(i)
        data = con_sql(sql)
        data = data.rename(columns={0: "city_id", 1: "cid"})
        data = filter_cid(data)
        if data.shape[0] < 3000:
            n = 3000 - data.shape[0]
            # 全国点击量TOP3000日记中去除该城市的日记
            temp = allCitiesTop3000[allCitiesTop3000["city_id"] != i].loc[:n - 1]
            data = data.append(temp)
        else:
            pass

        file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(i)
        data.to_csv(file_name,index=False)


def pool_method(city,sql,allCitiesTop3000):
    data = con_sql(sql)
    data = data.rename(columns={0: "city_id", 1: "cid"})
    data = filter_cid(data)
    if data.shape[0] < 3000:
        n = 3000 - data.shape[0]
        # 全国点击量TOP3000日记中去除该城市的日记
        temp = allCitiesTop3000[allCitiesTop3000["city_id"] != city].loc[:n - 1]
        data = data.append(temp)

    file_name = DIRECTORY_PATH + "diaryTestSet/{0}DiaryTop3000.csv".format(city)
    data.to_csv(file_name, index=False)


# 多线程方法获取全国城市热门日记
def multi_get_eachCityDiaryTop3000(processes=8):
    city_list = get_cityList()
    allCitiesTop3000 = get_allCitiesDiaryTop3000()
    pool = Pool(processes)
    for city in city_list:
        sql = "select city_id,cid from data_feed_click2 " \
          "where cid_type = 'diary' and city_id = '{0}' group by cid " \
          "order by max(click_count_choice) desc limit 3000".format(city)

        pool.apply_async(pool_method,(city,sql,allCitiesTop3000,))

    pool.close()
    pool.join()


if __name__ == "__main__":
    start = time.time()
    multi_get_eachCityDiaryTop3000()
    end = time.time()
    print("获取各城市热门日记耗时{}分".format((end-start)/60))