#coding=utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import pandas as pd
import sys
import os
import glob

import tensorflow as tf
import numpy as np
import re
from multiprocessing import Pool as ThreadPool

flags = tf.app.flags
FLAGS = flags.FLAGS
LOG = tf.logging

tf.app.flags.DEFINE_string("input_dir", "./", "input dir")
tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")

def gen_tfrecords(in_file):
    basename = os.path.basename(in_file) + ".tfrecord"
    out_file = os.path.join(FLAGS.output_dir, basename)
    tfrecord_out = tf.python_io.TFRecordWriter(out_file)
    df = pd.read_csv(in_file)

    for i in range(df.shape[0]):
        feats = ["ucity_id", "clevel1_id", "ccity_name", "device_type", "manufacturer",
                "channel", "top", "l1", "time", "stat_date","l2"]
        id = np.array([])
        for j in feats:
            id = np.append(id,df[j][i])
        features = tf.train.Features(feature={
            "y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])),
            "z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])),
            "ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int)))
        })

        example = tf.train.Example(features = features)
        serialized = example.SerializeToString()
        tfrecord_out.write(serialized)
    tfrecord_out.close()

def main(_):
    if not os.path.exists(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)
    file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv"))
    print("total files: %d" % len(file_list))

    pool = ThreadPool(FLAGS.threads) # Sets the pool size
    pool.map(gen_tfrecords, file_list)
    pool.close()
    pool.join()


if __name__ == "__main__":
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()