Commit 4ba42d54 authored by 张彦钊's avatar 张彦钊

change test file

parent 75951d44
......@@ -19,22 +19,6 @@ def app_list_func(x,l):
return e
def multi_hot(df,column,n):
a = time.time()
v = list(set(df.select(column).rdd.map(lambda x: x[0]).collect()))
b = time.time()
print(column)
print("cost time 分钟")
print((b-a)/60)
app_list_value = [str(i).split(",") for i in v]
app_list_unique = []
for i in app_list_value:
app_list_unique.extend(i)
app_list_unique = list(set(app_list_unique))
number = len(app_list_unique)
app_list_map = dict(zip(app_list_unique, list(range(n, number + n))))
return number,app_list_map
def get_list(db,sql,n):
cursor = db.cursor()
cursor.execute(sql)
......@@ -61,30 +45,114 @@ def get_map():
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select level2_ids from diary_feat"
b = time.time()
leve2_number, leve2_map = get_list(db, sql, apps_number)
leve2_number, leve2_map = get_list(db, sql, 1+apps_number)
print("leve2")
print((time.time() - b) / 60)
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select level3_ids from diary_feat"
c = time.time()
leve3_number, leve3_map = get_list(db, sql, leve2_number)
leve3_number, leve3_map = get_list(db, sql, 1+leve2_number+apps_number)
print((time.time() - c) / 60)
return apps_number, app_list_map,leve2_number, leve2_map,leve3_number, leve3_map
def get_unique(db,sql):
cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchall()
v = list(set([i[0] for i in result]))
db.close()
print(sql)
print(len(v))
return v
def feature_engineer():
apps_number, app_list_map, level2_number, level2_map, level3_number, level3_map = get_map()
unique_values = []
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct stat_date from esmm_train_data_dur"
unique_values.extend(get_unique(db,sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct ucity_id from esmm_train_data_dur"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct ccity_name from esmm_train_data_dur"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct time from cid_time_cut"
unique_values.extend(get_unique(db, sql))
def feature_engineer():
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data"
sql = "select distinct device_type from user_feature"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct manufacturer from user_feature"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct channel from user_feature"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct top from cid_type_top"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct price_min from train_Knowledge_network_data"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct treatment_method from train_Knowledge_network_data"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct price_max from train_Knowledge_network_data"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct treatment_time from train_Knowledge_network_data"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct maintain_time from train_Knowledge_network_data"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select distinct recover_time from train_Knowledge_network_data"
unique_values.extend(get_unique(db, sql))
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data_dur"
validate_date = con_sql(db, sql)[0].values.tolist()[0]
print("validate_date:" + validate_date)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
start = (temp - datetime.timedelta(days=300)).strftime("%Y-%m-%d")
print(start)
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC')
sql = "select doctor.hospital_id from jerry_test.esmm_train_data_dur e " \
"left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
"left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
"where e.stat_date >= '{}'".format(start)
unique_values.extend(get_unique(db, sql))
features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]
unique_values.extend(features)
print("unique_values length")
print(len(unique_values))
temp = list(range(2 + apps_number + level2_number + level3_number,
2 + apps_number + level2_number + level3_number + len(unique_values)))
value_map = dict(zip(unique_values, temp))
sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
"u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \
"wiki.tag as tag1,question.tag as tag2,search.tag as tag3,budan.tag as tag4," \
......@@ -113,36 +181,17 @@ def feature_engineer():
"channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids",
"tag1","tag2","tag3","tag4","tag5","tag6","tag7"])
features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]
df = df.na.fill(dict(zip(features, features)))
apps_number, app_list_map = multi_hot(df, "app_list", 1)
level2_number, leve2_map = multi_hot(df, "level2_ids", 1 + apps_number)
level3_number, leve3_map = multi_hot(df, "level3_ids", 1 + apps_number + level2_number)
unique_values = []
for i in features:
a = time.time()
unique_values.extend(list(set(df.select(i).rdd.map(lambda x: x[0]).collect())))
b = time.time()
print(i)
print((b-a)/60)
temp = list(range(2 + apps_number + level2_number + level3_number,
2 + apps_number + level2_number + level3_number + len(unique_values)))
value_map = dict(zip(unique_values, temp))
c = time.time()
rdd = df.select("stat_date","y", "z","app_list","level2_ids","level3_ids",
"tag1","tag2","tag3","tag4","tag5","tag6","tag7",
"ucity_id", "ccity_name","device_type", "manufacturer", "channel", "top", "time",
"hospital_id","treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time","recover_time").rdd.repartition(200).map(lambda x: (x[0],float(x[1]),float(x[2]),app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map),app_list_func(x[7], leve2_map),
app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map),app_list_func(x[10], leve2_map),
app_list_func(x[11], leve2_map),app_list_func(x[12], leve2_map),
app_list_func(x[5], level3_map), app_list_func(x[6], level2_map),app_list_func(x[7], level2_map),
app_list_func(x[8], level2_map), app_list_func(x[9], level2_map),app_list_func(x[10], level2_map),
app_list_func(x[11], level2_map),app_list_func(x[12], level2_map),
[value_map[x[0]], value_map[x[13]],value_map[x[14]], value_map[x[15]], value_map[x[16]],
value_map[x[17]],value_map[x[18]], value_map[x[19]], value_map[x[20]],value_map[x[21]],
value_map[x[22]], value_map[x[23]], value_map[x[24]],value_map[x[25]],value_map[x[26]]]))
......@@ -151,14 +200,6 @@ def feature_engineer():
print((d-c)/60)
rdd.persist()
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
# train = rdd.filter(lambda x: x[0] != validate_date) \
# .map(lambda x: (float(x[1]),float(x[2]),app_list_func(x[3], app_list_map), app_list_func(x[4], leve2_map),
# app_list_func(x[5], leve3_map), app_list_func(x[6], leve2_map),app_list_func(x[7], leve2_map),
# app_list_func(x[8], leve2_map), app_list_func(x[9], leve2_map),app_list_func(x[10], leve2_map),
# app_list_func(x[11], leve2_map),app_list_func(x[12], leve2_map),
# [value_map[x[0]], value_map[x[13]],value_map[x[14]], value_map[x[15]], value_map[x[16]],
# value_map[x[17]],value_map[x[18]], value_map[x[19]], value_map[x[20]],value_map[x[21]],
# value_map[x[22]], value_map[x[23]], value_map[x[24]],value_map[x[25]],value_map[x[26]]]))
train = rdd.filter(lambda x: x[0] != validate_date).map(lambda x:(x[1],x[2],x[3],x[4],x[5],x[6],x[7],x[8],x[9],
x[10],x[11],x[12],x[13]))
......@@ -183,10 +224,10 @@ def feature_engineer():
rdd.unpersist()
return validate_date,value_map,app_list_map,leve2_map,leve3_map
return validate_date,value_map,app_list_map,level2_map,level3_map
def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
def get_predict():
sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \
"u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \
"dl.app_list,e.hospital_id,feat.level3_ids," \
......@@ -222,12 +263,12 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
"hospital_id", "treatment_method", "price_min", "price_max", "treatment_time",
"maintain_time", "recover_time") \
.rdd.repartition(200).map(lambda x: (x[0],float(x[1]),float(x[2]),x[3],x[4],x[5],
app_list_func(x[6], app_list_map),app_list_func(x[7], leve2_map),
app_list_func(x[8], leve3_map), app_list_func(x[9], leve2_map),
app_list_func(x[10], leve2_map),app_list_func(x[11], leve2_map),
app_list_func(x[12], leve2_map), app_list_func(x[13], leve2_map),
app_list_func(x[14], leve2_map), app_list_func(x[15], leve2_map),
[value_map.get(date, 299999),value_map.get(x[16], 299998),
app_list_func(x[6], app_list_map),app_list_func(x[7], level2_map),
app_list_func(x[8], level3_map), app_list_func(x[9], level2_map),
app_list_func(x[10], level2_map),app_list_func(x[11], level2_map),
app_list_func(x[12], level2_map), app_list_func(x[13], level2_map),
app_list_func(x[14], level2_map), app_list_func(x[15], level2_map),
[value_map.get(validate_date, 299999),value_map.get(x[16], 299998),
value_map.get(x[17], 299997),value_map.get(x[18], 299996),
value_map.get(x[19], 299995), value_map.get(x[20], 299994),
value_map.get(x[21], 299993), value_map.get(x[22], 299992),
......@@ -286,25 +327,25 @@ def con_sql(db,sql):
if __name__ == '__main__':
# sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
# .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
# .set("spark.tispark.plan.allow_index_double_read", "false") \
# .set("spark.tispark.plan.allow_index_read", "true") \
# .set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
# .set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
# .set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
#
# spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
# ti = pti.TiContext(spark)
# ti.tidbMapDatabase("jerry_test")
# ti.tidbMapDatabase("eagle")
# spark.sparkContext.setLogLevel("WARN")
# path = "hdfs:///strategy/esmm/"
# local_path = "/home/gmuser/esmm/"
#
# validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
# get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
get_map()
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
.set("spark.tispark.plan.allow_index_double_read", "false") \
.set("spark.tispark.plan.allow_index_read", "true") \
.set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
.set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
.set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
ti.tidbMapDatabase("eagle")
spark.sparkContext.setLogLevel("WARN")
path = "hdfs:///strategy/esmm/"
local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map, level2_map, level3_map = feature_engineer()
get_predict()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment