get_tfrecord.py 5.03 KB
#coding=utf-8

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import sys
import os
import glob

import tensorflow as tf
import numpy as np
import re
from multiprocessing import Pool as ThreadPool

flags = tf.app.flags
FLAGS = flags.FLAGS
LOG = tf.logging

tf.app.flags.DEFINE_string("input_dir", "./", "input dir")
tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")

#保证顺序以及字段数量
#User_Fileds = set(['101','109_14','110_14','127_14','150_14','121','122','124','125','126','127','128','129'])
#Ad_Fileds = set(['205','206','207','210','216'])
#Context_Fileds = set(['508','509','702','853','301'])
#Common_Fileds = {'1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8','9':'9','10':'10','11':'11','12':'12','13':'13','14':'14','15':'15','16':'16','17':'17','18':'18','19':'19','20':'20','21':'21','22':'22','23':'23'}
Common_Fileds = {'1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8'}
UMH_Fileds = {'109_14':('u_cat','12'),'110_14':('u_shop','13'),'127_14':('u_brand','14'),'150_14':('u_int','15')}      #user multi-hot feature
Ad_Fileds = {'206':('a_cat','16'),'207':('a_shop','17'),'210':('a_int','18'),'216':('a_brand','19')}                  #ad feature for DIN

#40362692,0,0,216:9342395:1.0 301:9351665:1.0 205:7702673:1.0 206:8317829:1.0 207:8967741:1.0 508:9356012:2.30259 210:9059239:1.0 210:9042796:1.0 210:9076972:1.0 210:9103884:1.0 210:9063064:1.0 127_14:3529789:2.3979 127_14:3806412:2.70805
def gen_tfrecords(in_file):
    basename = os.path.basename(in_file) + ".tfrecord"
    out_file = os.path.join(FLAGS.output_dir, basename)
    tfrecord_out = tf.python_io.TFRecordWriter(out_file)
    with open(in_file) as fi:
        for line in fi:
            line = line.strip().split('\t')[-1]
            fields = line.strip().split(',')
            if len(fields) != 4:
                continue
            #1 label
            y = [float(fields[1])]
            z = [float(fields[2])]
            feature = {
                "y": tf.train.Feature(float_list = tf.train.FloatList(value=y)),
                "z": tf.train.Feature(float_list = tf.train.FloatList(value=z))
             }

            splits = re.split('[ :]', fields[3])
            ffv = np.reshape(splits,(-1,3))
            #common_mask = np.array([v in Common_Fileds for v in ffv[:,0]])
            #af_mask = np.array([v in Ad_Fileds for v in ffv[:,0]])
            #cf_mask = np.array([v in Context_Fileds for v in ffv[:,0]])

            #2 不需要特殊处理的特征
            feat_ids = np.array([])
            #feat_vals = np.array([])
            for f, def_id in Common_Fileds.items():
                if f in ffv[:,0]:
                    mask = np.array(f == ffv[:,0])
                    feat_ids = np.append(feat_ids, ffv[mask,1])
                    #np.append(feat_vals,ffv[mask,2].astype(np.float))
                else:
                    feat_ids = np.append(feat_ids, def_id)
                    #np.append(feat_vals,1.0)
            feature.update({"feat_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int)))})
                            #"feat_vals": tf.train.Feature(float_list=tf.train.FloatList(value=feat_vals))})

            #3 特殊字段单独处理
            for f, (fname, def_id) in UMH_Fileds.items():
                if f in ffv[:,0]:
                    mask = np.array(f == ffv[:,0])
                    feat_ids = ffv[mask,1]
                    feat_vals= ffv[mask,2]
                else:
                    feat_ids = np.array([def_id])
                    feat_vals = np.array([1.0])
                feature.update({fname+"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int))),
                                fname+"vals": tf.train.Feature(float_list=tf.train.FloatList(value=feat_vals.astype(np.float)))})

            for f, (fname, def_id) in Ad_Fileds.items():
                if f in ffv[:,0]:
                    mask = np.array(f == ffv[:,0])
                    feat_ids = ffv[mask,1]
                else:
                    feat_ids = np.array([def_id])
                feature.update({fname+"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int)))})

            # serialized to Example
            example = tf.train.Example(features = tf.train.Features(feature = feature))
            serialized = example.SerializeToString()
            tfrecord_out.write(serialized)
            #num_lines += 1
            #if num_lines % 10000 == 0:
            #    print("Process %d" % num_lines)
    tfrecord_out.close()

def main(_):
    if not os.path.exists(FLAGS.output_dir):
        os.mkdir(FLAGS.output_dir)
    file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv"))
    print("total files: %d" % len(file_list))

    pool = ThreadPool(FLAGS.threads) # Sets the pool size
    pool.map(gen_tfrecords, file_list)
    pool.close()
    pool.join()


if __name__ == "__main__":
    tf.logging.set_verbosity(tf.logging.INFO)
    tf.app.run()