Commit 2bf956e8 authored by Your Name's avatar Your Name

change train.py

parent 915416d1
No related merge requests found
#coding=utf-8
#from __future__ import absolute_import
#from __future__ import division
#from __future__ import print_function
#import argparse
import shutil
import pymysql
import os
import json
from datetime import date, timedelta
import tensorflow as tf
import subprocess
import time
import glob
import random
import pandas as pd
#################### CMD Arguments ####################
FLAGS = tf.app.flags.FLAGS
......@@ -65,7 +58,10 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
"tag6_list": tf.VarLenFeature(tf.int64),
"tag7_list": tf.VarLenFeature(tf.int64),
"search_tag2_list": tf.VarLenFeature(tf.int64),
"search_tag3_list": tf.VarLenFeature(tf.int64)
"search_tag3_list": tf.VarLenFeature(tf.int64),
"uid": tf.VarLenFeature(tf.string),
"city": tf.VarLenFeature(tf.string),
"cid_id": tf.VarLenFeature(tf.string)
}
parsed = tf.parse_single_example(record, features)
y = parsed.pop('y')
......@@ -135,6 +131,10 @@ def model_fn(features, labels, mode, params):
tag7_list = features['tag7_list']
search_tag2_list = features['search_tag2_list']
search_tag3_list = features['search_tag3_list']
uid = features['uid']
city = features['city']
cid_id = features['cid_id']
if FLAGS.task_type != "infer":
y = labels['y']
......@@ -161,6 +161,10 @@ def model_fn(features, labels, mode, params):
x_concat = tf.concat([tf.reshape(embedding_id, shape=[-1, common_dims]), app_id, level2, level3, tag1,
tag2, tag3, tag4, tag5, tag6, tag7,search_tag2,search_tag3], axis=1)
uid = features['uid']
city = features['city']
cid_id = features['cid_id']
with tf.name_scope("CVR_Task"):
if mode == tf.estimator.ModeKeys.TRAIN:
train_phase = True
......@@ -205,7 +209,7 @@ def model_fn(features, labels, mode, params):
pcvr = tf.sigmoid(y_cvr)
pctcvr = pctr*pcvr
predictions={"pcvr": pcvr, "pctr": pctr, "pctcvr": pctcvr}
predictions = {"pctcvr": pctcvr, "uid": uid, "city": city, "cid_id": cid_id}
export_outputs = {tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: tf.estimator.export.PredictOutput(predictions)}
# Provide an estimator spec for `ModeKeys.PREDICT`
if mode == tf.estimator.ModeKeys.PREDICT:
......@@ -313,26 +317,26 @@ def set_dist_env():
print(json.dumps(tf_config))
os.environ['TF_CONFIG'] = json.dumps(tf_config)
def main(_):
def main(file_path):
#------check Arguments------
if FLAGS.dt_dir == "":
FLAGS.dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
FLAGS.model_dir = FLAGS.model_dir + FLAGS.dt_dir
#FLAGS.data_dir = FLAGS.data_dir + FLAGS.dt_dir
tr_files = ["hdfs://172.16.32.4:8020/strategy/esmm/tr/part-r-00000"]
va_files = ["hdfs://172.16.32.4:8020/strategy/esmm/va/part-r-00000"]
te_files = ["%s/part-r-00000" % FLAGS.hdfs_dir]
if FLAGS.clear_existing_model:
try:
shutil.rmtree(FLAGS.model_dir)
except Exception as e:
print(e, "at clear_existing_model")
else:
print("existing model cleaned at %s" % FLAGS.model_dir)
set_dist_env()
# if FLAGS.clear_existing_model:
# try:
# shutil.rmtree(FLAGS.model_dir)
# except Exception as e:
# print(e, "at clear_existing_model")
# else:
# print("existing model cleaned at %s" % FLAGS.model_dir)
# set_dist_env()
#------bulid Tasks------
model_params = {
......@@ -350,7 +354,7 @@ def main(_):
Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir=FLAGS.model_dir, params=model_params, config=config)
if FLAGS.task_type == 'train':
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(tr_files, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size))
train_spec = tf.estimator.TrainSpec(input_fn=lambda: input_fn(file_path, num_epochs=FLAGS.num_epochs, batch_size=FLAGS.batch_size))
eval_spec = tf.estimator.EvalSpec(input_fn=lambda: input_fn(va_files, num_epochs=1, batch_size=FLAGS.batch_size), steps=None, start_delay_secs=1000, throttle_secs=1200)
result = tf.estimator.train_and_evaluate(Estimator, train_spec, eval_spec)
for key,value in sorted(result[0].items()):
......@@ -360,18 +364,68 @@ def main(_):
for key,value in sorted(result.items()):
print('%s: %s' % (key,value))
elif FLAGS.task_type == 'infer':
preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys=["pctcvr","pctr","pcvr"])
with open(FLAGS.local_dir + "/pred.txt", "w") as fo:
for prob in preds:
fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr']))
preds = Estimator.predict(input_fn=lambda: input_fn(file_path, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys=["pctcvr","uid","city","cid_id"])
result = []
for prob in preds:
result.append([str(prob["uid"][0]), str(prob["city"][0]), str(prob["cid_id"][0]), str(prob['pctcvr'])])
return result
elif FLAGS.task_type == 'export':
print("Not Implemented, Do It Yourself!")
def trans(x):
return str(x)[2:-1] if str(x)[0] == 'b' else x
def set_join(lst):
l = lst.unique().tolist()
r = [str(i) for i in l]
r =r[:500]
return ','.join(r)
def df_sort(result,queue_name):
df = pd.DataFrame(result, columns=["uid", "city", "cid_id", "pctcvr"])
print(df.head(10))
df['uid1'] = df['uid'].apply(trans)
df['city1'] = df['city'].apply(trans)
df['cid_id1'] = df['cid_id'].apply(trans)
df2 = df.groupby(by=["uid1", "city1"]).apply(lambda x: x.sort_values(by="pctcvr", ascending=False)) \
.reset_index(drop=True).groupby(by=["uid1", "city1"]).agg({'cid_id1': set_join}).reset_index(drop=False)
df2.columns = ["device_id", "city_id", queue_name]
df2["time"] = "2019-06-27"
return df2
def update_or_insert(df2,queue_name):
device_count = df2.shape[0]
con = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test', charset = 'utf8')
cur = con.cursor()
try:
for i in range(0, device_count):
query = """INSERT INTO esmm_device_diary_queue_test (device_id, city_id, time,%s) VALUES('%s', '%s', '%s', '%s') \
ON DUPLICATE KEY UPDATE device_id='%s', city_id='%s', time='%s', %s='%s'""" % (queue_name, df2.device_id[i],df2.city_id[i], df2.time[i], df2[queue_name][i], df2.device_id[i], df2.city_id[i], df2.time[i], queue_name, df2[queue_name][i])
print(query)
cur.execute(query)
con.commit()
con.close()
print("insert or update sucess")
except Exception as e:
print(e)
if __name__ == "__main__":
b = time.time()
path = "hdfs://172.16.32.4:8020/strategy/esmm/"
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
if FLAGS.task_type == 'train':
print("train task")
tr_files = ["hdfs://172.16.32.4:8020/strategy/esmm/tr/part-r-00000"]
main(tr_files)
elif FLAGS.task_type == 'infer':
te_files = ["%s/part-r-00000" % FLAGS.hdfs_dir]
queue_name = te_files[0].split('_')[-1] + "_queue"
print(queue_name + " task")
result = main(te_files)
df = df_sort(result,queue_name)
update_or_insert(df,queue_name)
print("耗时(分钟):")
print((time.time()-b)/60)
\ No newline at end of file
print((time.time()-b)/60)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment