Commit 04b52880 authored by 王志伟's avatar 王志伟
parents 7491370b 9b554ea0
...@@ -41,148 +41,77 @@ def feature_engineer(): ...@@ -41,148 +41,77 @@ def feature_engineer():
start = (temp - datetime.timedelta(days=2)).strftime("%Y-%m-%d") start = (temp - datetime.timedelta(days=2)).strftime("%Y-%m-%d")
print(start) print(start)
sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \ sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids " \
"u.channel,c.top,cut.time,dl.app_list,feat.level3_ids,doctor.hospital_id," \ "from jerry_test.esmm_train_data e " \
"k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
"from jerry_test.esmm_train_data e left join jerry_test.user_feature u on e.device_id = u.device_id " \
"left join jerry_test.cid_type_top c on e.device_id = c.device_id " \
"left join jerry_test.cid_time_cut cut on e.cid_id = cut.cid " \
"left join jerry_test.device_app_list dl on e.device_id = dl.device_id " \
"left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \ "left join jerry_test.diary_feat feat on e.cid_id = feat.diary_id " \
"left join jerry_test.train_Knowledge_network_data k on feat.level2 = k.level2_id " \
"left join eagle.src_zhengxing_api_service service on e.diary_service_id = service.id " \
"left join eagle.src_zhengxing_api_doctor doctor on service.doctor_id = doctor.id " \
"where e.stat_date >= '{}'".format(start) "where e.stat_date >= '{}'".format(start)
df = spark.sql(sql) df = spark.sql(sql)
df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer", features = ["ucity_id","stat_date"]
"channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids"])
features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]
df = df.na.fill(dict(zip(features,features))) df = df.na.fill(dict(zip(features,features)))
apps_number, app_list_map = multi_hot(df,"app_list",1) apps_number, app_list_map = multi_hot(df,"level2_ids",1)
level2_number,leve2_map = multi_hot(df,"level2_ids",1 + apps_number)
level3_number, leve3_map = multi_hot(df, "level3_ids", 1 + apps_number + level2_number)
unique_values = [] unique_values = []
for i in features: for i in features:
unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect()) unique_values.extend(df.select(i).distinct().rdd.map(lambda x: x[0]).collect())
temp = list(range(2 + apps_number + level2_number + level3_number, temp = list(range(2 + apps_number,
2 + apps_number + level2_number + level3_number + len(unique_values))) 2 + apps_number + len(unique_values)))
value_map = dict(zip(unique_values, temp)) value_map = dict(zip(unique_values, temp))
rdd = df.select("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", rdd = df.select("level2_ids","stat_date","ucity_id","y","z").rdd
"device_type", "manufacturer", "channel", "top", "time", "hospital_id",
"treatment_method", "price_min","price_max", "treatment_time","maintain_time",
"recover_time","y","z").rdd
rdd.persist() rdd.persist()
# TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集 # TODO 上线后把下面train fliter 删除,因为最近一天的数据也要作为训练集
train = rdd.filter(lambda x: x[3]!= validate_date)\
.map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], leve2_map), train = rdd.filter(lambda x: x[1]!= validate_date)\
app_list_func(x[2], leve3_map),value_map[x[3]],value_map[x[4]], .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4]))
value_map[x[5]],value_map[x[6]],value_map[x[7]],value_map[x[8]], test = rdd.filter(lambda x: x[1]== validate_date)\
value_map[x[9]],value_map[x[10]],value_map[x[11]],value_map[x[12]], .map(lambda x: (app_list_func(x[0], app_list_map),[value_map[x[2]],value_map[x[1]]], x[3],x[4]))
value_map[x[13]],value_map[x[14]],value_map[x[15]],value_map[x[16]],
value_map[x[17]], x[18],x[19])) spark.createDataFrame(test).toDF("level2_ids","ids","y","z")\
test = rdd.filter(lambda x: x[3] == validate_date)\ .repartition(1).write.format("tfrecords").option("recordType", "SequenceExample").save(path=path+"va/", mode="overwrite")
.map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], leve2_map),
app_list_func(x[2], leve3_map), value_map[x[3]], value_map[x[4]],
value_map[x[5]], value_map[x[6]], value_map[x[7]], value_map[x[8]],
value_map[x[9]], value_map[x[10]], value_map[x[11]], value_map[x[12]],
value_map[x[13]], value_map[x[14]], value_map[x[15]], value_map[x[16]],
value_map[x[17]], x[18], x[19]))
spark.createDataFrame(test).toDF("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z")\
.repartition(1).write.format("tfrecords").option("recordType", "Example").save(path=path+"va/", mode="overwrite")
print("va write done") print("va write done")
spark.createDataFrame(train).toDF("app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer", spark.createDataFrame(train).toDF("level2_ids","ids","y","z") \
"channel", "top", "time", "hospital_id","treatment_method", "price_min", .repartition(1).write.format("tfrecords").option("recordType", "SequenceExample").save(path=path+"tr/", mode="overwrite")
"price_max", "treatment_time","maintain_time", "recover_time","y","z") \
.repartition(1).write.format("tfrecords").option("recordType", "Example").save(path=path+"tr/", mode="overwrite")
print("done") print("done")
rdd.unpersist() rdd.unpersist()
return validate_date,value_map,app_list_map,leve2_map,leve3_map return validate_date,value_map,app_list_map
def get_predict(date,value_map,app_list_map,level2_map,level3_map): def get_predict(date,value_map,app_list_map):
sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.ccity_name," \ sql = "select e.y,e.z,e.label,e.ucity_id,feat.level2_ids,e.device_id,e.cid_id from esmm_pre_data e " \
"u.device_type,u.manufacturer,u.channel,c.top,e.device_id,e.cid_id,cut.time," \ "left join diary_feat feat on e.cid_id = feat.diary_id limit 50000"
"dl.app_list,e.hospital_id,feat.level3_ids," \
"k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
"from esmm_pre_data e left join user_feature u on e.device_id = u.device_id " \
"left join cid_type_top c on e.device_id = c.device_id " \
"left join cid_time_cut cut on e.cid_id = cut.cid " \
"left join device_app_list dl on e.device_id = dl.device_id " \
"left join diary_feat feat on e.cid_id = feat.diary_id " \
"left join train_Knowledge_network_data k on feat.level2 = k.level2_id"
features = ["app_list", "level2_ids", "level3_ids","ucity_id", "ccity_name", "device_type", "manufacturer", features = ["ucity_id"]
"channel", "top", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time", "recover_time"]
df = spark.sql(sql) df = spark.sql(sql)
df = df.na.fill(dict(zip(features, features))) df = df.na.fill(dict(zip(features, features)))
df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer", rdd = df.select("level2_ids","ucity_id","device_id","cid_id","label", "y", "z") \
"device_id","cid_id","label", .rdd.map(lambda x: (app_list_func(x[0], app_list_map),x[1],x[2],x[3],x[4],x[5],x[6],
"channel", "top", "time", "app_list", "hospital_id", "level3_ids"]) [value_map.get(x[1], 300000),value_map.get(date, 299999)]))
rdd = df.select("app_list", "level2_ids", "level3_ids","ucity_id","device_id","cid_id","label", "y", "z",
"ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time","top") \
.rdd.map(lambda x: (app_list_func(x[0], app_list_map), app_list_func(x[1], level2_map),
app_list_func(x[2], level3_map), x[3],x[4],x[5],x[6],x[7],x[8],
value_map.get(x[3], 300000),value_map.get(x[9], 299999),
value_map.get(x[10], 299998), value_map.get(x[11], 299997),
value_map.get(x[12], 299996), value_map.get(x[13], 299995),
value_map.get(x[14], 299994),value_map.get(x[15], 299993),
value_map.get(x[16], 299992),value_map.get(x[17], 299991),
value_map.get(x[18], 299990),value_map.get(x[19], 299989),
value_map.get(x[20], 299988),value_map.get(x[21], 299987),
value_map[date]))
rdd.persist() rdd.persist()
native_pre = spark.createDataFrame(rdd.filter(lambda x:x[6] == 0).map(lambda x:(x[3],x[4],x[5])))\ native_pre = spark.createDataFrame(rdd.filter(lambda x:x[4] == 0).map(lambda x:(x[1],x[2],x[3])))\
.toDF("city","uid","cid_id") .toDF("city","uid","cid_id")
print("native") print("native")
native_pre.toPandas().to_csv(local_path+"native.csv", header=True) native_pre.toPandas().to_csv(local_path+"native.csv", header=True)
spark.createDataFrame(rdd.filter(lambda x: x[6] == 0) spark.createDataFrame(rdd.filter(lambda x: x[4] == 0).map(lambda x: (x[0],x[5],x[6],x[7]))) \
.map(lambda x: (x[0], x[1], x[2],x[7],x[8],x[9],x[10],x[11],x[12], .toDF("level2_ids","y","z","ids").repartition(1).write.format("tfrecords") \
x[13],x[14],x[15], .option("recordType", "SequenceExample").save(path=path+"native/", mode="overwrite")
x[16],x[17],x[18],x[19],x[20],x[21],x[22],x[23]))) \
.toDF("app_list", "level2_ids", "level3_ids","y","z","ucity_id", nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x:(x[1],x[2],x[3]))) \
"ccity_name", "device_type","manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time", "top","stat_date").repartition(1).write.format("tfrecords").option("recordType", "Example") \
.save(path=path+"native/", mode="overwrite")
nearby_pre = spark.createDataFrame(rdd.filter(lambda x: x[6] == 1).map(lambda x: (x[3], x[4], x[5]))) \
.toDF("city", "uid", "cid_id") .toDF("city", "uid", "cid_id")
print("nearby") print("nearby")
nearby_pre.toPandas().to_csv(local_path+"nearby.csv", header=True) nearby_pre.toPandas().to_csv(local_path+"nearby.csv", header=True)
spark.createDataFrame(rdd.filter(lambda x: x[4] == 1).map(lambda x: (x[0], x[5], x[6], x[7]))) \
.toDF("level2_ids","y","z","ids").repartition(1).write.format("tfrecords") \
spark.createDataFrame(rdd.filter(lambda x: x[6] == 1) .option("recordType", "SequenceExample").save(path=path+"nearby/", mode="overwrite")
.map(lambda x: (x[0], x[1], x[2], x[7], x[8], x[9], x[10], x[11], x[12],
x[13], x[14], x[15],
x[16], x[17], x[18], x[19], x[20], x[21], x[22], x[23]))) \
.toDF("app_list", "level2_ids", "level3_ids","y","z", "ucity_id",
"ccity_name", "device_type", "manufacturer", "channel", "time", "hospital_id",
"treatment_method", "price_min", "price_max", "treatment_time", "maintain_time",
"recover_time","top","stat_date").repartition(1).write.format("tfrecords").option("recordType", "Example") \
.save(path=path+"nearby/", mode="overwrite")
rdd.unpersist() rdd.unpersist()
...@@ -196,20 +125,6 @@ def con_sql(db,sql): ...@@ -196,20 +125,6 @@ def con_sql(db,sql):
return df return df
def test():
sql = "select stat_date,cid_id,y,ccity_name from esmm_train_data limit 60"
rdd = spark.sql(sql).select("stat_date","cid_id","y","ccity_name").rdd.map(lambda x:(x[0],x[1],x[2],x[3]))
df = spark.createDataFrame(rdd)
df.show(6)
# spark.sql("use online")
# spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
# spark.sql("ADD JAR /srv/apps/hive-udf-1.0-SNAPSHOT.jar")
# spark.sql("CREATE TEMPORARY FUNCTION json_map AS 'brickhouse.udf.json.JsonMapUDF'")
# spark.sql("CREATE TEMPORARY FUNCTION is_json AS 'com.gmei.hive.common.udf.UDFJsonFormatCheck'")
#
# spark.sql("select cl_type from online.tl_hdfs_maidian_view where partition_date = '20190312' limit 6").show()
if __name__ == '__main__': if __name__ == '__main__':
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \ sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \ .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
...@@ -222,13 +137,13 @@ if __name__ == '__main__': ...@@ -222,13 +137,13 @@ if __name__ == '__main__':
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate() spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark) ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test") ti.tidbMapDatabase("jerry_test")
ti.tidbMapDatabase("eagle") # ti.tidbMapDatabase("eagle")
spark.sparkContext.setLogLevel("WARN") spark.sparkContext.setLogLevel("WARN")
path = "hdfs:///strategy/esmm/" path = "hdfs:///strategy/esmm/"
local_path = "/home/gmuser/test/" local_path = "/home/gmuser/esmm/"
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer() validate_date, value_map, app_list_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map) get_predict(validate_date, value_map, app_list_map)
......
...@@ -8,10 +8,7 @@ ...@@ -8,10 +8,7 @@
import shutil import shutil
import os import os
import json import json
import glob
from datetime import date, timedelta from datetime import date, timedelta
import random
import tensorflow as tf import tensorflow as tf
#################### CMD Arguments #################### #################### CMD Arguments ####################
...@@ -37,7 +34,8 @@ tf.app.flags.DEFINE_string("deep_layers", '256,128,64', "deep layers") ...@@ -37,7 +34,8 @@ tf.app.flags.DEFINE_string("deep_layers", '256,128,64', "deep layers")
tf.app.flags.DEFINE_string("dropout", '0.5,0.5,0.5', "dropout rate") tf.app.flags.DEFINE_string("dropout", '0.5,0.5,0.5', "dropout rate")
tf.app.flags.DEFINE_boolean("batch_norm", False, "perform batch normaization (True or False)") tf.app.flags.DEFINE_boolean("batch_norm", False, "perform batch normaization (True or False)")
tf.app.flags.DEFINE_float("batch_norm_decay", 0.9, "decay for the moving average(recommend trying decay=0.9)") tf.app.flags.DEFINE_float("batch_norm_decay", 0.9, "decay for the moving average(recommend trying decay=0.9)")
tf.app.flags.DEFINE_string("data_dir", '', "data dir") tf.app.flags.DEFINE_string("hdfs_dir", '', "hdfs dir")
tf.app.flags.DEFINE_string("local_dir", '', "local dir")
tf.app.flags.DEFINE_string("dt_dir", '', "data dt partition") tf.app.flags.DEFINE_string("dt_dir", '', "data dt partition")
tf.app.flags.DEFINE_string("model_dir", '', "model check point dir") tf.app.flags.DEFINE_string("model_dir", '', "model check point dir")
tf.app.flags.DEFINE_string("servable_model_dir", '', "export servable model for TensorFlow Serving") tf.app.flags.DEFINE_string("servable_model_dir", '', "export servable model for TensorFlow Serving")
...@@ -49,19 +47,10 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): ...@@ -49,19 +47,10 @@ def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames) print('Parsing', filenames)
def _parse_fn(record): def _parse_fn(record):
features = { features = {
"y": tf.FixedLenFeature([], tf.float32), "y": tf.VarLenFeature(tf.int64),
"z": tf.FixedLenFeature([], tf.float32), "z": tf.VarLenFeature(tf.int64),
"ids": tf.FixedLenFeature([FLAGS.field_size], tf.int64), "ids": tf.VarLenFeature(tf.int64),
"app_list": tf.VarLenFeature(tf.int64), "level2_ids": tf.VarLenFeature(tf.int64)
"level2_list": tf.VarLenFeature(tf.int64),
"level3_list": tf.VarLenFeature(tf.int64),
"tag1_list": tf.VarLenFeature(tf.int64),
"tag2_list": tf.VarLenFeature(tf.int64),
"tag3_list": tf.VarLenFeature(tf.int64),
"tag4_list": tf.VarLenFeature(tf.int64),
"tag5_list": tf.VarLenFeature(tf.int64),
"tag6_list": tf.VarLenFeature(tf.int64),
"tag7_list": tf.VarLenFeature(tf.int64)
} }
parsed = tf.parse_single_example(record, features) parsed = tf.parse_single_example(record, features)
...@@ -108,15 +97,8 @@ def model_fn(features, labels, mode, params): ...@@ -108,15 +97,8 @@ def model_fn(features, labels, mode, params):
feat_ids = features['ids'] feat_ids = features['ids']
app_list = features['app_list'] app_list = features['app_list']
level2_list = features['level2_list'] level2_list = features['level2_ids']
level3_list = features['level3_list']
tag1_list = features['tag1_list']
tag2_list = features['tag2_list']
tag3_list = features['tag3_list']
tag4_list = features['tag4_list']
tag5_list = features['tag5_list']
tag6_list = features['tag6_list']
tag7_list = features['tag7_list']
if FLAGS.task_type != "infer": if FLAGS.task_type != "infer":
y = labels['y'] y = labels['y']
...@@ -127,18 +109,10 @@ def model_fn(features, labels, mode, params): ...@@ -127,18 +109,10 @@ def model_fn(features, labels, mode, params):
embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids) embedding_id = tf.nn.embedding_lookup(Feat_Emb,feat_ids)
app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum") app_id = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=app_list, sp_weights=None, combiner="sum")
level2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level2_list, sp_weights=None, combiner="sum") level2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level2_list, sp_weights=None, combiner="sum")
level3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=level3_list, sp_weights=None, combiner="sum")
tag1 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag1_list, sp_weights=None, combiner="sum")
tag2 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag2_list, sp_weights=None, combiner="sum")
tag3 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag3_list, sp_weights=None, combiner="sum")
tag4 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag4_list, sp_weights=None, combiner="sum")
tag5 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag5_list, sp_weights=None, combiner="sum")
tag6 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag6_list, sp_weights=None, combiner="sum")
tag7 = tf.nn.embedding_lookup_sparse(Feat_Emb, sp_ids=tag7_list, sp_weights=None, combiner="sum")
# x_concat = tf.reshape(embedding_id,shape=[-1, common_dims]) # None * (F * K) # x_concat = tf.reshape(embedding_id,shape=[-1, common_dims]) # None * (F * K)
x_concat = tf.concat([tf.reshape(embedding_id,shape=[-1,common_dims]),app_id,level2,level3,tag1, x_concat = tf.concat([tf.reshape(embedding_id,shape=[-1,common_dims]),app_id,level2], axis=1)
tag2,tag3,tag4,tag5,tag6,tag7], axis=1)
with tf.name_scope("CVR_Task"): with tf.name_scope("CVR_Task"):
if mode == tf.estimator.ModeKeys.TRAIN: if mode == tf.estimator.ModeKeys.TRAIN:
...@@ -301,7 +275,8 @@ def main(_): ...@@ -301,7 +275,8 @@ def main(_):
print('task_type ', FLAGS.task_type) print('task_type ', FLAGS.task_type)
print('model_dir ', FLAGS.model_dir) print('model_dir ', FLAGS.model_dir)
print('data_dir ', FLAGS.data_dir) print('hdfs_dir ', FLAGS.hdfs_dir)
print('local_dir ', FLAGS.local_dir)
print('dt_dir ', FLAGS.dt_dir) print('dt_dir ', FLAGS.dt_dir)
print('num_epochs ', FLAGS.num_epochs) print('num_epochs ', FLAGS.num_epochs)
print('feature_size ', FLAGS.feature_size) print('feature_size ', FLAGS.feature_size)
...@@ -320,6 +295,7 @@ def main(_): ...@@ -320,6 +295,7 @@ def main(_):
path = "hdfs:///strategy/esmm/" path = "hdfs:///strategy/esmm/"
tr_files = [path+"tr/part-r-00000"] tr_files = [path+"tr/part-r-00000"]
va_files = [path+"va/part-r-00000"] va_files = [path+"va/part-r-00000"]
te_files = ["%s/part-r-00000" % FLAGS.hdfs_dir]
# tr_files = glob.glob("%s/tr/*tfrecord" % FLAGS.data_dir) # tr_files = glob.glob("%s/tr/*tfrecord" % FLAGS.data_dir)
# random.shuffle(tr_files) # random.shuffle(tr_files)
...@@ -366,9 +342,9 @@ def main(_): ...@@ -366,9 +342,9 @@ def main(_):
print('%s: %s' % (key,value)) print('%s: %s' % (key,value))
elif FLAGS.task_type == 'infer': elif FLAGS.task_type == 'infer':
preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys=["pctcvr","pctr","pcvr"]) preds = Estimator.predict(input_fn=lambda: input_fn(te_files, num_epochs=1, batch_size=FLAGS.batch_size), predict_keys=["pctcvr","pctr","pcvr"])
with open(FLAGS.data_dir+"/pred.txt", "w") as fo: with open(FLAGS.local_dir+"/pred.txt", "w") as fo:
print("-"*100) print("-"*100)
with open(FLAGS.data_dir + "/pred.txt", "w") as fo: with open(FLAGS.local_dir + "/pred.txt", "w") as fo:
for prob in preds: for prob in preds:
fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr'])) fo.write("%f\t%f\t%f\n" % (prob['pctr'], prob['pcvr'], prob['pctcvr']))
elif FLAGS.task_type == 'export': elif FLAGS.task_type == 'export':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment