Commit 4408b21c authored by 张彦钊's avatar 张彦钊

change test file

parent 8cb7057b
...@@ -66,9 +66,9 @@ def feature_engineer(): ...@@ -66,9 +66,9 @@ def feature_engineer():
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集 # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
train = rdd.filter(lambda x: x[1]!= validate_date)\ train = rdd.filter(lambda x: x[1]!= validate_date)\
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4])) .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4])))
test = rdd.filter(lambda x: x[1]== validate_date)\ test = rdd.filter(lambda x: x[1]== validate_date)\
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4])) .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4])))
spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\ spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\
.repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite") .repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite")
...@@ -91,7 +91,7 @@ def get_predict(date,value_map,app_list_map): ...@@ -91,7 +91,7 @@ def get_predict(date,value_map,app_list_map):
df = spark.sql(sql) df = spark.sql(sql)
df = df.na.fill(dict(zip(features, features))) df = df.na.fill(dict(zip(features, features)))
rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \ rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \
.rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],x[5],x[6], .rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],float(x[5]),float(x[6]),
[value_map.get(x[1], 300000),value_map.get(date, 299999)])) [value_map.get(x[1], 300000),value_map.get(date, 299999)]))
rdd.persist() rdd.persist()
...@@ -145,12 +145,12 @@ if __name__ == '__main__': ...@@ -145,12 +145,12 @@ if __name__ == '__main__':
validate_date, value_map, app_list_map = feature_engineer() validate_date, value_map, app_list_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map) get_predict(validate_date, value_map, app_list_map)
df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord") # df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord")
df.show(1) # df.show(1)
print("aa") # print("aa")
print("aa") # print("aa")
df = spark.read.format("tfrecords").load("/strategy/esmm/va/part-r-00000") # df = spark.read.format("tfrecords").load("/strategy/esmm/va/part-r-00000")
df.show(1) # df.show(1)
......
...@@ -47,8 +47,8 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): ...@@ -47,8 +47,8 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames) print('Parsing', filenames)
def _parse_fn(record): def _parse_fn(record):
features = { features = {
"y": tf.FixedLenFeature([], tf.int64), "y": tf.FixedLenFeature([], tf.float32),
"z": tf.FixedLenFeature([], tf.int64), "z": tf.FixedLenFeature([], tf.float32),
"ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64), "ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64),
"level2_ids": tf.VarLenFeature(tf.int64) "level2_ids": tf.VarLenFeature(tf.int64)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment