#coding=utf-8 from __future__ import absolute_import from __future__ import division from __future__ import print_function import pandas as pd import sys import os import glob import tensorflow as tf import numpy as np import re from multiprocessing import Pool as ThreadPool flags = tf.app.flags FLAGS = flags.FLAGS LOG = tf.logging tf.app.flags.DEFINE_string("input_dir", "./", "input dir") tf.app.flags.DEFINE_string("output_dir", "./", "output dir") tf.app.flags.DEFINE_integer("threads", 16, "threads num") def gen_tfrecords(in_file): basename = os.path.basename(in_file) + ".tfrecord" out_file = os.path.join(FLAGS.output_dir, basename) tfrecord_out = tf.python_io.TFRecordWriter(out_file) df = pd.read_csv(in_file) for i in range(df.shape[0]): feats = ["ucity_id", "clevel1_id", "ccity_name", "device_type", "manufacturer", "channel", "top", "l1", "time", "stat_date","l2"] id = np.array([]) for j in feats: id = np.append(id,df[j][i]) features = tf.train.Features(feature={ "y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])), "z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])), "ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int))) }) example = tf.train.Example(features = features) serialized = example.SerializeToString() tfrecord_out.write(serialized) tfrecord_out.close() def main(_): if not os.path.exists(FLAGS.output_dir): os.mkdir(FLAGS.output_dir) file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv")) print("total files: %d" % len(file_list)) pool = ThreadPool(FLAGS.threads) # Sets the pool size pool.map(gen_tfrecords, file_list) pool.close() pool.join() if __name__ == "__main__": tf.logging.set_verbosity(tf.logging.INFO) tf.app.run()