Commit 8f983d69 authored by 王志伟's avatar 王志伟

统计icon数据

parents fc52c288 b452ad9c
......@@ -46,7 +46,7 @@ tf.app.flags.DEFINE_string("servable_model_dir", '', "export servable model for
tf.app.flags.DEFINE_string("task_type", 'train', "task type {train, infer, eval, export}")
tf.app.flags.DEFINE_boolean("clear_existing_model", False, "clear existing model or not")
#40362692,0,0,216:9342395:1.0 301:9351665:1.0 205:7702673:1.0 206:8317829:1.0 207:8967741:1.0 508:9356012:2.30259 210:9059239:1.0 210:9042796:1.0 210:9076972:1.0 210:9103884:1.0 210:9063064:1.0 127_14:3529789:2.3979 127_14:3806412:2.70805
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames)
def _parse_fn(record):
......@@ -71,17 +71,29 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
return parsed, {"y": y, "z": z}
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# if perform_shuffle:
# dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
# dataset = dataset.repeat(num_epochs)
# dataset = dataset.batch(batch_size) # Batch size to use
files = tf.data.Dataset.list_files(filenames)
dataset = files.apply(
tf.data.experimental.parallel_interleave(
lambda file: tf.data.TFRecordDataset(file),
cycle_length=8
)
)
dataset = dataset.apply(tf.data.experimental.map_and_batch(map_func=_parse_fn, batch_size=batch_size, num_parallel_calls=8))
dataset = dataset.prefetch(10000)
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
#return dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
......@@ -90,6 +102,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
#print(batch_features,batch_labels)
return batch_features, batch_labels
def model_fn(features, labels, mode, params):
"""Bulid Model function f(x) for Estimator."""
#------hyperparameters----
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment