Commit bf312222 authored by Your Name's avatar Your Name

map and batch merge

parent b0494cfa
...@@ -81,11 +81,12 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): ...@@ -81,11 +81,12 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
# dataset = dataset.repeat(num_epochs) # dataset = dataset.repeat(num_epochs)
# dataset = dataset.batch(batch_size) # Batch size to use # dataset = dataset.batch(batch_size) # Batch size to use
dataset = tf.data.TFRecordDataset(filenames).apply(tf.contrib.data.map_and_batch(map_func=_parse_fn, batch_size=batch_size, num_parallel_calls=8))
dataset = tf.data.TFRecordDataset(filenames).apply(tf.data.experimental.map_and_batch(map_func=_parse_fn, batch_size=batch_size, num_parallel_calls=8))
dataset = dataset.prefetch(500000) dataset = dataset.prefetch(500000)
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
# dataset = dataset.padded_batch(batch_size, padded_shapes=({"feeds_ids": [None], "feeds_vals": [None], "title_ids": [None]}, [None])) #不定长补齐
#return dataset.make_one_shot_iterator() #return dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator() iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next() batch_features, batch_labels = iterator.get_next()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment