Commit 003a8df1 authored by 张彦钊's avatar 张彦钊

修改特征工程文件

parent 8d4380f4
...@@ -104,15 +104,15 @@ def feature_engineer(): ...@@ -104,15 +104,15 @@ def feature_engineer():
unique_values = [] unique_values = []
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct stat_date from esmm_train_data_dwell_precise" sql = "select distinct stat_date from esmm_train_data_dwell_share_test"
unique_values.extend(get_unique(db,sql)) unique_values.extend(get_unique(db,sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct ucity_id from esmm_train_data_dwell_precise" sql = "select distinct ucity_id from esmm_train_data_dwell_share_test"
unique_values.extend(get_unique(db, sql)) unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct ccity_name from esmm_train_data_dwell_precise" sql = "select distinct ccity_name from esmm_train_data_dwell_share_test"
unique_values.extend(get_unique(db, sql)) unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
...@@ -136,22 +136,49 @@ def feature_engineer(): ...@@ -136,22 +136,49 @@ def feature_engineer():
unique_values.extend(get_unique(db, sql)) unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data_dwell_precise" sql = "select distinct price_min from knowledge"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct treatment_method from knowledge"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct price_max from knowledge"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct treatment_time from knowledge"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct maintain_time from knowledge"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct recover_time from knowledge"
unique_values.extend(get_unique(db, sql))
# unique_values.append("video")
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data_dwell_share_test"
validate_date = con_sql(db, sql)[0].values.tolist()[0] validate_date = con_sql(db, sql)[0].values.tolist()[0]
print("validate_date:" + validate_date) print("validate_date:" + validate_date)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d") temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
start = (temp - datetime.timedelta(days=60)).strftime("%Y-%m-%d") start = (temp - datetime.timedelta(days=180)).strftime("%Y-%m-%d")
print(start) print(start)
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC')
sql = "select distinct doctor.hospital_id from jerry_test.esmm_train_data_dwell_precise e " \ sql = "select distinct doctor.hospital_id from jerry_test.esmm_train_data_dwell_share_test e " \
"left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \ "left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
"left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \ "left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
"where e.stat_date >= '{}'".format(start) "where e.stat_date >= '{}'".format(start)
unique_values.extend(get_unique(db, sql)) unique_values.extend(get_unique(db, sql))
features = ["ucity_id", "ccity_name", "device_type", "manufacturer", features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "hospital_id", "channel", "top", "time", "stat_date", "hospital_id",
"app_list", "level3_ids", "level2_ids", "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time",
"app_list", "level3_ids", "level2_ids", "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"search_tag2", "search_tag3"] "search_tag2", "search_tag3"]
unique_values.extend(features) unique_values.extend(features)
print("unique_values length") print("unique_values length")
...@@ -165,14 +192,23 @@ def feature_engineer(): ...@@ -165,14 +192,23 @@ def feature_engineer():
sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \ sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
"u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \ "u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \
"doris.search_tag2,doris.search_tag3," \ "wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
"ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7,doris.search_tag2,doris.search_tag3," \
"k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time," \
"e.device_id,e.cid_id " \ "e.device_id,e.cid_id " \
"from jerry_test.esmm_train_data_dwell_precise e " \ "from jerry_test.esmm_train_data_dwell_share_test e left join jerry_test.user_feature u on e.device_id = u.device_id " \
"left join jerry_test.user_feature u on e.device_id = u.device_id " \
"left join jerry_test.cid_type_top c on e.device_id = c.device_id " \ "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
"left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \ "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
"left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \ "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
"left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \ "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
"left join jerry_test.knowledge k on feat.level2 = k.level2_id " \
"left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
"left join jerry_test.question_tag question on e.device_id = question.device_id " \
"left join jerry_test.search_tag search on e.device_id = search.device_id " \
"left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
"left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
"left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
"left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
"left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \ "left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
"left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \ "left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
"left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \ "left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \
...@@ -181,20 +217,27 @@ def feature_engineer(): ...@@ -181,20 +217,27 @@ def feature_engineer():
df = spark.sql(sql) df = spark.sql(sql)
df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer", df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids"]) "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids",
"tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7"])
df = df.na.fill(dict(zip(features, features))) df = df.na.fill(dict(zip(features, features)))
rdd = df.select("stat_date", "y", "z", "app_list", "level2_ids", "level3_ids", rdd = df.select("stat_date", "y", "z", "app_list", "level2_ids", "level3_ids",
"tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time", "ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
"hospital_id", "search_tag2", "search_tag3","cid_id","device_id")\ "hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time", "recover_time", "search_tag2", "search_tag3","cid_id","device_id")\
.rdd.repartition(200).map( .rdd.repartition(200).map(
lambda x: (x[0], float(x[1]), float(x[2]), app_list_func(x[3], app_list_map), lambda x: (x[0], float(x[1]), float(x[2]), app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
app_list_func(x[4], leve2_map),app_list_func(x[5], leve3_map), app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map), app_list_func(x[7], leve2_map),
[value_map.get(x[0], 1), value_map.get(x[6], 2), value_map.get(x[7], 3), value_map.get(x[8], 4), app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map), app_list_func(x[10], leve2_map),
value_map.get(x[9], 5), value_map.get(x[10], 6), value_map.get(x[11], 7), value_map.get(x[12], 8), app_list_func(x[11], leve2_map), app_list_func(x[12], leve2_map),
value_map.get(x[13], 9)], [value_map.get(x[0], 1), value_map.get(x[13], 2), value_map.get(x[14], 3), value_map.get(x[15], 4),
app_list_func(x[14], leve2_map), app_list_func(x[15], leve3_map),x[6],x[16],x[17] value_map.get(x[16], 5), value_map.get(x[17], 6), value_map.get(x[18], 7), value_map.get(x[19], 8),
value_map.get(x[20], 9), value_map.get(x[21], 10),
value_map.get(x[22], 11), value_map.get(x[23], 12), value_map.get(x[24], 13),
value_map.get(x[25], 14), value_map.get(x[26], 15)],
app_list_func(x[27], leve2_map), app_list_func(x[28], leve3_map),x[13],x[29],x[30]
)) ))
...@@ -202,11 +245,13 @@ def feature_engineer(): ...@@ -202,11 +245,13 @@ def feature_engineer():
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集 # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
train = rdd.map( train = rdd.filter(lambda x: x[0] != validate_date).map(
lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9],
x[10], x[11])) x[10], x[11], x[12], x[13], x[14], x[15],x[16],x[17],x[18]))
f = time.time() f = time.time()
spark.createDataFrame(train).toDF("y", "z", "app_list", "level2_list", "level3_list","ids", spark.createDataFrame(train).toDF("y", "z", "app_list", "level2_list", "level3_list",
"tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids",
"search_tag2_list","search_tag3_list","city","cid_id","uid") \ "search_tag2_list","search_tag3_list","city","cid_id","uid") \
.repartition(1).write.format("tfrecords").save(path=path + "tr/", mode="overwrite") .repartition(1).write.format("tfrecords").save(path=path + "tr/", mode="overwrite")
h = time.time() h = time.time()
...@@ -216,13 +261,15 @@ def feature_engineer(): ...@@ -216,13 +261,15 @@ def feature_engineer():
print("训练集样本总量:") print("训练集样本总量:")
print(rdd.count()) print(rdd.count())
get_pre_number() # get_pre_number()
test = rdd.filter(lambda x: x[0] == validate_date).map( test = rdd.filter(lambda x: x[0] == validate_date).map(
lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9], lambda x: (x[1], x[2], x[3], x[4], x[5], x[6], x[7], x[8], x[9],
x[10], x[11])) x[10], x[11], x[12], x[13], x[14], x[15],x[16],x[17],x[18]))
spark.createDataFrame(test).toDF("y", "z", "app_list", "level2_list", "level3_list","ids", spark.createDataFrame(test).toDF("y", "z", "app_list", "level2_list", "level3_list",
"tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids",
"search_tag2_list","search_tag3_list","city","cid_id","uid") \ "search_tag2_list","search_tag3_list","city","cid_id","uid") \
.repartition(1).write.format("tfrecords").save(path=path + "va/", mode="overwrite") .repartition(1).write.format("tfrecords").save(path=path + "va/", mode="overwrite")
...@@ -236,19 +283,30 @@ def feature_engineer(): ...@@ -236,19 +283,30 @@ def feature_engineer():
def get_predict(date,value_map,app_list_map,leve2_map,leve3_map): def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \ sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
"u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \ "u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
"dl.app_list,e.hospital_id,feat.level3_ids,doris.search_tag2,doris.search_tag3 " \ "dl.app_list,e.hospital_id,feat.level3_ids," \
"wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
"ot.tag as tag5,sixin.tag as tag6,cart.tag as tag7,doris.search_tag2,doris.search_tag3," \
"k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
"from jerry_test.esmm_pre_data e " \ "from jerry_test.esmm_pre_data e " \
"left join jerry_test.user_feature u on e.device_id = u.device_id " \ "left join jerry_test.user_feature u on e.device_id = u.device_id " \
"left join jerry_test.cid_type_top c on e.device_id = c.device_id " \ "left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
"left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \ "left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
"left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \ "left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
"left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \ "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
"left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date " \ "left join jerry_test.wiki_tag wiki on e.device_id = wiki.device_id " \
"where e.device_id = '355374101417079'" "left join jerry_test.question_tag question on e.device_id = question.device_id " \
"left join jerry_test.search_tag search on e.device_id = search.device_id " \
"left join jerry_test.budan_tag budan on e.device_id = budan.device_id " \
"left join jerry_test.order_tag ot on e.device_id = ot.device_id " \
"left join jerry_test.sixin_tag sixin on e.device_id = sixin.device_id " \
"left join jerry_test.cart_tag cart on e.device_id = cart.device_id " \
"left join jerry_test.knowledge k on feat.level2 = k.level2_id " \
"left join jerry_test.search_doris doris on e.device_id = doris.device_id and e.stat_date = doris.get_date"
features = ["ucity_id", "ccity_name", "device_type", "manufacturer", features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id", "channel", "top", "time", "hospital_id",
"app_list", "level3_ids", "level2_ids", "treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time",
"app_list", "level3_ids", "level2_ids", "tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"search_tag2", "search_tag3"] "search_tag2", "search_tag3"]
df = spark.sql(sql) df = spark.sql(sql)
...@@ -257,17 +315,25 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map): ...@@ -257,17 +315,25 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
df = df.na.fill(dict(zip(features, features))) df = df.na.fill(dict(zip(features, features)))
f = time.time() f = time.time()
rdd = df.select("label", "y", "z", "ucity_id", "device_id", "cid_id", "app_list", "level2_ids", "level3_ids", rdd = df.select("label", "y", "z", "ucity_id", "device_id", "cid_id", "app_list", "level2_ids", "level3_ids",
"tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7",
"ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time", "ucity_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "time",
"hospital_id", "search_tag2", "search_tag3") \ "hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time", "recover_time", "search_tag2", "search_tag3") \
.rdd.repartition(200).map(lambda x: (x[0], float(x[1]), float(x[2]), x[3], x[4], x[5], .rdd.repartition(200).map(lambda x: (x[0], float(x[1]), float(x[2]), x[3], x[4], x[5],
app_list_func(x[6], app_list_map), app_list_func(x[7], leve2_map), app_list_func(x[6], app_list_map), app_list_func(x[7], leve2_map),
app_list_func(x[8], leve3_map), app_list_func(x[8], leve3_map), app_list_func(x[9], leve2_map),
[value_map.get(date, 1), value_map.get(x[9], 2), app_list_func(x[10], leve2_map), app_list_func(x[11], leve2_map),
value_map.get(x[10], 3), value_map.get(x[11], 4), app_list_func(x[12], leve2_map), app_list_func(x[13], leve2_map),
value_map.get(x[12], 5), value_map.get(x[13], 6), app_list_func(x[14], leve2_map), app_list_func(x[15], leve2_map),
value_map.get(x[14], 7), value_map.get(x[15], 8), [value_map.get(date, 1), value_map.get(x[16], 2),
value_map.get(x[16], 9)], value_map.get(x[17], 3), value_map.get(x[18], 4),
app_list_func(x[17], leve2_map),app_list_func(x[18], leve3_map))) value_map.get(x[19], 5), value_map.get(x[20], 6),
value_map.get(x[21], 7), value_map.get(x[22], 8),
value_map.get(x[23], 9), value_map.get(x[24], 10),
value_map.get(x[25], 11), value_map.get(x[26], 12),
value_map.get(x[27], 13), value_map.get(x[28], 14),
value_map.get(x[29], 15)],
app_list_func(x[30], leve2_map),app_list_func(x[31], leve3_map)))
rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER) rdd.persist(storageLevel= StorageLevel.MEMORY_ONLY_SER)
print("预测集样本大小:") print("预测集样本大小:")
...@@ -275,8 +341,10 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map): ...@@ -275,8 +341,10 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
if rdd.filter(lambda x: x[0] == 0).count() > 0: if rdd.filter(lambda x: x[0] == 0).count() > 0:
print("预测集native有数据") print("预测集native有数据")
spark.createDataFrame(rdd.filter(lambda x: x[0] == 0) spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11], x[3], x[4], x[5]))) \ .map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
.toDF("y", "z", "app_list", "level2_list", "level3_list", "ids", "search_tag2_list", x[12], x[13], x[14], x[15], x[16], x[17], x[18], x[3], x[4], x[5]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list",
"tag4_list","tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list",
"search_tag3_list", "city", "uid","cid_id") \ "search_tag3_list", "city", "uid","cid_id") \
.repartition(1).write.format("tfrecords").save(path=path + "native/", mode="overwrite") .repartition(1).write.format("tfrecords").save(path=path + "native/", mode="overwrite")
...@@ -289,9 +357,11 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map): ...@@ -289,9 +357,11 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
if rdd.filter(lambda x: x[0] == 1).count() > 0: if rdd.filter(lambda x: x[0] == 1).count() > 0:
print("预测集nearby有数据") print("预测集nearby有数据")
spark.createDataFrame(rdd.filter(lambda x: x[0] == 1) spark.createDataFrame(rdd.filter(lambda x: x[0] == 1)
.map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11], x[3], x[4], x[5]))) \ .map(lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11],
.toDF("y", "z", "app_list", "level2_list", "level3_list", "ids", "search_tag2_list", x[12], x[13], x[14], x[15], x[16], x[17], x[18], x[3], x[4], x[5]))) \
"search_tag3_list", "city", "uid","cid_id")\ .toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list",
"tag4_list","tag5_list", "tag6_list", "tag7_list", "ids", "search_tag2_list",
"search_tag3_list", "city", "uid", "cid_id")\
.repartition(1).write.format("tfrecords").save(path=path + "nearby/", mode="overwrite") .repartition(1).write.format("tfrecords").save(path=path + "nearby/", mode="overwrite")
print("nearby tfrecord done") print("nearby tfrecord done")
else: else:
...@@ -316,6 +386,6 @@ if __name__ == '__main__': ...@@ -316,6 +386,6 @@ if __name__ == '__main__':
local_path = "/home/gmuser/esmm/" local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer() validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map) # get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
spark.stop() spark.stop()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment