to_tfrecord.py 3.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
#coding=utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pandas as pd
import os
import glob

import tensorflow as tf
import numpy as np
from multiprocessing import Pool as ThreadPool

flags = tf.app.flags
FLAGS = flags.FLAGS
LOG = tf.logging

tf.app.flags.DEFINE_string("input_dir", "./", "input dir")
tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")

22

23 24 25 26 27 28 29
def gen_tfrecords(in_file):
    basename = os.path.basename(in_file) + ".tfrecord"
    out_file = os.path.join(FLAGS.output_dir, basename)
    tfrecord_out = tf.python_io.TFRecordWriter(out_file)
    df = pd.read_csv(in_file)

    for i in range(df.shape[0]):
30
        feats = ["ucity_id", "ccity_name", "device_type", "manufacturer",
31 32
                 "channel", "top", "time", "stat_date", "hospital_id",
                 "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]
33 34 35
        id = np.array([])
        for j in feats:
            id = np.append(id,df[j][i])
36 37
        app_list = np.array(str(df["app_list"][i]).split(","))
        level2_list = np.array(str(df["clevel2_id"][i]).split(","))
38
        level3_list = np.array(str(df["level3_ids"][i]).split(","))
张彦钊's avatar
张彦钊 committed
39 40 41 42 43 44 45 46
        tag1_list = np.array(str(df["tag1"][i]).split(","))
        tag2_list = np.array(str(df["tag2"][i]).split(","))
        tag3_list = np.array(str(df["tag3"][i]).split(","))
        tag4_list = np.array(str(df["tag4"][i]).split(","))
        tag5_list = np.array(str(df["tag5"][i]).split(","))
        tag6_list = np.array(str(df["tag6"][i]).split(","))
        tag7_list = np.array(str(df["tag7"][i]).split(","))

47 48 49
        features = tf.train.Features(feature={
            "y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])),
            "z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])),
50
            "ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int))),
51 52
            "app_list": tf.train.Feature(int64_list=tf.train.Int64List(value=app_list.astype(np.int))),
            "level2_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level2_list.astype(np.int))),
张彦钊's avatar
张彦钊 committed
53 54 55 56 57 58 59 60
            "level3_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level3_list.astype(np.int))),
            "tag1_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag1_list.astype(np.int))),
            "tag2_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag2_list.astype(np.int))),
            "tag3_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag3_list.astype(np.int))),
            "tag4_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag4_list.astype(np.int))),
            "tag5_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag5_list.astype(np.int))),
            "tag6_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag6_list.astype(np.int))),
            "tag7_list": tf.train.Feature(int64_list=tf.train.Int64List(value=tag7_list.astype(np.int)))
61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
        })

        example = tf.train.Example(features = features)
        serialized = example.SerializeToString()
        tfrecord_out.write(serialized)
    tfrecord_out.close()

def main(_):
    if not os.path.exists(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)
    file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv"))
    print("total files: %d" % len(file_list))

    pool = ThreadPool(FLAGS.threads) # Sets the pool size
    pool.map(gen_tfrecords, file_list)
    pool.close()
    pool.join()


if __name__ == "__main__":
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()