Commit ea33c974 authored by 张彦钊's avatar 张彦钊

change test file

parent 5d0090c6
......@@ -5,6 +5,7 @@ import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import time
def app_list_func(x,l):
......@@ -19,7 +20,12 @@ def app_list_func(x,l):
def multi_hot(df,column,n):
a = time.time()
v = df.select(column).distinct().rdd.map(lambda x: x[0]).collect()
b = time.time()
print(column)
print("cost time 分钟")
print((b-a)/60)
app_list_value = [str(i).split(",") for i in v]
app_list_unique = []
for i in app_list_value:
......@@ -79,46 +85,59 @@ def feature_engineer():
unique_values = []
for i in features:
a = time.time()
unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect())
b = time.time()
print(i)
print((b-a)/60)
temp = list(range(2 + apps_number + level2_number + level3_number,
2 + apps_number + level2_number + level3_number + len(unique_values)))
value_map = dict(zip(unique_values, temp))
c = time.time()
rdd = df.select("stat_date","y", "z","app_list","level2_ids","level3_ids",
"tag1","tag2","tag3","tag4","tag5","tag6","tag7",
"ucity_id", "ccity_name","device_type", "manufacturer", "channel", "top", "time",
"hospital_id","treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time","recover_time").rdd
rdd.persist()
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
train = rdd.filter(lambda x: x[0] != validate_date) \
.map(lambda x: (float(x[1]),float(x[2]),app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
"maintain_time","recover_time").rdd.repartition(200).map(lambda x: (x[0],float(x[1]),float(x[2]),app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map),app_list_func(x[7], leve2_map),
app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map),app_list_func(x[10], leve2_map),
app_list_func(x[11], leve2_map),app_list_func(x[12], leve2_map),
[value_map[x[0]], value_map[x[13]],value_map[x[14]], value_map[x[15]], value_map[x[16]],
value_map[x[17]],value_map[x[18]], value_map[x[19]], value_map[x[20]],value_map[x[21]],
value_map[x[22]], value_map[x[23]], value_map[x[24]],value_map[x[25]],value_map[x[26]]]))
d = time.time()
print("rdd")
print((d-c)/60)
rdd.persist()
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
# train = rdd.filter(lambda x: x[0] != validate_date) \
# .map(lambda x: (float(x[1]),float(x[2]),app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
# app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map),app_list_func(x[7], leve2_map),
# app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map),app_list_func(x[10], leve2_map),
# app_list_func(x[11], leve2_map),app_list_func(x[12], leve2_map),
# [value_map[x[0]], value_map[x[13]],value_map[x[14]], value_map[x[15]], value_map[x[16]],
# value_map[x[17]],value_map[x[18]], value_map[x[19]], value_map[x[20]],value_map[x[21]],
# value_map[x[22]], value_map[x[23]], value_map[x[24]],value_map[x[25]],value_map[x[26]]]))
train = rdd.filter(lambda x: x[0] != validate_date).map(lambda x:(x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],
x[10],x[11],x[12],x[13]))
f = time.time()
spark.createDataFrame(train).toDF("y","z","app_list","level2_list","level3_list",
"tag1_list","tag2_list","tag3_list","tag4_list",
"tag5_list","tag6_list","tag7_list","ids") \
.coalesce(1).write.format("tfrecords").save(path=path + "tr/", mode="overwrite")
.write.format("tfrecords").save(path=path + "tr/", mode="overwrite")
h = time.time()
print("train tfrecord done")
print((h-f)/60)
test = rdd.filter(lambda x: x[0] == validate_date) \
.map(lambda x: (float(x[1]), float(x[2]), app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map), app_list_func(x[7], leve2_map),
app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map), app_list_func(x[10], leve2_map),
app_list_func(x[11], leve2_map), app_list_func(x[12], leve2_map),
[value_map[x[0]], value_map[x[13]], value_map[x[14]], value_map[x[15]], value_map[x[16]],
value_map[x[17]], value_map[x[18]], value_map[x[19]], value_map[x[20]], value_map[x[21]],
value_map[x[22]], value_map[x[23]], value_map[x[24]], value_map[x[25]], value_map[x[26]]]))
test = rdd.filter(lambda x: x[0] == validate_date).map(lambda x:(x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],
x[10],x[11],x[12],x[13]))
spark.createDataFrame(test).toDF("y", "z", "app_list", "level2_list", "level3_list",
"tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids") \
.coalesce(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite")
.write.format("tfrecords").save(path=path+"va/", mode="overwrite")
print("va tfrecord done")
......@@ -156,12 +175,13 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
df = spark.sql(sql)
df = df.na.fill(dict(zip(features, features)))
c = time.time()
rdd = df.select("label", "y", "z","ucity_id","device_id","cid_id","app_list", "level2_ids", "level3_ids",
"tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
"hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time", "recover_time") \
.rdd.map(lambda x: (x[0],float(x[1]),float(x[2]),x[3],x[4],x[5],
.rdd.repartition(200).map(lambda x: (x[0],float(x[1]),float(x[2]),x[3],x[4],x[5],
app_list_func(x[6], app_list_map),app_list_func(x[7], leve2_map),
app_list_func(x[8], leve3_map), app_list_func(x[9], leve2_map),
app_list_func(x[10], leve2_map),app_list_func(x[11], leve2_map),
......@@ -177,7 +197,9 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
value_map.get(x[29], 299985)
]))
rdd.persist()
d = time.time()
print("rdd")
print((d-c)/60)
native_pre = spark.createDataFrame(rdd.filter(lambda x:x[0] == 0).map(lambda x:(x[3],x[4],x[5])))\
.toDF("city","uid","cid_id")
print("native csv")
......@@ -186,12 +208,15 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
# native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"native/",header = 'true')
# 预测的tfrecord必须写成一个文件,这样可以摆保证顺序
f = time.time()
spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
.map(lambda x: (x[1],x[2],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16]))) \
.toDF("y","z","app_list", "level2_list", "level3_list","tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids").coalesce(1).write.format("tfrecords") \
.save(path=path+"native/", mode="overwrite")
print("native tfrecord done")
h = time.time()
print((h-f)/60)
native_pre = spark.createDataFrame(rdd.filter(lambda x: x[0] == 1).map(lambda x: (x[3], x[4], x[5]))) \
.toDF("city", "uid", "cid_id")
......
......@@ -128,31 +128,32 @@ def con_sql(db,sql):
if __name__ == '__main__':
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
.set("spark.tispark.plan.allow_index_double_read", "false") \
.set("spark.tispark.plan.allow_index_read", "true") \
.set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
.set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
.set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
# ti.tidbMapDatabase("eagle")
spark.sparkContext.setLogLevel("WARN")
path = "hdfs:///strategy/esmm/"
local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map = feature()
get_predict(validate_date, value_map, app_list_map)
# df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord")
# df.show(1)
# print("aa")
# print("aa")
# df = spark.read.format("tfrecords").load("/strategy/esmm/va/part-r-00000")
# df.show(1)
# sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
# .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
# .set("spark.tispark.plan.allow_index_double_read", "false") \
# .set("spark.tispark.plan.allow_index_read", "true") \
# .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
# .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
# .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
#
# spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
# ti = pti.TiContext(spark)
# ti.tidbMapDatabase("jerry_test")
# # ti.tidbMapDatabase("eagle")
# spark.sparkContext.setLogLevel("WARN")
# path = "hdfs:///strategy/esmm/"
# local_path = "/home/gmuser/esmm/"
#
# validate_date, value_map, app_list_map = feature()
# get_predict(validate_date, value_map, app_list_map)
spark = SparkSession.builder.getOrCreate()
b = [("a", 1), ("a", 1), ("b", 3), ("a", 2)]
rdd = spark.sparkContext.parallelize(b)
df = spark.createDataFrame(rdd).toDF("id","n")
df.show()
t = df.select("id").rdd.map(lambda x:x[0]).collect()
print(t)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment