Commit e87b46e6 authored by 张彦钊's avatar 张彦钊

修复生成tfrecord文件时的bug

parent 8252e77d
...@@ -95,6 +95,7 @@ def get_data(): ...@@ -95,6 +95,7 @@ def get_data():
train[i] = train[i].map(value_map) train[i] = train[i].map(value_map)
test[i] = test[i].map(value_map) test[i] = test[i].map(value_map)
print(train[["clevel2_id", "channel","app_list"]].head(6))
print("train shape") print("train shape")
print(train.shape) print(train.shape)
print("test shape") print("test shape")
...@@ -148,8 +149,8 @@ def get_predict(date,value_map,app_list_map,level2_map): ...@@ -148,8 +149,8 @@ def get_predict(date,value_map,app_list_map,level2_map):
df["clevel2_id"] = df["clevel2_id"].fillna("lost_na") df["clevel2_id"] = df["clevel2_id"].fillna("lost_na")
df["clevel2_id"] = df["clevel2_id"].apply(app_list_func, args=(level2_map,)) df["clevel2_id"] = df["clevel2_id"].apply(app_list_func, args=(level2_map,))
print("predict shape") # print("predict shape")
print(df.shape) # print(df.shape)
df["uid"] = df["device_id"] df["uid"] = df["device_id"]
df["city"] = df["ucity_id"] df["city"] = df["ucity_id"]
features = ["ucity_id", "ccity_name", "device_type", "manufacturer", features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
...@@ -177,13 +178,13 @@ def get_predict(date,value_map,app_list_map,level2_map): ...@@ -177,13 +178,13 @@ def get_predict(date,value_map,app_list_map,level2_map):
print("native") print("native")
print(native_pre.shape) print(native_pre.shape)
print(native_pre.head())
native_pre[["uid","city","cid_id"]].to_csv(path+"native.csv",index=False) native_pre[["uid","city","cid_id"]].to_csv(path+"native.csv",index=False)
write_csv(native_pre, "native",200000) write_csv(native_pre, "native",200000)
print("nearby") print("nearby")
print(nearby_pre.shape) print(nearby_pre.shape)
print(nearby_pre.head())
nearby_pre[["uid","city","cid_id"]].to_csv(path+"nearby.csv",index=False) nearby_pre[["uid","city","cid_id"]].to_csv(path+"nearby.csv",index=False)
write_csv(nearby_pre, "nearby", 160000) write_csv(nearby_pre, "nearby", 160000)
......
...@@ -53,7 +53,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): ...@@ -53,7 +53,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
features = { features = {
"y": tf.FixedLenFeature([], tf.float32), "y": tf.FixedLenFeature([], tf.float32),
"z": tf.FixedLenFeature([], tf.float32), "z": tf.FixedLenFeature([], tf.float32),
"ids": tf.FixedLenFeature([8], tf.int64), "ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64),
"app_list": tf.VarLenFeature(tf.int64), "app_list": tf.VarLenFeature(tf.int64),
"level2_list": tf.VarLenFeature(tf.int64) "level2_list": tf.VarLenFeature(tf.int64)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment