# -*- coding: utf-8 -*- import pymysql from pyspark.conf import SparkConf import pytispark.pytispark as pti # from pyspark.sql import SQLContext from pyspark.sql import SparkSession import datetime import pandas as pd def app_list_func(x,l): b = str(x).split(",") e = [] for i in b: if i in l.keys(): e.append(l[i]) else: e.append(0) return e def multi_hot(df,column,n): v = df.select(column).distinct().rdd.map(lambda x: x[0]).collect() app_list_value = [str(i).split(",") for i in v] app_list_unique = [] for i in app_list_value: app_list_unique.extend(i) app_list_unique = list(set(app_list_unique)) number = len(app_list_unique) app_list_map = dict(zip(app_list_unique, list(range(n, number + n)))) return number,app_list_map def feature(): db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') sql = "select max(stat_date) from esmm_train_data" validate_date = con_sql(db, sql)[0].values.tolist()[0] print("validate_date:" + validate_date) temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d") start = (temp - datetime.timedelta(days=2)).strftime("%Y-%m-%d") print(start) sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids " \ "from jerry_test.esmm_train_data e " \ "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \ "where e.stat_date >= '{}'".format(start) df = spark.sql(sql) features = ["ucity_id","stat_date"] df = df.na.fill(dict(zip(features,features))) apps_number, app_list_map = multi_hot(df,"level2_ids",1) unique_values = [] for i in features: unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect()) temp = list(range(2 + apps_number, 2 + apps_number + len(unique_values))) value_map = dict(zip(unique_values, temp)) rdd = df.select("level2_ids","stat_date","ucity_id","y","z").rdd rdd.persist() # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集 train = rdd.filter(lambda x: x[1]!= validate_date)\ .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4]))) test = rdd.filter(lambda x: x[1]== validate_date)\ .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4]))) spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\ .repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite") print("va write done") spark.createDataFrame(train).toDF("level2_ids","ids","y","z") \ .repartition(1).write.format("tfrecords").save(path=path+"tr/", mode="overwrite") print("done") rdd.unpersist() return validate_date,value_map,app_list_map def get_predict(date,value_map,app_list_map): sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.device_id,e.cid_id from esmm_pre_data e " \ "left join diary_feat feat on e.cid_id = feat.diary_id limit 50000" features = ["ucity_id"] df = spark.sql(sql) df = df.na.fill(dict(zip(features, features))) rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \ .rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],float(x[5]),float(x[6]), [value_map.get(x[1], 299999),value_map.get(date, 299998)])) rdd.persist() native_pre = spark.createDataFrame(rdd.filter(lambda x:x[4] == 0).map(lambda x:(x[1],x[2],x[3])))\ .toDF("city","uid","cid_id") print("native") native_pre.toPandas().to_csv(local_path+"native.csv", header=True) # TODO 写成csv文件改成下面这样 # native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"native/",header = 'true') # 预测的tfrecord必须写成一个文件,这样可以摆保证顺序 spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \ .toDF("level2_ids","y","z","ids").coalesce(1).write.format("tfrecords") \ .save(path=path+"native/", mode="overwrite") nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x:(x[1],x[2],x[3]))) \ .toDF("city", "uid", "cid_id") print("nearby") nearby_pre.toPandas().to_csv(local_path+"nearby.csv", header=True) spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x: (x[0], x[5], x[6], x[7]))) \ .toDF("level2_ids","y","z","ids").coalesce(1).write.format("tfrecords") \ .save(path=path+"nearby/", mode="overwrite") rdd.unpersist() def con_sql(db,sql): cursor = db.cursor() cursor.execute(sql) result = cursor.fetchall() df = pd.DataFrame(list(result)) db.close() return df if __name__ == '__main__': # sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \ # .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \ # .set("spark.tispark.plan.allow_index_double_read", "false") \ # .set("spark.tispark.plan.allow_index_read", "true") \ # .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \ # .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\ # .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy") # # spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate() # ti = pti.TiContext(spark) # ti.tidbMapDatabase("jerry_test") # # ti.tidbMapDatabase("eagle") # spark.sparkContext.setLogLevel("WARN") # path = "hdfs:///strategy/esmm/" # local_path = "/home/gmuser/esmm/" # # validate_date, value_map, app_list_map = feature() # get_predict(validate_date, value_map, app_list_map) # [path + "tr/part-r-00000"] import subprocess # spark = SparkSession.builder.getOrCreate() # b = [("a", 1), ("a", 1), ("b", 3), ("a", 2)] # rdd = spark.sparkContext.parallelize(b) # df = spark.createDataFrame(rdd).toDF("id", "n") # df.show() # df.createOrReplaceTempView("df") # t = spark.sql("select id from df").map() # print(t) db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') sql = "select device_id from esmm_train_data limit 10" cursor = db.cursor() cursor.execute(sql) result = cursor.fetchall() print(result) a = list(set([i[0] for i in result])) print(a)