Commit 4f2cdd96 authored by 张彦钊's avatar 张彦钊

Delete test.py

parent 5cc22188
#coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pandas as pd
import sys
import os
import glob
import tensorflow as tf
import numpy as np
import re
from multiprocessing import Pool as ThreadPool
flags = tf.app.flags
FLAGS = flags.FLAGS
LOG = tf.logging
tf.app.flags.DEFINE_string("input_dir", "./", "input dir")
tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")
def gen_tfrecords(in_file):
basename = os.path.basename(in_file) + ".tfrecord"
out_file = os.path.join(FLAGS.output_dir, basename)
tfrecord_out = tf.python_io.TFRecordWriter(out_file)
df = pd.read_csv(in_file)
for i in range(df.shape[0]):
feats = ["ucity_id", "clevel1_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "l1", "time", "stat_date","l2"]
id = np.array([])
for j in feats:
id = np.append(id,df[j][i])
features = tf.train.Features(feature={
"y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])),
"z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])),
"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int)))
})
example = tf.train.Example(features = features)
serialized = example.SerializeToString()
tfrecord_out.write(serialized)
tfrecord_out.close()
def main(_):
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv"))
print("total files: %d" % len(file_list))
pool = ThreadPool(FLAGS.threads) # Sets the pool size
pool.map(gen_tfrecords, file_list)
pool.close()
pool.join()
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment