Commit f11eb7bc authored by 宋柯's avatar 宋柯

模型调试

parent d8b86606
import tensorflow as tf
from tensorflow.python.saved_model import signature_constants
from tensorflow.python.saved_model import tag_constants
from google.protobuf import text_format
import os
export_dir = 'inference/pb2saved'
graph_pb = '/Users/edz/PycharmProjects/serviceRec/train/saved_model_test/1640591747/saved_model.pb'
if os.path.exists(export_dir):
os.rmdir(export_dir)
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.gfile.GFile(graph_pb, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
sigs = {}
with tf.Session(graph=tf.Graph()) as sess:
# name="" is important to ensure we don't get spurious prefixing
tf.import_graph_def(graph_def, name="")
g = tf.get_default_graph()
print(sess.graph.get_name_scope())
print(sess.graph.get_all_collection_keys())
print(sess.graph.get_operations())
# input_ids = sess.graph.get_tensor_by_name(
# "input_ids:0")
# input_mask = sess.graph.get_tensor_by_name(
# "input_mask:0")
# segment_ids = sess.graph.get_tensor_by_name(
# "segment_ids:0")
# probabilities = g.get_tensor_by_name("loss/pred_prob:0")
# sigs[signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY] = \
# tf.saved_model.signature_def_utils.predict_signature_def(
# {
# "input_ids": input_ids,
# "input_mask": input_mask,
# "segment_ids": segment_ids
# }, {
# "probabilities": probabilities
# })
# builder.add_meta_graph_and_variables(sess,
# [tag_constants.SERVING],
# signature_def_map=sigs)
# builder.save()
\ No newline at end of file
import base64
import tensorflow as tf
import requests
import time
with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:
count = 0
examples = []
for line in f:
# print(line)
splits = line.split('|')
features = {
'ITEM_CATEGORY_card_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[0].encode()])),
'USER_CATEGORY_device_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[2].encode()])),
'USER_CATEGORY_os': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[3].encode()])),
'USER_CATEGORY_user_city_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[4].encode()])),
'USER_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[6].split(','))))),
'USER_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[7].split(','))))),
'USER_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[8].split(','))))),
'USER_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[9].split(','))))),
'ITEM_NUMERIC_click_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[10])])),
'ITEM_NUMERIC_click_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[11])])),
'ITEM_NUMERIC_click_count_stddev': tf.train.Feature(
float_list=tf.train.FloatList(value=[float(splits[12])])),
'ITEM_NUMERIC_exp_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[13])])),
'ITEM_NUMERIC_exp_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[14])])),
'ITEM_NUMERIC_exp_count_stddev': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[15])])),
'ITEM_NUMERIC_discount': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[16])])),
'ITEM_NUMERIC_case_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[17])])),
'ITEM_NUMERIC_sales_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[18])])),
'ITEM_CATEGORY_service_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[19].encode()])),
'ITEM_CATEGORY_merchant_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[20].encode()])),
'ITEM_CATEGORY_doctor_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[21].encode()])),
'ITEM_CATEGORY_doctor_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[22].encode()])),
'ITEM_CATEGORY_doctor_famous': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[23].encode()])),
'ITEM_CATEGORY_hospital_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[24].encode()])),
'ITEM_CATEGORY_hospital_city_tag_id': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[25].encode()])),
'ITEM_CATEGORY_hospital_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[26].encode()])),
'ITEM_CATEGORY_hospital_is_high_quality': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[27].encode()])),
'ITEM_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[28].split(','))))),
'ITEM_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[29].split(','))))),
'ITEM_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[30].split(','))))),
'ITEM_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[31].split(','))))),
'ITEM_NUMERIC_sku_price': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[32])])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(splits[5])])),
}
# print(features)
# print(splits[32])
tf_features = tf.train.Features(feature=features)
tf_example = tf.train.Example(features=tf_features)
tf_serialized = tf_example.SerializeToString()
examples.append({'b64': base64.b64encode(tf_serialized)})
count += 1
if count == 1000:
break
start = time.time()
res = requests.post("http://localhost:8501/v1/models/wide_deep:predict",
json={"inputs": {"examples": examples},
"signature_name": "predict"})
print(res.text)
print(time.time() - start)
import base64
import tensorflow as tf
import requests
#encoding=utf8
import requests
import numpy as np
import tensorflow.compat.v1 as tf
import time
tf.disable_v2_behavior()
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=3)
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc
tf.app.flags.DEFINE_string('server', 'localhost:8502', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS
def prediction():
options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), ('grpc.max_receive_message_length', 1000 * 1024 * 1024)]
channel = grpc.insecure_channel(FLAGS.server, options = options)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'wide_deep' #对应上图第一个方框
request.model_spec.signature_name = 'predict' #对应上图第二个方框with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:
for _ in range(1):
with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:
count = 0
examples = []
for line in f:
splits = line.split('|')
features = {
'ITEM_CATEGORY_card_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[0].encode()] * 2)),
'USER_CATEGORY_device_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[2].encode()] * 2)),
'USER_CATEGORY_os': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[3].encode()] * 2)),
'USER_CATEGORY_user_city_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[4].encode()] * 2)),
'USER_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[6].split(','))) * 2)),
'USER_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[7].split(','))) * 2)),
'USER_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[8].split(','))) * 2)),
'USER_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[9].split(','))) * 2)),
'ITEM_NUMERIC_click_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[10])] * 2)),
'ITEM_NUMERIC_click_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[11])] * 2)),
'ITEM_NUMERIC_click_count_stddev': tf.train.Feature(
float_list=tf.train.FloatList(value=[float(splits[12])])),
'ITEM_NUMERIC_exp_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[13])] * 2)),
'ITEM_NUMERIC_exp_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[14])] * 2)),
'ITEM_NUMERIC_exp_count_stddev': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[15])] * 2)),
'ITEM_NUMERIC_discount': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[16])] * 2)),
'ITEM_NUMERIC_case_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[17])] * 2)),
'ITEM_NUMERIC_sales_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[18])] * 2)),
'ITEM_CATEGORY_service_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[19].encode()] * 2)),
'ITEM_CATEGORY_merchant_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[20].encode()] * 2)),
'ITEM_CATEGORY_doctor_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[21].encode()] * 2)),
'ITEM_CATEGORY_doctor_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[22].encode()] * 2)),
'ITEM_CATEGORY_doctor_famous': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[23].encode()] * 2)),
'ITEM_CATEGORY_hospital_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[24].encode()] * 2)),
'ITEM_CATEGORY_hospital_city_tag_id': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[25].encode()] * 2)),
'ITEM_CATEGORY_hospital_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[26].encode()] * 2)),
'ITEM_CATEGORY_hospital_is_high_quality': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[27].encode()] * 2)),
'ITEM_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[28].split(','))) * 2)),
'ITEM_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[29].split(','))) * 2)),
'ITEM_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[30].split(','))) * 2)),
'ITEM_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[31].split(','))) * 2)),
'ITEM_NUMERIC_sku_price': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[32])] * 2)),
# 'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(splits[5])] * 2)),
}
# print(features)
# print(splits[32])
tf_features = tf.train.Features(feature=features)
tf_example = tf.train.Example(features=tf_features)
# print(tf_example)
tf_serialized = tf_example.SerializeToString()
examples.append(tf_serialized)
count += 1
if count == 1000:
break
start = time.time()
# request.inputs['examples'].CopyFrom(tf.make_tensor_proto(examples)) # in对应上图第三个方框,为模型的输入Name
# print(examples)
tensor_proto = tf.make_tensor_proto(examples)
print(time.time() - start)
request.inputs['examples'].CopyFrom(tensor_proto) # in对应上图第三个方框,为模型的输入Name
result_future = stub.Predict.future(request, 10.0) # 10 secs timeout
result = result_future.result()
# print(result)
print(time.time() - start)
if __name__ == "__main__":
prediction()
import base64
import tensorflow as tf
import requests
#encoding=utf8
import requests
import numpy as np
import tensorflow.compat.v1 as tf
import time
tf.disable_v2_behavior()
np.set_printoptions(threshold=np.inf)
np.set_printoptions(precision=3)
from tensorflow_serving.apis import predict_pb2
from tensorflow_serving.apis import prediction_service_pb2_grpc
import grpc
tf.app.flags.DEFINE_string('server', 'localhost:8502', 'PredictionService host:port')
FLAGS = tf.app.flags.FLAGS
def prediction():
options = [('grpc.max_send_message_length', 1000 * 1024 * 1024), ('grpc.max_receive_message_length', 1000 * 1024 * 1024)]
channel = grpc.insecure_channel(FLAGS.server, options = options)
stub = prediction_service_pb2_grpc.PredictionServiceStub(channel)
request = predict_pb2.PredictRequest()
request.model_spec.name = 'wide_deep' #对应上图第一个方框
request.model_spec.signature_name = 'regression' #对应上图第二个方框with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:
for _ in range(20):
with open('/Users/edz/software/Recommend/train_samples.csv', 'r') as f:
count = 0
examples = []
for line in f:
splits = line.split('|')
features = {
'ITEM_CATEGORY_card_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[0].encode()])),
'USER_CATEGORY_device_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[2].encode()])),
'USER_CATEGORY_os': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[3].encode()])),
'USER_CATEGORY_user_city_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[4].encode()])),
'USER_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[6].split(','))))),
'USER_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[7].split(','))))),
'USER_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[8].split(','))))),
'USER_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[9].split(','))))),
'ITEM_NUMERIC_click_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[10])])),
'ITEM_NUMERIC_click_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[11])])),
'ITEM_NUMERIC_click_count_stddev': tf.train.Feature(
float_list=tf.train.FloatList(value=[float(splits[12])])),
'ITEM_NUMERIC_exp_count_sum': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[13])])),
'ITEM_NUMERIC_exp_count_avg': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[14])])),
'ITEM_NUMERIC_exp_count_stddev': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[15])])),
'ITEM_NUMERIC_discount': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[16])])),
'ITEM_NUMERIC_case_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[17])])),
'ITEM_NUMERIC_sales_count': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[18])])),
'ITEM_CATEGORY_service_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[19].encode()])),
'ITEM_CATEGORY_merchant_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[20].encode()])),
'ITEM_CATEGORY_doctor_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[21].encode()])),
'ITEM_CATEGORY_doctor_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[22].encode()])),
'ITEM_CATEGORY_doctor_famous': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[23].encode()])),
'ITEM_CATEGORY_hospital_id': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[24].encode()])),
'ITEM_CATEGORY_hospital_city_tag_id': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[25].encode()])),
'ITEM_CATEGORY_hospital_type': tf.train.Feature(bytes_list=tf.train.BytesList(value=[splits[26].encode()])),
'ITEM_CATEGORY_hospital_is_high_quality': tf.train.Feature(
bytes_list=tf.train.BytesList(value=[splits[27].encode()])),
'ITEM_MULTI_CATEGORY_second_demands': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[28].split(','))))),
'ITEM_MULTI_CATEGORY_second_solutions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[29].split(','))))),
'ITEM_MULTI_CATEGORY_second_positions': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[30].split(','))))),
'ITEM_MULTI_CATEGORY_projects': tf.train.Feature(
bytes_list=tf.train.BytesList(value=list(map(lambda s: s.encode(), splits[31].split(','))))),
'ITEM_NUMERIC_sku_price': tf.train.Feature(float_list=tf.train.FloatList(value=[float(splits[32])])),
'label': tf.train.Feature(int64_list=tf.train.Int64List(value=[int(splits[5])])),
}
# print(features)
# print(splits[32])
tf_features = tf.train.Features(feature=features)
tf_example = tf.train.Example(features=tf_features)
tf_serialized = tf_example.SerializeToString()
examples.append(tf_serialized)
count += 1
if count == 1000:
break
start = time.time()
# request.inputs['examples'].CopyFrom(tf.make_tensor_proto(examples)) # in对应上图第三个方框,为模型的输入Name
tensor_proto = tf.make_tensor_proto(examples)
request.inputs['inputs'].CopyFrom(tensor_proto) # in对应上图第三个方框,为模型的输入Name
result_future = stub.Predict.future(request, 10.0) # 10 secs timeout
result = result_future.result()
# print(result)
print(time.time() - start)
if __name__ == "__main__":
prediction()
import tensorflow as tf
model = tf.estimator.DNNLinearCombinedClassifier()
model.export_saved_model()
\ No newline at end of file
...@@ -80,6 +80,8 @@ def input_fn(csv_path, epoch, shuffle, batch_size): ...@@ -80,6 +80,8 @@ def input_fn(csv_path, epoch, shuffle, batch_size):
dataset = dataset.map(parse_line, num_parallel_calls = tf.data.experimental.AUTOTUNE) dataset = dataset.map(parse_line, num_parallel_calls = tf.data.experimental.AUTOTUNE)
dataset = dataset.cache()
if shuffle: if shuffle:
dataset = dataset.shuffle(1024) dataset = dataset.shuffle(1024)
else: else:
...@@ -220,9 +222,9 @@ session_config.gpu_options.allow_growth = True ...@@ -220,9 +222,9 @@ session_config.gpu_options.allow_growth = True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution) # config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, session_config = session_config) config = tf.estimator.RunConfig(save_checkpoints_steps = 3000, session_config = session_config)
wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR + 'model', wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = './model',
linear_feature_columns = linear_feature_columns, linear_feature_columns = linear_feature_columns,
dnn_feature_columns = dnn_feature_columns, dnn_feature_columns = dnn_feature_columns,
dnn_hidden_units = [128, 32], dnn_hidden_units = [128, 32],
...@@ -233,15 +235,15 @@ wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR ...@@ -233,15 +235,15 @@ wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = BASE_DIR
# early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(wideAndDeepModel, metric_name = 'auc', max_steps_without_increase = 1000, min_steps = 1000) # early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(wideAndDeepModel, metric_name = 'auc', max_steps_without_increase = 1000, min_steps = 1000)
hooks = [tf.train.ProfilerHook(save_steps=100, output_dir='./profile/')] # hooks = [tf.train.ProfilerHook(save_steps=100, output_dir='./profile/')]
train_spec = tf.estimator.TrainSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 512), hooks = [])
train_spec = tf.estimator.TrainSpec(input_fn = lambda: input_fn(BASE_DIR + 'train_samples.csv', 20, True, 512), hooks = hooks) serving_feature_spec = tf.feature_column.make_parse_example_spec(linear_feature_columns + dnn_feature_columns)
serving_input_receiver_fn = tf.estimator.export.build_parsing_serving_input_receiver_fn(serving_feature_spec)
serving_feature_spec = tf.feature_column.make_parse_example_spec(
linear_feature_columns + dnn_feature_columns)
serving_input_receiver_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(
serving_feature_spec))
exporter = tf.estimator.BestExporter( exporter = tf.estimator.BestExporter(
name = "best_exporter", name = "best_exporter",
...@@ -249,7 +251,7 @@ exporter = tf.estimator.BestExporter( ...@@ -249,7 +251,7 @@ exporter = tf.estimator.BestExporter(
serving_input_receiver_fn = serving_input_receiver_fn, serving_input_receiver_fn = serving_input_receiver_fn,
exports_to_keep = 3) exports_to_keep = 3)
eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15), steps = None, throttle_secs = 120, exporters = exporter) eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15), steps = 100, throttle_secs = 120, exporters = exporter)
# def my_auc(labels, predictions): # def my_auc(labels, predictions):
# return {'auc_pr_careful_interpolation': tf.metrics.auc(labels, predictions['logistic'], curve='ROC', # return {'auc_pr_careful_interpolation': tf.metrics.auc(labels, predictions['logistic'], curve='ROC',
...@@ -257,6 +259,12 @@ eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_s ...@@ -257,6 +259,12 @@ eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_s
# wideAndDeepModel = tf.contrib.estimator.add_metrics(wideAndDeepModel, my_auc) # wideAndDeepModel = tf.contrib.estimator.add_metrics(wideAndDeepModel, my_auc)
tf.estimator.train_and_evaluate(wideAndDeepModel, train_spec, eval_spec)
wideAndDeepModel.evaluate(lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15))
# tf.estimator.train_and_evaluate(wideAndDeepModel, train_spec, eval_spec)
wideAndDeepModel.export_saved_model('./saved_model', serving_input_receiver_fn, as_text = False)
# wideAndDeepModel.evaluate(lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15))
import tensorflow as tf
tf.logging.set_verbosity(tf.logging.INFO)
import sys
import time
import os
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
start = time.time()
BASE_DIR = '/data/files/wideAndDeep/'
MODEL_BASE_DIR = '/data/files/wideAndDeep_tf2_dist/'
def input_fn(csv_path, epoch, shuffle, batch_size):
dataset = tf.data.TextLineDataset(csv_path)
def parse_line(line_tensor):
splits = tf.compat.v1.string_split([line_tensor], delimiter='|', skip_empty=False).values
return {
'ITEM_CATEGORY_card_id': splits[0],
'USER_CATEGORY_device_id': splits[2],
'USER_CATEGORY_os': splits[3],
'USER_CATEGORY_user_city_id': splits[4],
'USER_MULTI_CATEGORY_second_solutions': tf.compat.v1.string_split([splits[6]], delimiter=',').values,
'USER_MULTI_CATEGORY_second_demands': tf.compat.v1.string_split([splits[7]], delimiter=',').values,
'USER_MULTI_CATEGORY_second_positions': tf.compat.v1.string_split([splits[8]], delimiter=',').values,
'USER_MULTI_CATEGORY_projects': tf.compat.v1.string_split([splits[9]], delimiter=',').values,
'ITEM_NUMERIC_click_count_sum': tf.compat.v1.string_to_number(splits[10]),
'ITEM_NUMERIC_click_count_avg': tf.compat.v1.string_to_number(splits[11]),
'ITEM_NUMERIC_click_count_stddev': tf.compat.v1.string_to_number(splits[12]),
'ITEM_NUMERIC_exp_count_sum': tf.compat.v1.string_to_number(splits[13]),
'ITEM_NUMERIC_exp_count_avg': tf.compat.v1.string_to_number(splits[14]),
'ITEM_NUMERIC_exp_count_stddev': tf.compat.v1.string_to_number(splits[15]),
'ITEM_NUMERIC_discount': tf.compat.v1.string_to_number(splits[16]),
'ITEM_NUMERIC_case_count': tf.compat.v1.string_to_number(splits[17]),
'ITEM_NUMERIC_sales_count': tf.compat.v1.string_to_number(splits[18]),
'ITEM_CATEGORY_service_type': splits[19],
'ITEM_CATEGORY_merchant_id': splits[20],
'ITEM_CATEGORY_doctor_type': splits[21],
'ITEM_CATEGORY_doctor_id': splits[22],
'ITEM_CATEGORY_doctor_famous': splits[23],
'ITEM_CATEGORY_hospital_id': splits[24],
'ITEM_CATEGORY_hospital_city_tag_id': splits[25],
'ITEM_CATEGORY_hospital_type': splits[26],
'ITEM_CATEGORY_hospital_is_high_quality': splits[27],
'ITEM_MULTI_CATEGORY_second_demands': tf.compat.v1.string_split([splits[28]], delimiter=',').values,
'ITEM_MULTI_CATEGORY_second_solutions': tf.compat.v1.string_split([splits[29]],
delimiter=',').values,
'ITEM_MULTI_CATEGORY_second_positions': tf.compat.v1.string_split([splits[30]],
delimiter=',').values,
'ITEM_MULTI_CATEGORY_projects': tf.compat.v1.string_split([splits[31]], delimiter=',').values,
'ITEM_NUMERIC_sku_price': tf.compat.v1.string_to_number(splits[32]),
# 'label': tf.compat.v1.string_to_number(splits[5])
}, tf.compat.v1.string_to_number(splits[5])
padded_shapes = ({'ITEM_CATEGORY_card_id': (), 'USER_CATEGORY_device_id': (), 'USER_CATEGORY_os': (),
'USER_CATEGORY_user_city_id': (), 'USER_MULTI_CATEGORY_second_solutions': [-1],
'USER_MULTI_CATEGORY_second_demands': [-1], 'USER_MULTI_CATEGORY_second_positions': [-1],
'USER_MULTI_CATEGORY_projects': [-1], 'ITEM_NUMERIC_click_count_sum': (),
'ITEM_NUMERIC_click_count_avg': (), 'ITEM_NUMERIC_click_count_stddev': (),
'ITEM_NUMERIC_exp_count_sum': (), 'ITEM_NUMERIC_exp_count_avg': (),
'ITEM_NUMERIC_exp_count_stddev': (), 'ITEM_NUMERIC_discount': (), 'ITEM_NUMERIC_case_count': (),
'ITEM_NUMERIC_sales_count': (), 'ITEM_CATEGORY_service_type': (), 'ITEM_CATEGORY_merchant_id': (),
'ITEM_CATEGORY_doctor_type': (), 'ITEM_CATEGORY_doctor_id': (), 'ITEM_CATEGORY_doctor_famous': (),
'ITEM_CATEGORY_hospital_id': (), 'ITEM_CATEGORY_hospital_city_tag_id': (),
'ITEM_CATEGORY_hospital_type': (), 'ITEM_CATEGORY_hospital_is_high_quality': (),
'ITEM_MULTI_CATEGORY_second_demands': [-1], 'ITEM_MULTI_CATEGORY_second_solutions': [-1],
'ITEM_MULTI_CATEGORY_second_positions': [-1], 'ITEM_MULTI_CATEGORY_projects': [-1],
'ITEM_NUMERIC_sku_price': ()}, ())
padding_values = ({'ITEM_CATEGORY_card_id': '-1', 'USER_CATEGORY_device_id': '-1', 'USER_CATEGORY_os': '-1',
'USER_CATEGORY_user_city_id': '-1', 'USER_MULTI_CATEGORY_second_solutions': '-1',
'USER_MULTI_CATEGORY_second_demands': '-1', 'USER_MULTI_CATEGORY_second_positions': '-1',
'USER_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_click_count_sum': 0.0,
'ITEM_NUMERIC_click_count_avg': 0.0, 'ITEM_NUMERIC_click_count_stddev': 0.0,
'ITEM_NUMERIC_exp_count_sum': 0.0, 'ITEM_NUMERIC_exp_count_avg': 0.0,
'ITEM_NUMERIC_exp_count_stddev': 0.0, 'ITEM_NUMERIC_discount': 0.0,
'ITEM_NUMERIC_case_count': 0.0, 'ITEM_NUMERIC_sales_count': 0.0,
'ITEM_CATEGORY_service_type': '-1', 'ITEM_CATEGORY_merchant_id': '-1',
'ITEM_CATEGORY_doctor_type': '-1', 'ITEM_CATEGORY_doctor_id': '-1',
'ITEM_CATEGORY_doctor_famous': '-1', 'ITEM_CATEGORY_hospital_id': '-1',
'ITEM_CATEGORY_hospital_city_tag_id': '-1', 'ITEM_CATEGORY_hospital_type': '-1',
'ITEM_CATEGORY_hospital_is_high_quality': '-1', 'ITEM_MULTI_CATEGORY_second_demands': '-1',
'ITEM_MULTI_CATEGORY_second_solutions': '-1', 'ITEM_MULTI_CATEGORY_second_positions': '-1',
'ITEM_MULTI_CATEGORY_projects': '-1', 'ITEM_NUMERIC_sku_price': 0.0}, 0.0)
dataset = dataset.map(parse_line, num_parallel_calls = 8).cache()
dataset = dataset.padded_batch(batch_size, padded_shapes, padding_values=padding_values)
if shuffle:
dataset = dataset.shuffle(2048).prefetch(512 * 100).repeat(epoch)
else:
dataset = dataset.prefetch(512 * 100).repeat(epoch)
return dataset
boundaries = [0, 10, 100]
ITEM_NUMERIC_click_count_sum_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_click_count_sum'), boundaries)
ITEM_NUMERIC_exp_count_sum_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_exp_count_sum'), boundaries)
ITEM_NUMERIC_click_count_avg_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_click_count_avg'), boundaries)
ITEM_NUMERIC_exp_count_avg_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_exp_count_avg'), boundaries)
boundaries = [0, 0.01, 0.1]
ITEM_NUMERIC_click_count_stddev_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_click_count_stddev'), boundaries)
ITEM_NUMERIC_exp_count_stddev_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_exp_count_stddev'), boundaries)
boundaries = [0, 0.01, 0.1, 1]
ITEM_NUMERIC_discount_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_discount'), boundaries)
boundaries = [0, 10, 100]
ITEM_NUMERIC_case_count_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_case_count'), boundaries)
ITEM_NUMERIC_sales_count_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_sales_count'), boundaries)
ITEM_NUMERIC_sku_price_fc = tf.feature_column.bucketized_column(tf.feature_column.numeric_column('ITEM_NUMERIC_sku_price'), boundaries)
USER_CATEGORY_device_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_CATEGORY_device_id', BASE_DIR + 'USER_CATEGORY_device_id_vocab.csv')
USER_CATEGORY_os_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_CATEGORY_os', BASE_DIR + 'USER_CATEGORY_os_vocab.csv')
USER_CATEGORY_user_city_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_CATEGORY_user_city_id', BASE_DIR + 'USER_CATEGORY_user_city_id_vocab.csv')
USER_MULTI_CATEGORY__second_solutions_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_second_solutions', BASE_DIR + 'USER_MULTI_CATEGORY_second_solutions_vocab.csv')
USER_MULTI_CATEGORY__second_positions_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_second_positions', BASE_DIR + 'USER_MULTI_CATEGORY_second_positions_vocab.csv')
USER_MULTI_CATEGORY__second_demands_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_second_demands', BASE_DIR + 'USER_MULTI_CATEGORY_second_demands_vocab.csv')
USER_MULTI_CATEGORY__projects_fc = tf.feature_column.categorical_column_with_vocabulary_file('USER_MULTI_CATEGORY_projects', BASE_DIR + 'USER_MULTI_CATEGORY_projects_vocab.csv')
ITEM_CATEGORY_card_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_card_id', BASE_DIR + 'ITEM_CATEGORY_card_id_vocab.csv')
ITEM_CATEGORY_service_type_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_service_type', BASE_DIR + 'ITEM_CATEGORY_service_type_vocab.csv')
ITEM_CATEGORY_merchant_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_merchant_id', BASE_DIR + 'ITEM_CATEGORY_merchant_id_vocab.csv')
ITEM_CATEGORY_doctor_type_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_doctor_type', BASE_DIR + 'ITEM_CATEGORY_doctor_type_vocab.csv')
ITEM_CATEGORY_doctor_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_doctor_id', BASE_DIR + 'ITEM_CATEGORY_doctor_id_vocab.csv')
ITEM_CATEGORY_doctor_famous_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_doctor_famous', BASE_DIR + 'ITEM_CATEGORY_doctor_famous_vocab.csv')
ITEM_CATEGORY_hospital_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_id', BASE_DIR + 'ITEM_CATEGORY_hospital_id_vocab.csv')
ITEM_CATEGORY_hospital_city_tag_id_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_city_tag_id', BASE_DIR + 'ITEM_CATEGORY_hospital_city_tag_id_vocab.csv')
ITEM_CATEGORY_hospital_type_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_type', BASE_DIR + 'ITEM_CATEGORY_hospital_type_vocab.csv')
ITEM_CATEGORY_hospital_is_high_quality_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_CATEGORY_hospital_is_high_quality', BASE_DIR + 'ITEM_CATEGORY_hospital_is_high_quality_vocab.csv')
ITEM_MULTI_CATEGORY__second_solutions_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_second_solutions', BASE_DIR + 'ITEM_MULTI_CATEGORY_second_solutions_vocab.csv')
ITEM_MULTI_CATEGORY__second_positions_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_second_positions', BASE_DIR + 'ITEM_MULTI_CATEGORY_second_positions_vocab.csv')
ITEM_MULTI_CATEGORY__second_demands_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_second_demands', BASE_DIR + 'ITEM_MULTI_CATEGORY_second_demands_vocab.csv')
ITEM_MULTI_CATEGORY__projects_fc = tf.feature_column.categorical_column_with_vocabulary_file('ITEM_MULTI_CATEGORY_projects', BASE_DIR + 'ITEM_MULTI_CATEGORY_projects_vocab.csv')
def embedding_fc(categorical_column, dim):
return tf.feature_column.embedding_column(categorical_column, dim)
linear_feature_columns = [
ITEM_NUMERIC_click_count_sum_fc,
ITEM_NUMERIC_exp_count_sum_fc,
ITEM_NUMERIC_click_count_avg_fc,
ITEM_NUMERIC_exp_count_avg_fc,
ITEM_NUMERIC_click_count_stddev_fc,
ITEM_NUMERIC_exp_count_stddev_fc,
ITEM_NUMERIC_discount_fc,
ITEM_NUMERIC_case_count_fc,
ITEM_NUMERIC_sales_count_fc,
ITEM_NUMERIC_sku_price_fc,
embedding_fc(ITEM_CATEGORY_card_id_fc, 1),
embedding_fc(ITEM_CATEGORY_service_type_fc, 1),
embedding_fc(ITEM_CATEGORY_merchant_id_fc, 1),
embedding_fc(ITEM_CATEGORY_doctor_type_fc, 1),
embedding_fc(ITEM_CATEGORY_doctor_id_fc, 1),
embedding_fc(ITEM_CATEGORY_doctor_famous_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_id_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_city_tag_id_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_type_fc, 1),
embedding_fc(ITEM_CATEGORY_hospital_is_high_quality_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__projects_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__second_demands_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__second_positions_fc, 1),
embedding_fc(ITEM_MULTI_CATEGORY__second_solutions_fc, 1),
]
dnn_feature_columns = [
embedding_fc(USER_CATEGORY_device_id_fc, 8),
embedding_fc(USER_CATEGORY_os_fc, 8),
embedding_fc(USER_CATEGORY_user_city_id_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__second_solutions_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__second_positions_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__second_demands_fc, 8),
embedding_fc(USER_MULTI_CATEGORY__projects_fc, 8),
embedding_fc(ITEM_NUMERIC_click_count_sum_fc, 8),
embedding_fc(ITEM_NUMERIC_exp_count_sum_fc, 8),
embedding_fc(ITEM_NUMERIC_click_count_avg_fc, 8),
embedding_fc(ITEM_NUMERIC_exp_count_avg_fc, 8),
embedding_fc(ITEM_NUMERIC_click_count_stddev_fc, 8),
embedding_fc(ITEM_NUMERIC_exp_count_stddev_fc, 8),
embedding_fc(ITEM_NUMERIC_discount_fc, 8),
embedding_fc(ITEM_NUMERIC_case_count_fc, 8),
embedding_fc(ITEM_NUMERIC_sales_count_fc, 8),
embedding_fc(ITEM_NUMERIC_sku_price_fc, 8),
embedding_fc(ITEM_CATEGORY_card_id_fc, 8),
embedding_fc(ITEM_CATEGORY_service_type_fc, 8),
embedding_fc(ITEM_CATEGORY_merchant_id_fc, 8),
embedding_fc(ITEM_CATEGORY_doctor_type_fc, 8),
embedding_fc(ITEM_CATEGORY_doctor_id_fc, 8),
embedding_fc(ITEM_CATEGORY_doctor_famous_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_id_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_city_tag_id_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_type_fc, 8),
embedding_fc(ITEM_CATEGORY_hospital_is_high_quality_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__projects_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__second_demands_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__second_positions_fc, 8),
embedding_fc(ITEM_MULTI_CATEGORY__second_solutions_fc, 8),
]
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0, 1, 2"
from tensorflow.python.client import device_lib
print(device_lib.list_local_devices())
distribution = tf.distribute.MirroredStrategy()
# session_config = tf.compat.v1.ConfigProto(log_device_placement = True, allow_soft_placement = True)
session_config = tf.compat.v1.ConfigProto(allow_soft_placement = True)
session_config.gpu_options.allow_growth = True
# config = tf.estimator.RunConfig(save_checkpoints_steps = 10000, train_distribute = distribution, eval_distribute = distribution)
config = tf.estimator.RunConfig(save_checkpoints_steps = 3000, session_config = session_config)
wideAndDeepModel = tf.estimator.DNNLinearCombinedClassifier(model_dir = MODEL_BASE_DIR + 'model_csv',
linear_feature_columns = linear_feature_columns,
dnn_feature_columns = dnn_feature_columns,
dnn_hidden_units = [128, 32],
dnn_dropout = 0.5,
config = config)
# early_stopping = tf.contrib.estimator.stop_if_no_decrease_hook(wideAndDeepModel, eval_dir = wideAndDeepModel.eval_dir(), metric_name='auc', max_steps_without_decrease=1000, min_steps = 100)
# early_stopping = tf.contrib.estimator.stop_if_no_increase_hook(wideAndDeepModel, metric_name = 'auc', max_steps_without_increase = 1000, min_steps = 1000)
hooks = []
train_spec = tf.estimator.TrainSpec(input_fn = lambda: input_fn(BASE_DIR + 'train_samples.csv', 20, True, 1024), hooks = hooks)
serving_feature_spec = tf.feature_column.make_parse_example_spec(
linear_feature_columns + dnn_feature_columns)
serving_input_receiver_fn = (
tf.estimator.export.build_parsing_serving_input_receiver_fn(
serving_feature_spec))
exporter = tf.estimator.BestExporter(
name = "best_exporter",
compare_fn = lambda best_eval_result, current_eval_result: current_eval_result['auc'] > best_eval_result['auc'],
serving_input_receiver_fn = serving_input_receiver_fn,
exports_to_keep = 3)
eval_spec = tf.estimator.EvalSpec(input_fn = lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15), steps = None, throttle_secs = 120, exporters = exporter)
# def my_auc(labels, predictions):
# return {'auc_pr_careful_interpolation': tf.metrics.auc(labels, predictions['logistic'], curve='ROC',
# summation_method='careful_interpolation')}
# wideAndDeepModel = tf.contrib.estimator.add_metrics(wideAndDeepModel, my_auc)
tf.estimator.train_and_evaluate(wideAndDeepModel, train_spec, eval_spec)
wideAndDeepModel.evaluate(lambda: input_fn(BASE_DIR + 'eval_samples.csv', 1, False, 2 ** 15))
print("训练耗时: {}s".format(time.time() - start))
\ No newline at end of file
import os
import argparse
import shutil
import pandas as pd
from sklearn.model_selection import train_test_split
import tensorflow as tf
def build_model_columns():
# 定义连续值列
actual_price = tf.feature_column.numeric_column('actual_price', normalizer_fn=lambda x: (x - 0) / 150000,
dtype=tf.float32)
# 定义离散值列
gender = tf.feature_column.categorical_column_with_vocabulary_list(
'Gender', [1, -1, 0], dtype=tf.int64)
# 对购买总金额和最大一次购买inx进行分箱
actual_price_bin = tf.feature_column.bucketized_column(
actual_price, boundaries=[100, 250, 550, 1300])
# wide部分的特征是0 1稀疏向量, 走LR, 采用全部离散特征和某些离散特征的交叉
wide_columns = [actual_price_bin, gender]
gender_emb = tf.feature_column.embedding_column(gender, 10)
# 所有特征都走deep部分, 连续特征+离散特征onehot或者embedding
deep_columns = [
gender_emb
]
return wide_columns, deep_columns
def build_estimator(model_dir, model_type, warm_start_from=None):
"""按照指定的模型生成估算器对象."""
# 特征工程后的列对象组成的list
wide_columns, deep_columns = build_model_columns()
# deep 每一层全连接隐藏层单元个数, 4层每一层的激活函数是relu
hidden_units = [50, 25]
run_config = tf.estimator.RunConfig().replace( # 将GPU个数设为0,关闭GPU运算。因为该模型在CPU上速度更快
save_checkpoints_steps=100,
keep_checkpoint_max=2)
if model_type == 'wide': # 生成带有wide模型的估算器对象
return tf.estimator.LinearClassifier(
model_dir=model_dir,
feature_columns=wide_columns,
config=run_config)
elif model_type == 'deep': # 生成带有deep模型的估算器对象
return tf.estimator.DNNClassifier(
model_dir=model_dir,
feature_columns=deep_columns,
hidden_units=hidden_units,
config=run_config)
else:
return tf.estimator.DNNLinearCombinedClassifier( # 生成带有wide和deep模型的估算器对象
model_dir=model_dir,
linear_feature_columns=wide_columns,
dnn_feature_columns=deep_columns,
dnn_hidden_units=hidden_units,
config=run_config,
warm_start_from=warm_start_from)
def read_pandas(data_file):
"""pandas将数据读取内存"""
assert os.path.exists(data_file), ("%s not found." % data_file)
df = pd.read_csv(data_file).dropna()
train, test = train_test_split(df, test_size=0.15, random_state=1)
y_train = train.pop("label")
y_test = test.pop("label")
return train, test, y_train, y_test
def input_fn(X, y, shuffle, batch_size, predict=False): # 定义估算器输入函数
"""估算器的输入函数."""
if predict == True:
# from_tensor_slices 从内存引入数据
dataset = tf.data.Dataset.from_tensor_slices(X.to_dict(orient='list')) # 创建dataset数据集
else:
dataset = tf.data.Dataset.from_tensor_slices((X.to_dict(orient='list'), y)) # 创建dataset数据集
if shuffle: # 对数据进行乱序操作
dataset = dataset.shuffle(buffer_size=64) # 越大shuffle程度越大
dataset = dataset.batch(batch_size) # 将数据集按照batch_size划分
dataset = dataset.prefetch(1) # 预取数据,buffer_size=1 在多数情况下就足够了
return dataset
def trainmain(train, y_train, test, y_test):
model_dir = "./wide_deep_test"
model_type = "wide_deep"
model = build_estimator(model_dir, model_type) # 生成估算器对象
def train_input_fn():
return input_fn(train, y_train, True, 1, predict=False)
def eval_input_fn():
return input_fn(test, y_test, False, 1, predict=False)
# 在外部指定repeat 不在dataset中
for n in range(1):
model.train(input_fn=train_input_fn)
results = model.evaluate(input_fn=eval_input_fn)
print('{0:-^30}'.format('evaluate at epoch %d' % ((n + 1))))
# results 是一个字典
print(pd.Series(results).to_frame('values'))
# 导出模型
export_model(model, "saved_model_test")
def export_model(model, export_dir):
features = {
"Gender": tf.placeholder(dtype=tf.int64, shape=(2), name='Gender'),
"actual_price": tf.placeholder(dtype=tf.float32, shape=(2), name='actual_price'),
}
example_input_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(features)
model.export_savedmodel()
model.export_savedmodel(export_dir, example_input_fn, as_text=False, strip_default_attrs=True)
import pandas as pd
train_X = pd.DataFrame({"Gender": [1, 0, 1, 0, 1, 0], "actual_price": [10000.0, 10000.0, 10000.0, 10000.0, 10000.0, 10000.0]})
train_Y = [1, 0, 1, 0, 1, 0]
trainmain(train_X, train_Y, train_X, train_Y)
\ No newline at end of file
import requests
data = {
'Gender': [0],
'actual_price': [0]}
res = requests.post("http://localhost:8501/v1/models/wide_deep:predict",
json={"instances": [data], "signature_name": "predict"})
print(res.text)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment