Commit 7651d006 authored by 张彦钊's avatar 张彦钊

change test file

parent cdf5bc89
......@@ -71,11 +71,11 @@ def feature_engineer():
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4]))
spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\
.repartition(1).write.format("tfrecords").option("recordType", "SequenceExample").save(path=path+"va/", mode="overwrite")
.repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite")
print("va write done")
spark.createDataFrame(train).toDF("level2_ids","ids","y","z") \
.repartition(1).write.format("tfrecords").option("recordType", "SequenceExample").save(path=path+"tr/", mode="overwrite")
.repartition(1).write.format("tfrecords").save(path=path+"tr/", mode="overwrite")
print("done")
rdd.unpersist()
......@@ -103,7 +103,7 @@ def get_predict(date,value_map,app_list_map):
spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \
.toDF("level2_ids","y","z","ids").repartition(1).write.format("tfrecords") \
.option("recordType", "SequenceExample").save(path=path+"native/", mode="overwrite")
.save(path=path+"native/", mode="overwrite")
nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x:(x[1],x[2],x[3]))) \
.toDF("city", "uid", "cid_id")
......@@ -111,7 +111,7 @@ def get_predict(date,value_map,app_list_map):
nearby_pre.toPandas().to_csv(local_path+"nearby.csv", header=True)
spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x: (x[0], x[5], x[6], x[7]))) \
.toDF("level2_ids","y","z","ids").repartition(1).write.format("tfrecords") \
.option("recordType", "SequenceExample").save(path=path+"nearby/", mode="overwrite")
.save(path=path+"nearby/", mode="overwrite")
rdd.unpersist()
......@@ -135,15 +135,15 @@ if __name__ == '__main__':
.set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
# ti = pti.TiContext(spark)
# ti.tidbMapDatabase("jerry_test")
# # ti.tidbMapDatabase("eagle")
# spark.sparkContext.setLogLevel("WARN")
# path = "hdfs:///strategy/esmm/"
# local_path = "/home/gmuser/esmm/"
#
# validate_date, value_map, app_list_map = feature_engineer()
# get_predict(validate_date, value_map, app_list_map)
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
# ti.tidbMapDatabase("eagle")
spark.sparkContext.setLogLevel("WARN")
path = "hdfs:///strategy/esmm/"
local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map)
df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord")
df.show(1)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment