Commit 4f8b6803 authored by 王志伟's avatar 王志伟
parents 96d74920 c13effaa
This diff is collapsed.
#coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import glob
import tensorflow as tf
import numpy as np
import re
from multiprocessing import Pool as ThreadPool
flags = tf.app.flags
FLAGS = flags.FLAGS
LOG = tf.logging
tf.app.flags.DEFINE_string("input_dir", "./", "input dir")
tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")
#保证顺序以及字段数量
#User_Fileds = set(['101','109_14','110_14','127_14','150_14','121','122','124','125','126','127','128','129'])
#Ad_Fileds = set(['205','206','207','210','216'])
#Context_Fileds = set(['508','509','702','853','301'])
#Common_Fileds = {'1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8','9':'9','10':'10','11':'11','12':'12','13':'13','14':'14','15':'15','16':'16','17':'17','18':'18','19':'19','20':'20','21':'21','22':'22','23':'23'}
Common_Fileds = {'1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8'}
UMH_Fileds = {'109_14':('u_cat','12'),'110_14':('u_shop','13'),'127_14':('u_brand','14'),'150_14':('u_int','15')} #user multi-hot feature
Ad_Fileds = {'206':('a_cat','16'),'207':('a_shop','17'),'210':('a_int','18'),'216':('a_brand','19')} #ad feature for DIN
#40362692,0,0,216:9342395:1.0 301:9351665:1.0 205:7702673:1.0 206:8317829:1.0 207:8967741:1.0 508:9356012:2.30259 210:9059239:1.0 210:9042796:1.0 210:9076972:1.0 210:9103884:1.0 210:9063064:1.0 127_14:3529789:2.3979 127_14:3806412:2.70805
def gen_tfrecords(in_file):
basename = os.path.basename(in_file) + ".tfrecord"
out_file = os.path.join(FLAGS.output_dir, basename)
tfrecord_out = tf.python_io.TFRecordWriter(out_file)
with open(in_file) as fi:
for line in fi:
line = line.strip().split('\t')[-1]
fields = line.strip().split(',')
if len(fields) != 4:
continue
#1 label
y = [float(fields[1])]
z = [float(fields[2])]
feature = {
"y": tf.train.Feature(float_list = tf.train.FloatList(value=y)),
"z": tf.train.Feature(float_list = tf.train.FloatList(value=z))
}
splits = re.split('[ :]', fields[3])
ffv = np.reshape(splits,(-1,3))
#common_mask = np.array([v in Common_Fileds for v in ffv[:,0]])
#af_mask = np.array([v in Ad_Fileds for v in ffv[:,0]])
#cf_mask = np.array([v in Context_Fileds for v in ffv[:,0]])
#2 不需要特殊处理的特征
feat_ids = np.array([])
#feat_vals = np.array([])
for f, def_id in Common_Fileds.items():
if f in ffv[:,0]:
mask = np.array(f == ffv[:,0])
feat_ids = np.append(feat_ids, ffv[mask,1])
#np.append(feat_vals,ffv[mask,2].astype(np.float))
else:
feat_ids = np.append(feat_ids, def_id)
#np.append(feat_vals,1.0)
feature.update({"feat_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int)))})
#"feat_vals": tf.train.Feature(float_list=tf.train.FloatList(value=feat_vals))})
#3 特殊字段单独处理
for f, (fname, def_id) in UMH_Fileds.items():
if f in ffv[:,0]:
mask = np.array(f == ffv[:,0])
feat_ids = ffv[mask,1]
feat_vals= ffv[mask,2]
else:
feat_ids = np.array([def_id])
feat_vals = np.array([1.0])
feature.update({fname+"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int))),
fname+"vals": tf.train.Feature(float_list=tf.train.FloatList(value=feat_vals.astype(np.float)))})
for f, (fname, def_id) in Ad_Fileds.items():
if f in ffv[:,0]:
mask = np.array(f == ffv[:,0])
feat_ids = ffv[mask,1]
else:
feat_ids = np.array([def_id])
feature.update({fname+"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int)))})
# serialized to Example
example = tf.train.Example(features = tf.train.Features(feature = feature))
serialized = example.SerializeToString()
tfrecord_out.write(serialized)
#num_lines += 1
#if num_lines % 10000 == 0:
# print("Process %d" % num_lines)
tfrecord_out.close()
def main(_):
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv"))
print("total files: %d" % len(file_list))
pool = ThreadPool(FLAGS.threads) # Sets the pool size
pool.map(gen_tfrecords, file_list)
pool.close()
pool.join()
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
\ No newline at end of file
...@@ -11,7 +11,6 @@ import os ...@@ -11,7 +11,6 @@ import os
import json import json
import glob import glob
from datetime import date, timedelta from datetime import date, timedelta
from time import time
import random import random
import tensorflow as tf import tensorflow as tf
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment