Commit 4408b21c authored by 张彦钊's avatar 张彦钊

change test file

parent 8cb7057b
......@@ -66,9 +66,9 @@ def feature_engineer():
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
train = rdd.filter(lambda x: x[1]!= validate_date)\
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4]))
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4])))
test = rdd.filter(lambda x: x[1]== validate_date)\
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4]))
.map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], float(x[3]),float(x[4])))
spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\
.repartition(1).write.format("tfrecords").save(path=path+"va/", mode="overwrite")
......@@ -91,7 +91,7 @@ def get_predict(date,value_map,app_list_map):
df = spark.sql(sql)
df = df.na.fill(dict(zip(features, features)))
rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \
.rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],x[5],x[6],
.rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],float(x[5]),float(x[6]),
[value_map.get(x[1], 300000),value_map.get(date, 299999)]))
rdd.persist()
......@@ -145,12 +145,12 @@ if __name__ == '__main__':
validate_date, value_map, app_list_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map)
df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord")
df.show(1)
print("aa")
print("aa")
df = spark.read.format("tfrecords").load("/strategy/esmm/va/part-r-00000")
df.show(1)
# df = spark.read.format("tfrecords").option("recordType", "Example").load("/strategy/va.tfrecord")
# df.show(1)
# print("aa")
# print("aa")
# df = spark.read.format("tfrecords").load("/strategy/esmm/va/part-r-00000")
# df.show(1)
......
......@@ -47,8 +47,8 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames)
def _parse_fn(record):
features = {
"y": tf.FixedLenFeature([], tf.int64),
"z": tf.FixedLenFeature([], tf.int64),
"y": tf.FixedLenFeature([], tf.float32),
"z": tf.FixedLenFeature([], tf.float32),
"ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64),
"level2_ids": tf.VarLenFeature(tf.int64)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment