Commit f406c2a4 authored by 张彦钊's avatar 张彦钊

pandas 映射

parent 4e17c892
...@@ -59,8 +59,8 @@ def get_data(): ...@@ -59,8 +59,8 @@ def get_data():
value_map = {v: k for k, v in enumerate(unique_values)} value_map = {v: k for k, v in enumerate(unique_values)}
df = df.drop("device_id", axis=1) df = df.drop("device_id", axis=1)
train = df[df["stat_date"] != validate_date] train = df[df["stat_date"] != validate_date+"stat_date"]
test = df[df["stat_date"] == validate_date] test = df[df["stat_date"] == validate_date+"stat_date"]
for i in features: for i in features:
train[i] = train[i].map(value_map) train[i] = train[i].map(value_map)
test[i] = test[i].map(value_map) test[i] = test[i].map(value_map)
......
...@@ -26,8 +26,7 @@ def gen_tfrecords(in_file): ...@@ -26,8 +26,7 @@ def gen_tfrecords(in_file):
out_file = os.path.join(FLAGS.output_dir, basename) out_file = os.path.join(FLAGS.output_dir, basename)
tfrecord_out = tf.python_io.TFRecordWriter(out_file) tfrecord_out = tf.python_io.TFRecordWriter(out_file)
df = pd.read_csv(in_file) df = pd.read_csv(in_file)
["", "", "", "device_type", "manufacturer",
, "level2_ids", "time", "stat_date"]
for i in range(df.shape[0]): for i in range(df.shape[0]):
features = tf.train.Features(feature={ features = tf.train.Features(feature={
"y": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["y"][i]])), "y": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["y"][i]])),
...@@ -37,11 +36,13 @@ def gen_tfrecords(in_file): ...@@ -37,11 +36,13 @@ def gen_tfrecords(in_file):
"ucity_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["ucity_id"][i]])), "ucity_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["ucity_id"][i]])),
"clevel1_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["clevel1_id"][i]])), "clevel1_id": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["clevel1_id"][i]])),
"ccity_name": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["ccity_name"][i]])), "ccity_name": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["ccity_name"][i]])),
"channel": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["channel"][i]])), "device_type": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["device_type"][i]])),
"manufacturer": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["manufacturer"][i]])),
"level2_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["level2_ids"][i]])),
"time": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["time"][i]])),
"stat_date": tf.train.Feature(int64_list=tf.train.Int64List(value=[df["stat_date"][i]]))
}) })
example = tf.train.Example(features = tf.train.Features(feature = features))
example = tf.train.Example(features = tf.train.Features(feature = feature))
serialized = example.SerializeToString() serialized = example.SerializeToString()
tfrecord_out.write(serialized) tfrecord_out.write(serialized)
tfrecord_out.close() tfrecord_out.close()
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment