Commit 37030b82 authored by 张彦钊's avatar 张彦钊

add test file

parent 15ef275a
...@@ -2,12 +2,11 @@ ...@@ -2,12 +2,11 @@
import pymysql import pymysql
from pyspark.conf import SparkConf from pyspark.conf import SparkConf
import pytispark.pytispark as pti import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
from pyspark.sql import SparkSession from pyspark.sql import SparkSession
import datetime import datetime
import pandas as pd import pandas as pd
import subprocess import time
import tensorflow as tf from pyspark import StorageLevel
def app_list_func(x,l): def app_list_func(x,l):
...@@ -21,8 +20,11 @@ def app_list_func(x,l): ...@@ -21,8 +20,11 @@ def app_list_func(x,l):
return e return e
def multi_hot(df,column,n): def get_list(db,sql,n):
v = df.select(column).distinct().rdd.map(lambda x: x[0]).collect() cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchall()
v = list(set([i[0] for i in result]))
app_list_value = [str(i).split(",") for i in v] app_list_value = [str(i).split(",") for i in v]
app_list_unique = [] app_list_unique = []
for i in app_list_value: for i in app_list_value:
...@@ -30,182 +32,359 @@ def multi_hot(df,column,n): ...@@ -30,182 +32,359 @@ def multi_hot(df,column,n):
app_list_unique = list(set(app_list_unique)) app_list_unique = list(set(app_list_unique))
number = len(app_list_unique) number = len(app_list_unique)
app_list_map = dict(zip(app_list_unique, list(range(n, number + n)))) app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
return number,app_list_map db.close()
return number, app_list_map
def feature(): def get_map():
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data" sql = "select app_list from device_app_list"
validate_date = con_sql(db, sql)[0].values.tolist()[0] a = time.time()
print("validate_date:" + validate_date) apps_number, app_list_map = get_list(db,sql,16)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d") print("applist")
start = (temp - datetime.timedelta(days=2)).strftime("%Y-%m-%d") print((time.time()-a)/60)
print(start) db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select level2_ids from diary_feat"
sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids " \ b = time.time()
"from jerry_test.esmm_train_data e " \ leve2_number, leve2_map = get_list(db, sql, 16+apps_number)
"left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \ print("leve2")
"where e.stat_date >= '{}'".format(start) print((time.time() - b) / 60)
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
df = spark.sql(sql) sql = "select level3_ids from diary_feat"
c = time.time()
features = ["ucity_id","stat_date"] leve3_number, leve3_map = get_list(db, sql, 16+leve2_number+apps_number)
print((time.time() - c) / 60)
return apps_number, app_list_map,leve2_number, leve2_map,leve3_number, leve3_map
df = df.na.fill(dict(zip(features,features)))
apps_number, app_list_map = multi_hot(df,"level2_ids",1) def get_unique(db,sql):
cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchall()
v = list(set([i[0] for i in result]))
db.close()
print(sql)
print(len(v))
return v
unique_values = [] def con_sql(db,sql):
for i in features: cursor = db.cursor()
unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect()) cursor.execute(sql)
temp = list(range(2 + apps_number, result = cursor.fetchall()
2 + apps_number + len(unique_values))) df = pd.DataFrame(list(result))
value_map = dict(zip(unique_values, temp)) db.close()
return df
rdd = df.select("level2_ids","stat_date","ucity_id","y","z").rdd
rdd.persist()
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
train = rdd.filter(lambda x: x[1]!= validate_date)\ def get_pre_number():
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4]))) db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
test = rdd.filter(lambda x: x[1]== validate_date)\ sql = "select count(*) from esmm_pre_data"
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4]))) cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchone()[0]
print("预测集数量:")
print(result)
db.close()
spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\
.repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite")
print("va write done") def feature_engineer():
spark.createDataFrame(train).toDF("level2_ids","ids","y","z") \ apps_number, app_list_map, level2_number, leve2_map, level3_number, leve3_map = get_map()
.repartition(1).write.format("tfrecords").save(path=path+"tr/", mode="overwrite") app_list_map["app_list"] = 16
leve3_map["level3_ids"] = 17
leve3_map["search_tag3"] = 18
leve2_map["level2_ids"] = 19
leve2_map["tag1"] = 20
leve2_map["tag2"] = 21
leve2_map["tag3"] = 22
leve2_map["tag4"] = 23
leve2_map["tag5"] = 24
leve2_map["tag6"] = 25
leve2_map["tag7"] = 26
leve2_map["search_tag2"] = 27
print("done") unique_values = []
rdd.unpersist() db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct stat_date from esmm_train_data_dwell"
unique_values.extend(get_unique(db,sql))
return validate_date,value_map,app_list_map db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct ucity_id from esmm_train_data_dwell"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct ccity_name from esmm_train_data_dwell"
unique_values.extend(get_unique(db, sql))
def get_predict(date,value_map,app_list_map): db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.device_id,e.cid_id from esmm_pre_data e " \ sql = "select distinct time from cid_time_cut"
"left join diary_feat feat on e.cid_id = feat.diary_id limit 50000" unique_values.extend(get_unique(db, sql))
features = ["ucity_id"] db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
df = spark.sql(sql) sql = "select distinct device_type from user_feature"
df = df.na.fill(dict(zip(features, features))) unique_values.extend(get_unique(db, sql))
rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \
.rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],float(x[5]),float(x[6]),
[value_map.get(x[1], 299999),value_map.get(date, 299998)]))
rdd.persist() db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct manufacturer from user_feature"
unique_values.extend(get_unique(db, sql))
native_pre = spark.createDataFrame(rdd.filter(lambda x:x[4] == 0).map(lambda x:(x[1],x[2],x[3])))\ db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
.toDF("city","uid","cid_id") sql = "select distinct channel from user_feature"
print("native") unique_values.extend(get_unique(db, sql))
native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
# TODO 写成csv文件改成下面这样
# native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"native/",header = 'true')
# 预测的tfrecord必须写成一个文件,这样可以摆保证顺序 db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \ sql = "select distinct top from cid_type_top"
.toDF("level2_ids","y","z","ids").coalesce(1).write.format("tfrecords") \ unique_values.extend(get_unique(db, sql))
.save(path=path+"native/", mode="overwrite")
nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x:(x[1],x[2],x[3]))) \ db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
.toDF("city", "uid", "cid_id") sql = "select distinct price_min from knowledge"
print("nearby") unique_values.extend(get_unique(db, sql))
nearby_pre.toPandas().to_csv(local_path+"nearby.csv", header=True)
spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x: (x[0], x[5], x[6], x[7]))) \
.toDF("level2_ids","y","z","ids").coalesce(1).write.format("tfrecords") \
.save(path=path+"nearby/", mode="overwrite")
rdd.unpersist() db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct treatment_method from knowledge"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct price_max from knowledge"
unique_values.extend(get_unique(db, sql))
def con_sql(db,sql): db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
cursor = db.cursor() sql = "select distinct treatment_time from knowledge"
cursor.execute(sql) unique_values.extend(get_unique(db, sql))
result = cursor.fetchall()
df = pd.DataFrame(list(result))
db.close()
return df
def get_filename(dir_in): db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
pre_add = "hdfs://172.16.32.4:8020/strategy/esmm/" sql = "select distinct maintain_time from knowledge"
x = [] unique_values.extend(get_unique(db, sql))
for i in range(0,200):
if i < 10:
t = pre_add+dir_in+"/part-r-0000"+str(i)
x.append(t)
elif 10 <= i < 100:
t = pre_add + dir_in + "/part-r-000" + str(i)
x.append(t)
elif 100 <= i < 200:
t = pre_add + dir_in + "/part-r-00" + str(i)
x.append(t)
return x
def get_hdfs(dir_in):
pre_path = "hdfs://172.16.32.4:8020"
args = "hdfs dfs -ls " + dir_in + " | awk '{print $8}'"
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
s_output, s_err = proc.communicate()
all_dart_dirs = s_output.split()
a = []
for i in all_dart_dirs:
b = str(i).split("/")[4]
if b[:4] == "part":
tmp = pre_path + str(i)[2:-1]
a.append(tmp)
return a
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct recover_time from knowledge"
unique_values.extend(get_unique(db, sql))
def get_pre_number():
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select count(*) from esmm_pre_data" sql = "select max(stat_date) from esmm_train_data_dwell"
cursor = db.cursor() validate_date = con_sql(db, sql)[0].values.tolist()[0]
cursor.execute(sql) print("validate_date:" + validate_date)
result = cursor.fetchone()[0] temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
print("预测集数量:") start = (temp - datetime.timedelta(days=100)).strftime("%Y-%m-%d")
print(result) print(start)
db.close()
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC')
sql = "select distinct doctor.hospital_id from jerry_test.esmm_train_data_dwell e " \
"left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
"left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
"where e.stat_date >= '{}'".format(start)
unique_values.extend(get_unique(db, sql))
features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time",
"app_list", "level3_ids", "level2_ids", "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"search_tag2", "search_tag3"]
unique_values.extend(features)
print("unique_values length")
print(len(unique_values))
print("特征维度:")
print(apps_number + level2_number + level3_number + len(unique_values))
temp = list(range(28 + apps_number + level2_number + level3_number,
28 + apps_number + level2_number + level3_number + len(unique_values)))
value_map = dict(zip(unique_values, temp))
if __name__ == '__main__': sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
# get_pre() "u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \
# sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \ "wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
# .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \ "ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7,doris.search_tag2,doris.search_tag3," \
# .set("spark.tispark.plan.allow_index_double_read", "false") \ "k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
# .set("spark.tispark.plan.allow_index_read", "true") \ "from jerry_test.esmm_train_data_dwell e left join jerry_test.user_feature u on e.device_id = u.device_id " \
# .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \ "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
# .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\ "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
# .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy") "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
"left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
"left join jerry_test.knowledge k on feat.level2 = k.level2_id " \
"left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
"left join jerry_test.question_tag question on e.device_id = question.device_id " \
"left join jerry_test.search_tag search on e.device_id = search.device_id " \
"left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
"left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
"left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
"left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
"left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
"left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
"left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \
"where e.stat_date >= '{}'".format(start)
# df = spark.sql(sql)
# #
# spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate() # df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
# ti = pti.TiContext(spark) # "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids",
# ti.tidbMapDatabase("jerry_test") # "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7"])
# # ti.tidbMapDatabase("eagle")
# spark.sparkContext.setLogLevel("WARN")
# path = "hdfs:///strategy/esmm/"
# local_path = "/home/gmuser/esmm/"
# #
# validate_date, value_map, app_list_map = feature() # df = df.na.fill(dict(zip(features, features)))
# get_predict(validate_date, value_map, app_list_map)
# #
# rdd = df.select("stat_date", "y", "z", "app_list", "level2_ids", "level3_ids",
# "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
# "ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
# "hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
# "maintain_time", "recover_time", "search_tag2", "search_tag3")\
# .rdd.repartition(200).map(
# lambda x: (x[0], float(x[1]), float(x[2]), app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
# app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map), app_list_func(x[7], leve2_map),
# app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map), app_list_func(x[10], leve2_map),
# app_list_func(x[11], leve2_map), app_list_func(x[12], leve2_map),
# [value_map.get(x[0], 1), value_map.get(x[13], 2), value_map.get(x[14], 3), value_map.get(x[15], 4),
# value_map.get(x[16], 5), value_map.get(x[17], 6), value_map.get(x[18], 7), value_map.get(x[19], 8),
# value_map.get(x[20], 9), value_map.get(x[21], 10),
# value_map.get(x[22], 11), value_map.get(x[23], 12), value_map.get(x[24], 13),
# value_map.get(x[25], 14), value_map.get(x[26], 15)],
# app_list_func(x[27], leve2_map), app_list_func(x[28], leve3_map)
# ))
# #
# spark = SparkSession.builder.getOrCreate()
# #
# b = [("a", 1), ("a", 1), ("b", 3), ("a", 2)] # rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER)
# rdd = spark.sparkContext.parallelize(b) #
# df = spark.createDataFrame(rdd).toDF("id", "n") # # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
# df.show() #
# df.createOrReplaceTempView("df") # train = rdd.map(
# t = spark.sql("select id from df").map() # lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9],
# x[10], x[11], x[12], x[13], x[14], x[15]))
# f = time.time()
# spark.createDataFrame(train).toDF("y", "z", "app_list", "level2_list", "level3_list",
# "tag1_list", "tag2_list", "tag3_list", "tag4_list",
# "tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
# .repartition(1).write.format("tfrecords").save(path=path + "tr/", mode="overwrite")
# h = time.time()
# print("train tfrecord done")
# print((h - f) / 60)
#
# print("训练集样本总量:")
# print(rdd.count())
#
# get_pre_number()
#
# test = rdd.filter(lambda x: x[0] == validate_date).map(
# lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9],
# x[10], x[11], x[12], x[13], x[14], x[15]))
#
# spark.createDataFrame(test).toDF("y", "z", "app_list", "level2_list", "level3_list",
# "tag1_list", "tag2_list", "tag3_list", "tag4_list",
# "tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
# .repartition(1).write.format("tfrecords").save(path=path + "va/", mode="overwrite")
#
# print("va tfrecord done")
#
# rdd.unpersist()
return validate_date, value_map, app_list_map, leve2_map, leve3_map
def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
"u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
"dl.app_list,e.hospital_id,feat.level3_ids," \
"wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
"ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7,doris.search_tag2,doris.search_tag3," \
"k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
"from jerry_test.esmm_pre_data e " \
"left join jerry_test.user_feature u on e.device_id = u.device_id " \
"left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
"left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
"left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
"left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
"left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
"left join jerry_test.question_tag question on e.device_id = question.device_id " \
"left join jerry_test.search_tag search on e.device_id = search.device_id " \
"left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
"left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
"left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
"left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
"left join jerry_test.knowledge k on feat.level2 = k.level2_id " \
"left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \
"limit 100000"
features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time",
"app_list", "level3_ids", "level2_ids", "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"search_tag2", "search_tag3"]
df = spark.sql(sql)
df = df.drop_duplicates(["ucity_id", "device_id", "cid_id"])
df = df.na.fill(dict(zip(features, features)))
f = time.time()
rdd = df.select("label", "y", "z", "ucity_id", "device_id", "cid_id", "app_list", "level2_ids", "level3_ids",
"tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
"hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time", "recover_time", "search_tag2", "search_tag3") \
.rdd.repartition(200).map(lambda x: (x[0], float(x[1]), float(x[2]), x[3], x[4], x[5],
app_list_func(x[6], app_list_map), app_list_func(x[7], leve2_map),
app_list_func(x[8], leve3_map), app_list_func(x[9], leve2_map),
app_list_func(x[10], leve2_map), app_list_func(x[11], leve2_map),
app_list_func(x[12], leve2_map), app_list_func(x[13], leve2_map),
app_list_func(x[14], leve2_map), app_list_func(x[15], leve2_map),
[value_map.get(date, 1), value_map.get(x[16], 2),
value_map.get(x[17], 3), value_map.get(x[18], 4),
value_map.get(x[19], 5), value_map.get(x[20], 6),
value_map.get(x[21], 7), value_map.get(x[22], 8),
value_map.get(x[23], 9), value_map.get(x[24], 10),
value_map.get(x[25], 11), value_map.get(x[26], 12),
value_map.get(x[27], 13), value_map.get(x[28], 14),
value_map.get(x[29], 15)], app_list_func(x[30], leve2_map),
app_list_func(x[31], leve3_map)))
rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER)
print(rdd.count())
# native_pre = spark.createDataFrame(rdd.filter(lambda x:x[0] == 0).map(lambda x:(x[3],x[4],x[5])))\
# .toDF("city","uid","cid_id")
# print("native csv")
# native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
x[12], x[13], x[14], x[15], x[16], x[17], x[18]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
.repartition(1).write.format("tfrecords").save(path=path+"test_native/", mode="overwrite")
print("native tfrecord done")
h = time.time()
print((h-f)/60)
# nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[0] == 1).map(lambda x: (x[3], x[4], x[5]))) \
# .toDF("city", "uid", "cid_id")
# print("nearby csv")
# nearby_pre.toPandas().to_csv(local_path + "nearby.csv", header=True)
spark.createDataFrame(rdd.filter(lambda x: x[0] == 1)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
x[12], x[13], x[14], x[15], x[16], x[17], x[18]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list") \
.repartition(1).write.format("tfrecords").save(path=path + "test_nearby/", mode="overwrite")
print("nearby tfrecord done")
if __name__ == '__main__':
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
.set("spark.tispark.plan.allow_index_double_read", "false") \
.set("spark.tispark.plan.allow_index_read", "true") \
.set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
.set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
.set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
ti.tidbMapDatabase("eagle")
spark.sparkContext.setLogLevel("WARN")
path = "hdfs:///strategy/esmm/"
local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
spark.stop()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment