Commit e87b46e6 authored by 张彦钊's avatar 张彦钊

修复生成tfrecord文件时的bug

parent 8252e77d
......@@ -95,6 +95,7 @@ def get_data():
train[i] = train[i].map(value_map)
test[i] = test[i].map(value_map)
print(train[["clevel2_id", "channel","app_list"]].head(6))
print("train shape")
print(train.shape)
print("test shape")
......@@ -148,8 +149,8 @@ def get_predict(date,value_map,app_list_map,level2_map):
df["clevel2_id"] = df["clevel2_id"].fillna("lost_na")
df["clevel2_id"] = df["clevel2_id"].apply(app_list_func, args=(level2_map,))
print("predict shape")
print(df.shape)
# print("predict shape")
# print(df.shape)
df["uid"] = df["device_id"]
df["city"] = df["ucity_id"]
features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
......@@ -177,13 +178,13 @@ def get_predict(date,value_map,app_list_map,level2_map):
print("native")
print(native_pre.shape)
print(native_pre.head())
native_pre[["uid","city","cid_id"]].to_csv(path+"native.csv",index=False)
write_csv(native_pre, "native",200000)
print("nearby")
print(nearby_pre.shape)
print(nearby_pre.head())
nearby_pre[["uid","city","cid_id"]].to_csv(path+"nearby.csv",index=False)
write_csv(nearby_pre, "nearby", 160000)
......
......@@ -53,7 +53,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
features = {
"y": tf.FixedLenFeature([], tf.float32),
"z": tf.FixedLenFeature([], tf.float32),
"ids": tf.FixedLenFeature([8], tf.int64),
"ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64),
"app_list": tf.VarLenFeature(tf.int64),
"level2_list": tf.VarLenFeature(tf.int64)
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment