Commit 11fb57e9 authored by 宋柯's avatar 宋柯

模型调试

parent 81455a71
...@@ -76,12 +76,12 @@ def input_fn(csv_path, epoch, shuffle, batch_size): ...@@ -76,12 +76,12 @@ def input_fn(csv_path, epoch, shuffle, batch_size):
'ITEM_MULTI_CATEGORY_second_solutions': '-1', 'ITEM_MULTI_CATEGORY_second_positions': '-1', 'ITEM_MULTI_CATEGORY_second_solutions': '-1', 'ITEM_MULTI_CATEGORY_second_positions': '-1',
'ITEM_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_sku_price': 0.0}, 0.0) 'ITEM_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_sku_price': 0.0}, 0.0)
dataset = dataset.map(parse_line, num_parallel_calls=2) dataset = dataset.map(parse_line, num_parallel_calls=8)
dataset = dataset.padded_batch(batch_size, padded_shapes, padding_values=padding_values) dataset = dataset.padded_batch(batch_size, padded_shapes, padding_values=padding_values)
if shuffle: if shuffle:
dataset = dataset.shuffle(1000).prefetch(10000).repeat(epoch) dataset = dataset.shuffle(1000).prefetch(512 * 10).repeat(epoch)
else: else:
dataset = dataset.prefetch(10000).repeat(epoch) dataset = dataset.prefetch(512 * 10).repeat(epoch)
return dataset return dataset
...@@ -205,11 +205,14 @@ print(device_lib.list_local_devices()) ...@@ -205,11 +205,14 @@ print(device_lib.list_local_devices())
distribution = tf.distribute.MirroredStrategy() distribution = tf.distribute.MirroredStrategy()
session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True) # session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True)
session_config = tf.compat.v1.ConfigProto(allow_soft_placement = True)
session_config.gpu_options.allow_growth = True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution) # config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config = tf.estimator.RunConfig(save_checkpoints_steps = 10000) config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, session_config = session_config)
wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR + 'model', wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR + 'model',
linear_feature_columns = linear_feature_columns, linear_feature_columns = linear_feature_columns,
...@@ -224,7 +227,7 @@ wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR ...@@ -224,7 +227,7 @@ wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR
hooks = [] hooks = []
train_spec = tf.estimator.TrainSpec(input_fn = lambda: input_fn(BASE_DIR + 'train_samples.csv', 20, True, 128), hooks = hooks) train_spec = tf.estimator.TrainSpec(input_fn = lambda: input_fn(BASE_DIR + 'train_samples.csv', 20, True, 512), hooks = hooks)
serving_feature_spec = tf.feature_column.make_parse_example_spec( serving_feature_spec = tf.feature_column.make_parse_example_spec(
linear_feature_columns + dnn_feature_columns) linear_feature_columns + dnn_feature_columns)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment