Commit 72cbd9da authored by Your Name's avatar Your Name

并行数据转换

parent 856be02b
...@@ -71,15 +71,18 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): ...@@ -71,15 +71,18 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
return parsed, {"y": y, "z": z} return parsed, {"y": y, "z": z}
# Extract lines from input files using the Dataset API, can pass one filename or filename list # Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8).prefetch(500000) # multi-thread pre-process then prefetch # dataset = tf.data.TFRecordDataset(filenames).map(_parse_fn, num_parallel_calls=8).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory) # Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle: # if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256) # dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together. # epochs from blending together.
dataset = dataset.repeat(num_epochs) # dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use # dataset = dataset.batch(batch_size) # Batch size to use
dataset = tf.data.TFRecordDataset(filenames).apply(tf.contrib.map_and_batch(map_func=_parse_fn, batch_size=batch_size))
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐 # dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
#return dataset.make_one_shot_iterator() #return dataset.make_one_shot_iterator()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment