Commit 7c69c949 authored by 张彦钊's avatar 张彦钊

Merge branch 'zhao' into 'master'

esmm 预测集加上判断rdd是否为空

See merge request !25
parents a126dc70 9a0aa001
......@@ -336,29 +336,37 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
value_map.get(x[29], 15)],
app_list_func(x[30], leve2_map),app_list_func(x[31], leve3_map)))
rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER)
print("预测集样本大小:")
print(rdd.count())
spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
x[12], x[13], x[14], x[15], x[16], x[17], x[18],x[3],x[4],x[5]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list","city","uid","cid_id") \
.repartition(1).write.format("tfrecords").save(path=path+"native/", mode="overwrite")
print("native tfrecord done")
h = time.time()
print((h-f)/60)
spark.createDataFrame(rdd.filter(lambda x: x[0] == 1)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
x[12], x[13], x[14], x[15], x[16], x[17], x[18],x[3],x[4],x[5]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list","search_tag3_list","city","uid","cid_id") \
.repartition(1).write.format("tfrecords").save(path=path + "nearby/", mode="overwrite")
print("nearby tfrecord done")
if rdd.filter(lambda x: x[0] == 0).count() > 0:
print("预测集native有数据")
spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
x[12], x[13], x[14], x[15], x[16], x[17], x[18], x[3], x[4], x[5]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list",
"tag4_list","tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list",
"search_tag3_list", "city", "uid","cid_id") \
.repartition(1).write.format("tfrecords").save(path=path + "native/", mode="overwrite")
print("native tfrecord done")
h = time.time()
print((h - f) / 60)
else:
print("预测集native为空")
if rdd.filter(lambda x: x[0] == 1).count() > 0:
print("预测集nearby有数据")
spark.createDataFrame(rdd.filter(lambda x: x[0] == 1)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
x[12], x[13], x[14], x[15], x[16], x[17], x[18], x[3], x[4], x[5]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list",
"tag4_list","tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list",
"search_tag3_list", "city", "uid", "cid_id")\
.repartition(1).write.format("tfrecords").save(path=path + "nearby/", mode="overwrite")
print("nearby tfrecord done")
else:
print("预测集nearby为空")
if __name__ == '__main__':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment