Commit 1a556d84 authored by 宋柯's avatar 宋柯

模型调试

parent 22d0a85e
import redis
import sys
import os
import json
def getRedisConn():
pool = redis.ConnectionPool(host="172.16.50.145",password="XfkMCCdWDIU%ls$h",port=6379,db=0)
conn = redis.Redis(connection_pool=pool)
# conn = redis.Redis(host="172.16.50.145", port=6379, password="XfkMCCdWDIU%ls$h",db=0)
# conn = redis.Redis(host="172.18.51.10", port=6379, db=0, decode_responses = True) #test
return conn
if len(sys.argv) == 2:
save_dir = sys.argv[1]
else:
save_dir = '/data/files/wideAndDeep/'
print('save_dir: ', save_dir)
if not os.path.exists(save_dir):
print('mkdir save_dir: ', save_dir)
os.makedirs(save_dir)
conn = getRedisConn()
vocab_keys = conn.lrange("strategy:all:vocab", 0, -1)
print("vocab_keys: ", vocab_keys[0])
vocab_keys = eval(vocab_keys[0])
for vocab_key in vocab_keys:
print('vocab_key: ', vocab_key)
splits = vocab_key.split(":")
field = splits[1]
filename = field + "_vocab.csv"
print('filename: ', filename)
with open(os.path.join(save_dir, filename), 'w') as f:
texts = conn.lrange(vocab_key, 0, -1)
texts = list(filter(lambda x: x != '', eval(texts[0])))
print('texts: ', len(texts))
f.write('\n'.join(texts))
os.system("hdfs dfs -getmerge /strategy/train_samples_tfrecord {save_dir}train_samples.tfrecord".format(save_dir = save_dir))
os.system("hdfs dfs -getmerge /strategy/eval_samples_tfrecord {save_dir}eval_samples.tfrecord".format(save_dir = save_dir))
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
import sys
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
BASE_DIR = '/data/files/wideAndDeep/'
def input_fn(csv_path, epoch, shuffle, batch_size):
dataset = tf.data.TFRecordDataset(csv_path, buffer_size = 1024, num_parallel_reads = 2)
dics = {
'ITEM_CATEGORY_card_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'USER_CATEGORY_device_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'USER_CATEGORY_os': tf.FixedLenFeature((), tf.string, default_value='-1'),
'USER_CATEGORY_user_city_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'USER_MULTI_CATEGORY_second_solutions': tf.VarLenFeature(tf.string),
'USER_MULTI_CATEGORY_second_demands': tf.VarLenFeature(tf.string),
'USER_MULTI_CATEGORY_second_positions': tf.VarLenFeature(tf.string),
'USER_MULTI_CATEGORY_projects': tf.VarLenFeature(tf.string),
'ITEM_NUMERIC_click_count_sum': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_click_count_avg': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_click_count_stddev': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_exp_count_sum': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_exp_count_avg': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_exp_count_stddev': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_discount': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_case_count': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_NUMERIC_sales_count': tf.FixedLenFeature((), tf.float32, default_value=0),
'ITEM_CATEGORY_service_type': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_merchant_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_doctor_type': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_doctor_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_doctor_famous': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_hospital_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_hospital_city_tag_id': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_hospital_type': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_CATEGORY_hospital_is_high_quality': tf.FixedLenFeature((), tf.string, default_value='-1'),
'ITEM_MULTI_CATEGORY_second_demands': tf.VarLenFeature(tf.string),
'ITEM_MULTI_CATEGORY_second_solutions': tf.VarLenFeature(tf.string),
'ITEM_MULTI_CATEGORY_second_positions': tf.VarLenFeature(tf.string),
'ITEM_MULTI_CATEGORY_projects': tf.VarLenFeature(tf.string),
'ITEM_NUMERIC_sku_price': tf.FixedLenFeature((), tf.float32, default_value=0),
'label': tf.FixedLenFeature((), tf.int64, default_value=0),
}
def parse_serialized_example(serialized_example):
parsed_example = tf.parse_single_example(serialized_example, dics)
parsed_example['USER_MULTI_CATEGORY_second_solutions'] = tf.sparse_tensor_to_dense(
parsed_example['USER_MULTI_CATEGORY_second_solutions'], default_value='-1')
parsed_example['USER_MULTI_CATEGORY_second_demands'] = tf.sparse_tensor_to_dense(
parsed_example['USER_MULTI_CATEGORY_second_demands'], default_value='-1')
parsed_example['USER_MULTI_CATEGORY_second_positions'] = tf.sparse_tensor_to_dense(
parsed_example['USER_MULTI_CATEGORY_second_positions'], default_value='-1')
parsed_example['USER_MULTI_CATEGORY_projects'] = tf.sparse_tensor_to_dense(
parsed_example['USER_MULTI_CATEGORY_projects'], default_value='-1')
parsed_example['ITEM_MULTI_CATEGORY_second_demands'] = tf.sparse_tensor_to_dense(
parsed_example['ITEM_MULTI_CATEGORY_second_demands'], default_value='-1')
parsed_example['ITEM_MULTI_CATEGORY_second_solutions'] = tf.sparse_tensor_to_dense(
parsed_example['ITEM_MULTI_CATEGORY_second_solutions'], default_value='-1')
parsed_example['ITEM_MULTI_CATEGORY_second_positions'] = tf.sparse_tensor_to_dense(
parsed_example['ITEM_MULTI_CATEGORY_second_positions'], default_value='-1')
parsed_example['ITEM_MULTI_CATEGORY_projects'] = tf.sparse_tensor_to_dense(
parsed_example['ITEM_MULTI_CATEGORY_projects'], default_value='-1')
return parsed_example, parsed_example.pop('label')
padded_shapes = ({'ITEM_CATEGORY_card_id': (), 'USER_CATEGORY_device_id': (), 'USER_CATEGORY_os': (),
'USER_CATEGORY_user_city_id': (), 'USER_MULTI_CATEGORY_second_solutions': [-1],
'USER_MULTI_CATEGORY_second_demands': [-1], 'USER_MULTI_CATEGORY_second_positions': [-1],
'USER_MULTI_CATEGORY_projects': [-1], 'ITEM_NUMERIC_click_count_sum': (),
'ITEM_NUMERIC_click_count_avg': (), 'ITEM_NUMERIC_click_count_stddev': (),
'ITEM_NUMERIC_exp_count_sum': (), 'ITEM_NUMERIC_exp_count_avg': (),
'ITEM_NUMERIC_exp_count_stddev': (), 'ITEM_NUMERIC_discount': (), 'ITEM_NUMERIC_case_count': (),
'ITEM_NUMERIC_sales_count': (), 'ITEM_CATEGORY_service_type': (), 'ITEM_CATEGORY_merchant_id': (),
'ITEM_CATEGORY_doctor_type': (), 'ITEM_CATEGORY_doctor_id': (), 'ITEM_CATEGORY_doctor_famous': (),
'ITEM_CATEGORY_hospital_id': (), 'ITEM_CATEGORY_hospital_city_tag_id': (),
'ITEM_CATEGORY_hospital_type': (), 'ITEM_CATEGORY_hospital_is_high_quality': (),
'ITEM_MULTI_CATEGORY_second_demands': [-1], 'ITEM_MULTI_CATEGORY_second_solutions': [-1],
'ITEM_MULTI_CATEGORY_second_positions': [-1], 'ITEM_MULTI_CATEGORY_projects': [-1],
'ITEM_NUMERIC_sku_price': ()}, ())
padding_values = ({'ITEM_CATEGORY_card_id': '-1', 'USER_CATEGORY_device_id': '-1', 'USER_CATEGORY_os': '-1',
'USER_CATEGORY_user_city_id': '-1', 'USER_MULTI_CATEGORY_second_solutions': '-1',
'USER_MULTI_CATEGORY_second_demands': '-1', 'USER_MULTI_CATEGORY_second_positions': '-1',
'USER_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_click_count_sum': 0.0,
'ITEM_NUMERIC_click_count_avg': 0.0, 'ITEM_NUMERIC_click_count_stddev': 0.0,
'ITEM_NUMERIC_exp_count_sum': 0.0, 'ITEM_NUMERIC_exp_count_avg': 0.0,
'ITEM_NUMERIC_exp_count_stddev': 0.0, 'ITEM_NUMERIC_discount': 0.0,
'ITEM_NUMERIC_case_count': 0.0, 'ITEM_NUMERIC_sales_count': 0.0,
'ITEM_CATEGORY_service_type': '-1', 'ITEM_CATEGORY_merchant_id': '-1',
'ITEM_CATEGORY_doctor_type': '-1', 'ITEM_CATEGORY_doctor_id': '-1',
'ITEM_CATEGORY_doctor_famous': '-1', 'ITEM_CATEGORY_hospital_id': '-1',
'ITEM_CATEGORY_hospital_city_tag_id': '-1', 'ITEM_CATEGORY_hospital_type': '-1',
'ITEM_CATEGORY_hospital_is_high_quality': '-1', 'ITEM_MULTI_CATEGORY_second_demands': '-1',
'ITEM_MULTI_CATEGORY_second_solutions': '-1', 'ITEM_MULTI_CATEGORY_second_positions': '-1',
'ITEM_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_sku_price': 0.0},
tf.constant(0, dtype=tf.int64))
dataset = dataset.map(parse_serialized_example, num_parallel_calls = tf.data.experimental.AUTOTUNE)
if shuffle:
dataset = dataset.shuffle(1024)
else:
dataset = dataset
dataset = dataset.padded_batch(batch_size, padded_shapes, padding_values = padding_values)
dataset.prefetch(tf.data.experimental.AUTOTUNE)
dataset.repeat(epoch)
return dataset
boundaries = [0, 10, 100]
ITEM_NUMERIC_click_count_sum_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_click_count_sum'), boundaries)
ITEM_NUMERIC_exp_count_sum_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_exp_count_sum'), boundaries)
ITEM_NUMERIC_click_count_avg_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_click_count_avg'), boundaries)
ITEM_NUMERIC_exp_count_avg_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_exp_count_avg'), boundaries)
boundaries = [0, 0.01, 0.1]
ITEM_NUMERIC_click_count_stddev_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_click_count_stddev'), boundaries)
ITEM_NUMERIC_exp_count_stddev_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_exp_count_stddev'), boundaries)
boundaries = [0, 0.01, 0.1, 1]
ITEM_NUMERIC_discount_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_discount'), boundaries)
boundaries = [0, 10, 100]
ITEM_NUMERIC_case_count_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_case_count'), boundaries)
ITEM_NUMERIC_sales_count_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_sales_count'), boundaries)
ITEM_NUMERIC_sku_price_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_sku_price'), boundaries)
USER_CATEGORY_device_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_CATEGORY_device_id', BASE_DIR + 'USER_CATEGORY_device_id_vocab.csv')
USER_CATEGORY_os_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_CATEGORY_os', BASE_DIR + 'USER_CATEGORY_os_vocab.csv')
USER_CATEGORY_user_city_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_CATEGORY_user_city_id', BASE_DIR + 'USER_CATEGORY_user_city_id_vocab.csv')
USER_MULTI_CATEGORY__second_solutions_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_second_solutions', BASE_DIR + 'USER_MULTI_CATEGORY_second_solutions_vocab.csv')
USER_MULTI_CATEGORY__second_positions_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_second_positions', BASE_DIR + 'USER_MULTI_CATEGORY_second_positions_vocab.csv')
USER_MULTI_CATEGORY__second_demands_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_second_demands', BASE_DIR + 'USER_MULTI_CATEGORY_second_demands_vocab.csv')
USER_MULTI_CATEGORY__projects_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_projects', BASE_DIR + 'USER_MULTI_CATEGORY_projects_vocab.csv')
ITEM_CATEGORY_card_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_card_id', BASE_DIR + 'ITEM_CATEGORY_card_id_vocab.csv')
ITEM_CATEGORY_service_type_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_service_type', BASE_DIR + 'ITEM_CATEGORY_service_type_vocab.csv')
ITEM_CATEGORY_merchant_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_merchant_id', BASE_DIR + 'ITEM_CATEGORY_merchant_id_vocab.csv')
ITEM_CATEGORY_doctor_type_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_doctor_type', BASE_DIR + 'ITEM_CATEGORY_doctor_type_vocab.csv')
ITEM_CATEGORY_doctor_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_doctor_id', BASE_DIR + 'ITEM_CATEGORY_doctor_id_vocab.csv')
ITEM_CATEGORY_doctor_famous_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_doctor_famous', BASE_DIR + 'ITEM_CATEGORY_doctor_famous_vocab.csv')
ITEM_CATEGORY_hospital_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_id', BASE_DIR + 'ITEM_CATEGORY_hospital_id_vocab.csv')
ITEM_CATEGORY_hospital_city_tag_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_city_tag_id', BASE_DIR + 'ITEM_CATEGORY_hospital_city_tag_id_vocab.csv')
ITEM_CATEGORY_hospital_type_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_type', BASE_DIR + 'ITEM_CATEGORY_hospital_type_vocab.csv')
ITEM_CATEGORY_hospital_is_high_quality_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_is_high_quality', BASE_DIR + 'ITEM_CATEGORY_hospital_is_high_quality_vocab.csv')
ITEM_MULTI_CATEGORY__second_solutions_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_second_solutions', BASE_DIR + 'ITEM_MULTI_CATEGORY_second_solutions_vocab.csv')
ITEM_MULTI_CATEGORY__second_positions_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_second_positions', BASE_DIR + 'ITEM_MULTI_CATEGORY_second_positions_vocab.csv')
ITEM_MULTI_CATEGORY__second_demands_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_second_demands', BASE_DIR + 'ITEM_MULTI_CATEGORY_second_demands_vocab.csv')
ITEM_MULTI_CATEGORY__projects_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_projects', BASE_DIR + 'ITEM_MULTI_CATEGORY_projects_vocab.csv')
def embedding_fc(categorical_column, dim):
return tf.feature_column.embedding_column(categorical_column, dim)
linear_feature_columns = [
ITEM_NUMERIC_click_count_sum_fc,
ITEM_NUMERIC_exp_count_sum_fc,
ITEM_NUMERIC_click_count_avg_fc,
ITEM_NUMERIC_exp_count_avg_fc,
ITEM_NUMERIC_click_count_stddev_fc,
ITEM_NUMERIC_exp_count_stddev_fc,
ITEM_NUMERIC_discount_fc,
ITEM_NUMERIC_case_count_fc,
ITEM_NUMERIC_sales_count_fc,
ITEM_NUMERIC_sku_price_fc,
embedding_fc(ITEM_CATEGORY_card_id_fc, 1),
embedding_fc(ITEM_CATEGORY_service_type_fc, 1),
embedding_fc(ITEM_CATEGORY_merchant_id_fc, 1),
embedding_fc(ITEM_CATEGORY_doctor_type_fc, 1),
embedding_fc(ITEM_CATEGORY_doctor_id_fc, 1),
embedding_fc(ITEM_CATEGORY_doctor_famous_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_id_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_city_tag_id_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_type_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_is_high_quality_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__projects_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__second_demands_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__second_positions_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__second_solutions_fc, 1),
]
dnn_feature_columns = [
embedding_fc(USER_CATEGORY_device_id_fc, 8),
embedding_fc(USER_CATEGORY_os_fc, 8),
embedding_fc(USER_CATEGORY_user_city_id_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__second_solutions_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__second_positions_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__second_demands_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__projects_fc, 8),
embedding_fc(ITEM_NUMERIC_click_count_sum_fc, 8),
embedding_fc(ITEM_NUMERIC_exp_count_sum_fc, 8),
embedding_fc(ITEM_NUMERIC_click_count_avg_fc, 8),
embedding_fc(ITEM_NUMERIC_exp_count_avg_fc, 8),
embedding_fc(ITEM_NUMERIC_click_count_stddev_fc, 8),
embedding_fc(ITEM_NUMERIC_exp_count_stddev_fc, 8),
embedding_fc(ITEM_NUMERIC_discount_fc, 8),
embedding_fc(ITEM_NUMERIC_case_count_fc, 8),
embedding_fc(ITEM_NUMERIC_sales_count_fc, 8),
embedding_fc(ITEM_NUMERIC_sku_price_fc, 8),
embedding_fc(ITEM_CATEGORY_card_id_fc, 8),
embedding_fc(ITEM_CATEGORY_service_type_fc, 8),
embedding_fc(ITEM_CATEGORY_merchant_id_fc, 8),
embedding_fc(ITEM_CATEGORY_doctor_type_fc, 8),
embedding_fc(ITEM_CATEGORY_doctor_id_fc, 8),
embedding_fc(ITEM_CATEGORY_doctor_famous_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_id_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_city_tag_id_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_type_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_is_high_quality_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__projects_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__second_demands_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__second_positions_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__second_solutions_fc, 8),
]
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2"
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
distribution = tf.distribute.MirroredStrategy()
# session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True)
session_config = tf.compat.v1.ConfigProto(allow_soft_placement = True)
session_config.gpu_options.allow_growth = True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, session_config = session_config)
wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR + 'model_tfrecord',
linear_feature_columns = linear_feature_columns,
dnn_feature_columns = dnn_feature_columns,
dnn_hidden_units = [128, 32],
dnn_dropout = 0.5,
config = config)
# early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(wideAndDeepModel, eval_dir = wideAndDeepModel.eval_dir(), metric_name='auc', max_steps_without_decrease=1000, min_steps = 100)
# early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(wideAndDeepModel, metric_name = 'auc', max_steps_without_increase = 1000, min_steps = 1000)
hooks = []
train_spec = tf.estimator.TrainSpec(input_fn = lambda: input_fn(BASE_DIR + 'train_samples.tfrecord', 20, True, 512), hooks = hooks)
serving_feature_spec = tf.feature_column.make_parse_example_spec(
linear_feature_columns + dnn_feature_columns)
serving_input_receiver_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(
serving_feature_spec))
exporter = tf.estimator.BestExporter(
name = "best_exporter",
compare_fn = lambda best_eval_result, current_eval_result: current_eval_result['auc'] > best_eval_result['auc'],
serving_input_receiver_fn = serving_input_receiver_fn,
exports_to_keep = 3)
eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_samples.tfrecord', 1, False, 2 ** 15), steps = None, throttle_secs = 120, exporters = exporter)
# def my_auc(labels, predictions):
# return {'auc_pr_careful_interpolation': tf.metrics.auc(labels, predictions['logistic'], curve='ROC',
# summation_method='careful_interpolation')}
# wideAndDeepModel = tf.contrib.estimator.add_metrics(wideAndDeepModel, my_auc)
tf.estimator.train_and_evaluate(wideAndDeepModel, train_spec, eval_spec)
wideAndDeepModel.evaluate(lambda: input_fn(BASE_DIR + 'eval_samples.tfrecord', 1, False, 2 ** 15))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment