Commit b4bfaa68 authored by 张彦钊's avatar 张彦钊

change test file

parent 7eb0395f
This diff is collapsed.
...@@ -17,7 +17,6 @@ def app_list_func(x,l): ...@@ -17,7 +17,6 @@ def app_list_func(x,l):
else: else:
e.append(0) e.append(0)
return e return e
# return ",".join([str(j) for j in e])
def multi_hot(df,column,n): def multi_hot(df,column,n):
...@@ -32,7 +31,7 @@ def multi_hot(df,column,n): ...@@ -32,7 +31,7 @@ def multi_hot(df,column,n):
return number,app_list_map return number,app_list_map
def feature_engineer(): def feature():
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data" sql = "select max(stat_date) from esmm_train_data"
validate_date = con_sql(db, sql)[0].values.tolist()[0] validate_date = con_sql(db, sql)[0].values.tolist()[0]
...@@ -99,9 +98,9 @@ def get_predict(date,value_map,app_list_map): ...@@ -99,9 +98,9 @@ def get_predict(date,value_map,app_list_map):
native_pre = spark.createDataFrame(rdd.filter(lambda x:x[4] == 0).map(lambda x:(x[1],x[2],x[3])))\ native_pre = spark.createDataFrame(rdd.filter(lambda x:x[4] == 0).map(lambda x:(x[1],x[2],x[3])))\
.toDF("city","uid","cid_id") .toDF("city","uid","cid_id")
print("native") print("native")
# native_pre.toPandas().to_csv(local_path+"native.csv", header=True) native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
# TODO 写成csv文件改成下面这样
native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"hello.csv",header = 'true') # native_pre.coalesce(1).write.format('com.databricks.spark.csv').save(path+"native/",header = 'true')
# 预测的tfrecord必须写成一个文件,这样可以摆保证顺序 # 预测的tfrecord必须写成一个文件,这样可以摆保证顺序
spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \ spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \
...@@ -145,7 +144,7 @@ if __name__ == '__main__': ...@@ -145,7 +144,7 @@ if __name__ == '__main__':
path = "hdfs:///strategy/esmm/" path = "hdfs:///strategy/esmm/"
local_path = "/home/gmuser/esmm/" local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map = feature_engineer() validate_date, value_map, app_list_map = feature()
get_predict(validate_date, value_map, app_list_map) get_predict(validate_date, value_map, app_list_map)
# df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord") # df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment