Commit a207254f authored by Your Name's avatar Your Name

update

parent 998a9aa1
from datetime import date, timedelta from datetime import date, timedelta
import tensorflow as tf import tensorflow as tf
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
from pyspark.sql import SparkSession
import datetime
import pandas as pd
import time import time
from pyspark import StorageLevel
def model_fn(features, labels, mode, params): def model_fn(features, labels, mode, params):
"""Bulid Model function f(x) for Estimator.""" """Bulid Model function f(x) for Estimator."""
...@@ -135,7 +141,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): ...@@ -135,7 +141,7 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
#print(batch_features,batch_labels) #print(batch_features,batch_labels)
return batch_features, batch_labels return batch_features, batch_labels
def main(_): def esmm_predict(dist_data):
dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d') dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/" + dt_dir model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/" + dt_dir
te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"] te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/nearby/part-r-00000"]
...@@ -153,13 +159,28 @@ def main(_): ...@@ -153,13 +159,28 @@ def main(_):
log_step_count_steps=100, save_summary_steps=100) log_step_count_steps=100, save_summary_steps=100)
Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/", params=model_params, config=config) Estimator = tf.estimator.Estimator(model_fn=model_fn, model_dir="hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/", params=model_params, config=config)
preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"]) preds = Estimator.predict(input_fn=lambda: input_fn(dist_data, num_epochs=1, batch_size=10000), predict_keys=["pctcvr","pctr","pcvr"])
with open("/home/gmuser/esmm/nearby" + "/pred.txt", "w") as fo: indices = []
for prob in preds: for prob in preds:
fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr'])) indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
return indices
if __name__ == "__main__": if __name__ == "__main__":
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
.set("spark.tispark.plan.allow_index_double_read", "false") \
.set("spark.tispark.plan.allow_index_read", "true") \
.set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
.set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")\
.set("spark.driver.maxResultSize", "8g").set("spark.sql.avro.compression.codec","snappy")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
spark.sparkContext.setLogLevel("WARN")
path = "hdfs://172.16.32.4:8020/strategy/esmm/"
b = time.time() b = time.time()
tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run() tf.app.run()
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment