Commit d8b86606 authored by 宋柯's avatar 宋柯

模型调试

parent 04b1805f
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
import sys
import time
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
start = time.time()
BASE_DIR = '/data/files/wideAndDeep/'
def input_fn(csv_path, epoch, shuffle, batch_size):
......@@ -78,10 +81,10 @@ def input_fn(csv_path, epoch, shuffle, batch_size):
'ITEM_MULTI_CATEGORY_second_solutions': '-1', 'ITEM_MULTI_CATEGORY_second_positions': '-1',
'ITEM_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_sku_price': 0.0}, 0.0)
dataset = dataset.map(parse_line, num_parallel_calls=8).cache()
dataset = dataset.map(parse_line, num_parallel_calls = 8).cache()
dataset = dataset.padded_batch(batch_size, padded_shapes, padding_values=padding_values)
if shuffle:
dataset = dataset.shuffle(1000).prefetch(512 * 100).repeat(epoch)
dataset = dataset.shuffle(2048).prefetch(512 * 100).repeat(epoch)
else:
dataset = dataset.prefetch(512 * 100).repeat(epoch)
......@@ -216,7 +219,7 @@ session_config.gpu_options.allow_growth = True
config = tf.estimator.RunConfig(save_checkpoints_steps = 3000, session_config = session_config)
wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR + 'model',
wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR + 'model_csv',
linear_feature_columns = linear_feature_columns,
dnn_feature_columns = dnn_feature_columns,
dnn_hidden_units = [128, 32],
......@@ -254,3 +257,5 @@ eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_s
tf.estimator.train_and_evaluate(wideAndDeepModel, train_spec, eval_spec)
wideAndDeepModel.evaluate(lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15))
print("训练耗时: {}s".format(time.time() - start))
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment