Commit fb95c89f authored by 张彦钊's avatar 张彦钊

修改测试文件

parent 3f37c5ce
......@@ -38,7 +38,7 @@ def feature_engineer():
validate_date = con_sql(db, sql)[0].values.tolist()[0]
print("validate_date:" + validate_date)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
start = (temp - datetime.timedelta(days=300)).strftime("%Y-%m-%d")
start = (temp - datetime.timedelta(days=3)).strftime("%Y-%m-%d")
print(start)
sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
......@@ -107,11 +107,9 @@ def feature_engineer():
value_map[x[13]], value_map[x[14]], value_map[x[15]], value_map[x[16]],
value_map[x[17]], x[18], x[19]))
# spark.createDataFrame(test).write.csv('/recommend/va', mode='overwrite', header=True)
# spark.createDataFrame(train).write.csv('/recommend/tr', mode='overwrite', header=True)
spark.createDataFrame(test).write.format("avro").save(path="/recommend/va", mode="overwrite")
spark.createDataFrame(train).write.format("avro").save(path="/recommend/tr", mode="overwrite")
a = spark.createDataFrame(train).toPandas()
print(a.shape)
print("done")
rdd.unpersist()
......@@ -161,7 +159,8 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("city","uid","cid_id")
print("native")
print(native_pre.count())
native_pre.write.csv('/recommend', mode='overwrite', header=True)
native_pre.write.format("avro").save(path="/recommend", mode="overwrite")
spark.createDataFrame(rdd.filter(lambda x: x[6] == 0)
.map(lambda x: (x[0], x[1], x[2],x[9],x[10],x[11],x[12],x[13],x[14],x[15],
......@@ -169,13 +168,13 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("app_list", "level2_ids", "level3_ids","ucity_id",
"ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time", "top","stat_date").write.csv('/recommend/native', mode='overwrite', header=True)
"recover_time", "top","stat_date").write.format("avro").save(path="/recommend/native", mode="overwrite")
nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[6] == 1).map(lambda x: (x[3], x[4], x[5]))) \
.toDF("city", "uid", "cid_id")
print("nearby")
print(nearby_pre.count())
nearby_pre.write.csv('/recommend', mode='overwrite', header=True)
nearby_pre.write.format("avro").save(path="/recommend", mode="overwrite")
spark.createDataFrame(rdd.filter(lambda x: x[6] == 1)
.map(lambda x: (x[0], x[1], x[2], x[9], x[10], x[11], x[12], x[13], x[14], x[15],
......@@ -183,7 +182,7 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("app_list", "level2_ids", "level3_ids", "ucity_id",
"ccity_name", "device_type", "manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time","top","stat_date").write.csv('/recommend/nearby', mode='overwrite', header=True)
"recover_time","top","stat_date").write.format("avro").save(path="/recommend/nearby", mode="overwrite")
rdd.unpersist()
......@@ -203,7 +202,6 @@ def test():
df.show(6)
df.write.format("avro").save(path="/recommend/tr", mode="overwrite")
# from hdfs import InsecureClient
# from hdfs.ext.dataframe import read_dataframe
# client = InsecureClient('http://nvwa01:50070')
......@@ -221,22 +219,6 @@ def test():
# spark.sql("select cl_type from online.tl_hdfs_maidian_view where partition_date = '20190312' limit 6").show()
# data = [(0, 18.0), (1, 19.0), (2, 8.0), (3, 5.0), (4, 2.2), (5, 9.2), (6, 14.4)]
# df = spark.createDataFrame(data, ["id", "hour"])
# df.show(6)
# t = df.rdd.map(lambda x:x[0]).collect()
# print(t)
# validate_date = spark.sql("select max(stat_date) from esmm_train_data").rdd.map(lambda x: str(x[0]))
# print(validate_date.count())
# print("validate_date:" + validate_date)
# temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
# start = (temp - datetime.timedelta(days=10)).strftime("%Y-%m-%d")
# print(start)
if __name__ == '__main__':
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
......@@ -251,9 +233,9 @@ if __name__ == '__main__':
ti.tidbMapDatabase("jerry_test")
spark.sparkContext.setLogLevel("WARN")
# validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
# get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
test()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment