Commit c6a363bc authored by 张彦钊's avatar 张彦钊

split file

parent d583afd3
......@@ -63,19 +63,32 @@ def get_data():
test = df[df["stat_date"] == validate_date+"stat_date"]
for i in features:
train[i] = train[i].map(value_map)
train[i] = train[i].astype('int64')
test[i] = test[i].map(value_map)
test[i] = test[i].astype('int64')
print("train shape")
print(train.shape)
print("test shape")
print(test.shape)
train.to_csv(path + "tr.csv", index=False)
test.to_csv(path + "va.csv", index=False)
write_csv(train, "tr",100000)
write_csv(test, "va",80000)
return validate_date,value_map
def write_csv(df,name,n):
for i in range(0, df.shape[0], n):
if i == 0:
temp = df.iloc[0:n]
elif i + n > df.shape[0]:
temp = df.iloc[i:]
else:
temp = df.loc[i:i + n]
temp.to_csv(path + name+ "/{}{}.csv".format(name,i), index=False)
def get_predict(date,value_map):
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select e.y,e.z,e.label,e.ucity_id,e.clevel1_id,e.ccity_name," \
......@@ -108,15 +121,21 @@ def get_predict(date,value_map):
for i in features:
native_pre[i] = native_pre[i].map(value_map)
# TODO 没有覆盖到的类别会处理成na,暂时用0填充,后续完善一下
native_pre[i] = native_pre[i].fillna(0)
native_pre[i] = native_pre[i].astype('int64')
nearby_pre[i] = nearby_pre[i].map(value_map)
# TODO 没有覆盖到的类别会处理成na,暂时用0填充,后续完善一下
nearby_pre[i] = nearby_pre[i].fillna(0)
nearby_pre[i] = nearby_pre[i].astype('int64')
print("native")
print(native_pre.shape)
native_pre.to_csv(path + "native.csv", index=False)
write_csv(native_pre, "native",200000)
print("nearby")
print(nearby_pre.shape)
nearby_pre.to_csv(path + "nearby.csv",index=False)
write_csv(nearby_pre, "nearby", 160000)
if __name__ == '__main__':
......
......@@ -33,7 +33,7 @@ def gen_tfrecords(in_file):
"y": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["y"][i]])),
"z": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["z"][i]])),
"top": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["top"][i]])),
"channel":tf.train.Feature(int64_list=tf.train.Int64List(value=[df["channel"][i]])),
"channel": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["channel"][i]])),
"ucity_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["ucity_id"][i]])),
"clevel1_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["clevel1_id"][i]])),
"ccity_name": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["ccity_name"][i]])),
......@@ -43,6 +43,7 @@ def gen_tfrecords(in_file):
"time": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["time"][i]])),
"stat_date": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["stat_date"][i]]))
})
example = tf.train.Example(features = tf.train.Features(feature = features))
serialized = example.SerializeToString()
tfrecord_out.write(serialized)
......
......@@ -64,6 +64,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
"time": tf.FixedLenFeature([], tf.int64),
"stat_date": tf.FixedLenFeature([], tf.int64)
}
parsed = tf.parse_single_example(record, features)
y = parsed.pop('y')
z = parsed.pop('z')
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment