zhang.py 20.4 KB
Newer Older
张彦钊's avatar
张彦钊 committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190
# -*- coding: utf-8 -*-
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import time
from pyspark import StorageLevel


def app_list_func(x,l):
    b = str(x).split(",")
    e = []
    for i in b:
        if i in l.keys():
            e.append(l[i])
        else:
            e.append(0)
    return e


def get_list(db,sql,n):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    v = list(set([i[0] for i in result]))
    app_list_value = [str(i).split(",") for i in v]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
    db.close()
    return number, app_list_map


def get_map():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select app_list from device_app_list"
    a = time.time()
    apps_number, app_list_map = get_list(db,sql,16)
    print("applist")
    print((time.time()-a)/60)
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select level2_ids from diary_feat"
    b = time.time()
    leve2_number, leve2_map = get_list(db, sql, 16+apps_number)
    print("leve2")
    print((time.time() - b) / 60)
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select level3_ids from diary_feat"
    c = time.time()
    leve3_number, leve3_map = get_list(db, sql, 16+leve2_number+apps_number)
    print((time.time() - c) / 60)
    return apps_number, app_list_map,leve2_number, leve2_map,leve3_number, leve3_map


def get_unique(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    v = list(set([i[0] for i in result]))
    db.close()
    print(sql)
    print(len(v))
    return v

def con_sql(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    db.close()
    return df


def get_pre_number():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select count(*) from esmm_pre_data"
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchone()[0]
    print("预测集数量:")
    print(result)
    db.close()


def feature_engineer():
    apps_number, app_list_map, level2_number, leve2_map, level3_number, leve3_map = get_map()
    app_list_map["app_list"] = 16
    leve3_map["level3_ids"] = 17
    leve3_map["search_tag3"] = 18
    leve2_map["level2_ids"] = 19
    leve2_map["tag1"] = 20
    leve2_map["tag2"] = 21
    leve2_map["tag3"] = 22
    leve2_map["tag4"] = 23
    leve2_map["tag5"] = 24
    leve2_map["tag6"] = 25
    leve2_map["tag7"] = 26
    leve2_map["search_tag2"] = 27

    unique_values = []
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct stat_date from esmm_train_data_dwell"
    unique_values.extend(get_unique(db,sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct ucity_id from esmm_train_data_dwell"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct ccity_name from esmm_train_data_dwell"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct time from cid_time_cut"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct device_type from user_feature"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct manufacturer from user_feature"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct channel from user_feature"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct top from cid_type_top"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct price_min from knowledge"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct treatment_method from knowledge"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct price_max from knowledge"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct treatment_time from knowledge"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct maintain_time from knowledge"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select distinct recover_time from knowledge"
    unique_values.extend(get_unique(db, sql))

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select max(stat_date) from esmm_train_data_dwell"
    validate_date = con_sql(db, sql)[0].values.tolist()[0]
    print("validate_date:" + validate_date)
    temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    start = (temp - datetime.timedelta(days=3)).strftime("%Y-%m-%d")
    print(start)

    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC')
    sql = "select distinct doctor.hospital_id from jerry_test.esmm_train_data_dwell e " \
          "left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
          "left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
          "where e.stat_date >= '{}'".format(start)
    unique_values.extend(get_unique(db, sql))
    features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "stat_date", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time",
                "app_list", "level3_ids", "level2_ids", "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
                "search_tag2", "search_tag3"]
    unique_values.extend(features)
    print("unique_values length")
    print(len(unique_values))
    print("特征维度:")
    print(apps_number + level2_number + level3_number + len(unique_values))

    temp = list(range(28 + apps_number + level2_number + level3_number,
                      28 + apps_number + level2_number + level3_number + len(unique_values)))
    value_map = dict(zip(unique_values, temp))

张彦钊's avatar
张彦钊 committed
191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273
    # sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
    #       "u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \
    #       "wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
    #       "ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7,doris.search_tag2,doris.search_tag3," \
    #       "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
    #       "from jerry_test.esmm_train_data_dwell e left join jerry_test.user_feature u on e.device_id = u.device_id " \
    #       "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
    #       "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
    #       "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
    #       "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
    #       "left join jerry_test.knowledge k on feat.level2 = k.level2_id " \
    #       "left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
    #       "left join jerry_test.question_tag question on e.device_id = question.device_id " \
    #       "left join jerry_test.search_tag search on e.device_id = search.device_id " \
    #       "left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
    #       "left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
    #       "left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
    #       "left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
    #       "left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
    #       "left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
    #       "left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \
    #       "where e.stat_date >= '{}'".format(start)
    #
    # df = spark.sql(sql)
    #
    # df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
    #                          "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids",
    #                          "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7"])
    #
    # df = df.na.fill(dict(zip(features, features)))
    #
    # rdd = df.select("stat_date", "y", "z", "app_list", "level2_ids", "level3_ids",
    #                 "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
    #                 "ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
    #                 "hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
    #                 "maintain_time", "recover_time", "search_tag2", "search_tag3")\
    #     .rdd.repartition(200).map(
    #     lambda x: (x[0], float(x[1]), float(x[2]), app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
    #                app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map), app_list_func(x[7], leve2_map),
    #                app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map), app_list_func(x[10], leve2_map),
    #                app_list_func(x[11], leve2_map), app_list_func(x[12], leve2_map),
    #                [value_map.get(x[0], 1), value_map.get(x[13], 2), value_map.get(x[14], 3), value_map.get(x[15], 4),
    #                 value_map.get(x[16], 5), value_map.get(x[17], 6), value_map.get(x[18], 7), value_map.get(x[19], 8),
    #                 value_map.get(x[20], 9), value_map.get(x[21], 10),
    #                 value_map.get(x[22], 11), value_map.get(x[23], 12), value_map.get(x[24], 13),
    #                 value_map.get(x[25], 14), value_map.get(x[26], 15)],
    #                app_list_func(x[27], leve2_map), app_list_func(x[28], leve3_map)
    #                ))
    #
    #
    # rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER)
    #
    # # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
    #
    # train = rdd.map(
    #     lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9],
    #                x[10], x[11], x[12], x[13], x[14], x[15]))
    # f = time.time()
    # spark.createDataFrame(train).toDF("y", "z", "app_list", "level2_list", "level3_list",
    #                                   "tag1_list", "tag2_list", "tag3_list", "tag4_list",
    #                                   "tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
    #     .repartition(1).write.format("tfrecords").save(path=path + "tr/", mode="overwrite")
    # h = time.time()
    # print("train tfrecord done")
    # print((h - f) / 60)
    #
    # print("训练集样本总量:")
    # print(rdd.count())
    #
    # get_pre_number()
    #
    # test = rdd.filter(lambda x: x[0] == validate_date).map(
    #     lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9],
    #                x[10], x[11], x[12], x[13], x[14], x[15]))
    #
    # spark.createDataFrame(test).toDF("y", "z", "app_list", "level2_list", "level3_list",
    #                                  "tag1_list", "tag2_list", "tag3_list", "tag4_list",
    #                                  "tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
    #     .repartition(1).write.format("tfrecords").save(path=path + "va/", mode="overwrite")
    #
    # print("va tfrecord done")
    #
    # rdd.unpersist()
张彦钊's avatar
张彦钊 committed
274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299

    return validate_date, value_map, app_list_map, leve2_map, leve3_map


def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
    sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
          "u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
          "dl.app_list,e.hospital_id,feat.level3_ids," \
          "wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
          "ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7,doris.search_tag2,doris.search_tag3," \
          "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
          "from jerry_test.esmm_pre_data e " \
          "left join jerry_test.user_feature u on e.device_id = u.device_id " \
          "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
          "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
          "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
          "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
          "left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
          "left join jerry_test.question_tag question on e.device_id = question.device_id " \
          "left join jerry_test.search_tag search on e.device_id = search.device_id " \
          "left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
          "left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
          "left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
          "left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
          "left join jerry_test.knowledge k on feat.level2 = k.level2_id " \
          "left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \
张彦钊's avatar
张彦钊 committed
300
          "limit 600000"
张彦钊's avatar
张彦钊 committed
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336

    features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "time", "hospital_id",
                "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time",
                "app_list", "level3_ids", "level2_ids", "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
                "search_tag2", "search_tag3"]

    df = spark.sql(sql)
    df = df.drop_duplicates(["ucity_id", "device_id", "cid_id"])

    df = df.na.fill(dict(zip(features, features)))
    f = time.time()
    rdd = df.select("label", "y", "z", "ucity_id", "device_id", "cid_id", "app_list", "level2_ids", "level3_ids",
                    "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
                    "ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
                    "hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
                    "maintain_time", "recover_time", "search_tag2", "search_tag3") \
        .rdd.repartition(200).map(lambda x: (x[0], float(x[1]), float(x[2]), x[3], x[4], x[5],
                                             app_list_func(x[6], app_list_map), app_list_func(x[7], leve2_map),
                                             app_list_func(x[8], leve3_map), app_list_func(x[9], leve2_map),
                                             app_list_func(x[10], leve2_map), app_list_func(x[11], leve2_map),
                                             app_list_func(x[12], leve2_map), app_list_func(x[13], leve2_map),
                                             app_list_func(x[14], leve2_map), app_list_func(x[15], leve2_map),
                                             [value_map.get(date, 1), value_map.get(x[16], 2),
                                              value_map.get(x[17], 3), value_map.get(x[18], 4),
                                              value_map.get(x[19], 5), value_map.get(x[20], 6),
                                              value_map.get(x[21], 7), value_map.get(x[22], 8),
                                              value_map.get(x[23], 9), value_map.get(x[24], 10),
                                              value_map.get(x[25], 11), value_map.get(x[26], 12),
                                              value_map.get(x[27], 13), value_map.get(x[28], 14),
                                              value_map.get(x[29], 15)], app_list_func(x[30], leve2_map),
                                             app_list_func(x[31], leve3_map)))



    rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER)
张彦钊's avatar
张彦钊 committed
337
    println(rdd.count())
张彦钊's avatar
张彦钊 committed
338

张彦钊's avatar
张彦钊 committed
339 340 341 342
    native_pre = spark.createDataFrame(rdd.filter(lambda x:x[0] == 0).map(lambda x:(x[3],x[4],x[5])))\
        .toDF("city","uid","cid_id")
    print("native csv")
    native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
张彦钊's avatar
张彦钊 committed
343 344 345 346
    spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
                          .map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
                                          x[12], x[13], x[14], x[15], x[16], x[17], x[18]))) \
        .toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
张彦钊's avatar
张彦钊 committed
347
              "tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
张彦钊's avatar
张彦钊 committed
348 349 350 351 352
        .repartition(1).write.format("tfrecords").save(path=path+"native/", mode="overwrite")
    print("native tfrecord done")
    h = time.time()
    print((h-f)/60)

张彦钊's avatar
张彦钊 committed
353 354 355 356
    nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[0] == 1).map(lambda x: (x[3], x[4], x[5]))) \
        .toDF("city", "uid", "cid_id")
    print("nearby csv")
    nearby_pre.toPandas().to_csv(local_path + "nearby.csv", header=True)
张彦钊's avatar
张彦钊 committed
357 358 359 360 361

    spark.createDataFrame(rdd.filter(lambda x: x[0] == 1)
                          .map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
                                          x[12], x[13], x[14], x[15], x[16], x[17], x[18]))) \
        .toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
张彦钊's avatar
张彦钊 committed
362
              "tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
张彦钊's avatar
张彦钊 committed
363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394
        .repartition(1).write.format("tfrecords").save(path=path + "nearby/", mode="overwrite")
    print("nearby tfrecord done")


if __name__ == '__main__':
    sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
        .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
        .set("spark.tispark.plan.allow_index_double_read", "false") \
        .set("spark.tispark.plan.allow_index_read", "true") \
        .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
        .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
        .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")

    spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    ti = pti.TiContext(spark)
    ti.tidbMapDatabase("jerry_test")
    ti.tidbMapDatabase("eagle")
    spark.sparkContext.setLogLevel("WARN")
    path = "hdfs:///strategy/esmm/"
    local_path = "/home/gmuser/esmm/"

    validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
    get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)

    spark.stop()