Commit 1963df08 authored by 张彦钊's avatar 张彦钊

change test file

parent 6aa49406
......@@ -17,7 +17,7 @@ def app_list_func(x,l):
e.append(l[i])
else:
e.append(0)
return np.array(e)
return e
# return ",".join([str(j) for j in e])
......@@ -101,13 +101,13 @@ def feature_engineer():
spark.createDataFrame(test).toDF("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z")\
.repartition(1).write.format("tfrecords").option("recordType", "Example").save(path=path+"va/", mode="overwrite")
.repartition(1).write.format("tfrecords").option("recordType", "SequenceExample").save(path=path+"va/", mode="overwrite")
print("va write done")
spark.createDataFrame(train).toDF("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z") \
.repartition(1).write.format("tfrecords").option("recordType", "Example").save(path=path+"tr/", mode="overwrite")
.repartition(1).write.format("tfrecords").option("recordType", "SequenceExample").save(path=path+"tr/", mode="overwrite")
print("done")
rdd.unpersist()
......@@ -166,7 +166,7 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("app_list", "level2_ids", "level3_ids","y","z","ucity_id",
"ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time", "top","stat_date").repartition(1).write.format("tfrecords").option("recordType", "Example") \
"recover_time", "top","stat_date").repartition(1).write.format("tfrecords").option("recordType", "SequenceExample") \
.save(path=path+"native/", mode="overwrite")
nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[6] == 1).map(lambda x: (x[3], x[4], x[5]))) \
......@@ -182,7 +182,7 @@ def get_predict(date,value_map,app_list_map,level2_map,level3_map):
.toDF("app_list", "level2_ids", "level3_ids","y","z", "ucity_id",
"ccity_name", "device_type", "manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time","top","stat_date").repartition(1).write.format("tfrecords").option("recordType", "Example") \
"recover_time","top","stat_date").repartition(1).write.format("tfrecords").option("recordType", "SequenceExample") \
.save(path=path+"nearby/", mode="overwrite")
rdd.unpersist()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment