multi.py 6.83 KB
# -*- coding: utf-8 -*-
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import subprocess


def app_list_func(x,l):
    b = str(x).split(",")
    e = []
    for i in b:
        if i in l.keys():
            e.append(l[i])
        else:
            e.append(0)
    return e


def multi_hot(df,column,n):
    v = df.select(column).distinct().rdd.map(lambda x: x[0]).collect()
    app_list_value = [str(i).split(",") for i in v]
    app_list_unique = []
    for i in app_list_value:
        app_list_unique.extend(i)
    app_list_unique = list(set(app_list_unique))
    number = len(app_list_unique)
    app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
    return number,app_list_map


def feature():
    db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
    sql = "select max(stat_date) from esmm_train_data"
    validate_date = con_sql(db, sql)[0].values.tolist()[0]
    print("validate_date:" + validate_date)
    temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
    start = (temp - datetime.timedelta(days=2)).strftime("%Y-%m-%d")
    print(start)

    sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids " \
          "from jerry_test.esmm_train_data e " \
          "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
          "where e.stat_date >= '{}'".format(start)

    df = spark.sql(sql)

    features = ["ucity_id","stat_date"]

    df = df.na.fill(dict(zip(features,features)))

    apps_number, app_list_map = multi_hot(df,"level2_ids",1)

    unique_values = []
    for i in features:
        unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect())
    temp = list(range(2 + apps_number,
                      2 + apps_number + len(unique_values)))
    value_map = dict(zip(unique_values, temp))

    rdd = df.select("level2_ids","stat_date","ucity_id","y","z").rdd
    rdd.persist()
    # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集

    train = rdd.filter(lambda x: x[1]!= validate_date)\
        .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4])))
    test = rdd.filter(lambda x: x[1]== validate_date)\
        .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4])))

    spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\
        .repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite")

    print("va write done")
    spark.createDataFrame(train).toDF("level2_ids","ids","y","z") \
        .repartition(1).write.format("tfrecords").save(path=path+"tr/", mode="overwrite")

    print("done")
    rdd.unpersist()

    return validate_date,value_map,app_list_map


def get_predict(date,value_map,app_list_map):
    sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.device_id,e.cid_id from esmm_pre_data e " \
          "left join diary_feat feat on e.cid_id = feat.diary_id limit 50000"

    features = ["ucity_id"]
    df = spark.sql(sql)
    df = df.na.fill(dict(zip(features, features)))
    rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \
        .rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],float(x[5]),float(x[6]),
                            [value_map.get(x[1], 299999),value_map.get(date, 299998)]))

    rdd.persist()

    native_pre = spark.createDataFrame(rdd.filter(lambda x:x[4] == 0).map(lambda x:(x[1],x[2],x[3])))\
        .toDF("city","uid","cid_id")
    print("native")
    native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
# TODO 写成csv文件改成下面这样
    # native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"native/",header = 'true')

    # 预测的tfrecord必须写成一个文件,这样可以摆保证顺序
    spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \
        .toDF("level2_ids","y","z","ids").coalesce(1).write.format("tfrecords") \
            .save(path=path+"native/", mode="overwrite")

    nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x:(x[1],x[2],x[3]))) \
        .toDF("city", "uid", "cid_id")
    print("nearby")
    nearby_pre.toPandas().to_csv(local_path+"nearby.csv", header=True)
    spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x: (x[0], x[5], x[6], x[7]))) \
        .toDF("level2_ids","y","z","ids").coalesce(1).write.format("tfrecords") \
        .save(path=path+"nearby/", mode="overwrite")

    rdd.unpersist()


def con_sql(db,sql):
    cursor = db.cursor()
    cursor.execute(sql)
    result = cursor.fetchall()
    df = pd.DataFrame(list(result))
    db.close()
    return df

def get_filename(dir_in):
    pre_add = "hdfs://172.16.32.4:8020/strategy/esmm/"
    x = []
    for i in range(0,200):
        if i < 10:
            t = pre_add+dir_in+"/part-r-0000"+str(i)
            x.append(t)
        elif 10 <= i < 100:
            t = pre_add + dir_in + "/part-r-000" + str(i)
            x.append(t)
        elif 100 <= i < 200:
            t = pre_add + dir_in + "/part-r-00" + str(i)
            x.append(t)
    return x

if __name__ == '__main__':
    # sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
    #     .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
    #     .set("spark.tispark.plan.allow_index_double_read", "false") \
    #     .set("spark.tispark.plan.allow_index_read", "true") \
    #     .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
    #     .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
    #     .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
    #
    # spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
    # ti = pti.TiContext(spark)
    # ti.tidbMapDatabase("jerry_test")
    # # ti.tidbMapDatabase("eagle")
    # spark.sparkContext.setLogLevel("WARN")
    # path = "hdfs:///strategy/esmm/"
    # local_path = "/home/gmuser/esmm/"
    #
    # validate_date, value_map, app_list_map = feature()
    # get_predict(validate_date, value_map, app_list_map)

    # [path + "tr/part-r-00000"]

    # spark = SparkSession.builder.getOrCreate()
    # b = [("a", 1), ("a", 1), ("b", 3), ("a", 2)]
    # rdd = spark.sparkContext.parallelize(b)
    # df = spark.createDataFrame(rdd).toDF("id", "n")
    # df.show()
    # df.createOrReplaceTempView("df")
    # t = spark.sql("select id from df").map()
    import glob
    import random

    tr_files = glob.glob("/home/gmuser/test/*")
    random.shuffle(tr_files)
    print("tr_files:", tr_files)