Commit 6d2da688 authored by 张彦钊's avatar 张彦钊

删除预测集的分区

parent 9355647a
......@@ -188,7 +188,8 @@ def feature_engineer():
df = df.drop_duplicates(["ucity_id", "level2_ids", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids",
"tag1", "tag2", "tag3", "tag4", "tag5", "tag6", "tag7"])
print("样本总量:")
print(df.count())
df = df.na.fill(dict(zip(features, features)))
rdd = df.select("stat_date", "y", "z", "app_list", "level2_ids", "level3_ids",
......@@ -299,7 +300,7 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
spark.createDataFrame(rdd.filter(lambda x: x[0] == 0)
.map(lambda x: (x[1],x[2],x[6],x[7],x[8],x[9],x[10],x[11],x[12],x[13],x[14],x[15],x[16]))) \
.toDF("y","z","app_list", "level2_list", "level3_list","tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids").repartition(1).write.format("tfrecords") \
"tag5_list", "tag6_list", "tag7_list", "ids").write.format("tfrecords") \
.save(path=path+"native/", mode="overwrite")
print("native tfrecord done")
h = time.time()
......@@ -316,7 +317,7 @@ def get_predict(date,value_map,app_list_map,leve2_map,leve3_map):
.map(
lambda x: (x[1], x[2], x[6], x[7], x[8], x[9], x[10], x[11], x[12], x[13], x[14], x[15], x[16]))) \
.toDF("y", "z", "app_list", "level2_list", "level3_list", "tag1_list", "tag2_list", "tag3_list", "tag4_list",
"tag5_list", "tag6_list", "tag7_list", "ids").repartition(1).write.format("tfrecords") \
"tag5_list", "tag6_list", "tag7_list", "ids").write.format("tfrecords") \
.save(path=path + "nearby/", mode="overwrite")
print("nearby tfrecord done")
......
......@@ -24,4 +24,9 @@ echo "infer nearby..."
${PYTHON_PATH} ${MODEL_PATH}/train.py --ctr_task_wgt=0.5 --learning_rate=0.0001 --deep_layers=512,256,128,64,32 --dropout=0.3,0.3,0.3,0.3,0.3 --optimizer=Adam --num_epochs=1 --embedding_size=16 --batch_size=2000 --field_size=15 --feature_size=600000 --l2_reg=0.005 --log_steps=100 --num_threads=36 --model_dir=${LOCAL_PATH}/model_ckpt/DeepCvrMTL/ --local_dir=${LOCAL_PATH}/nearby --hdfs_dir=${HDFS_PATH}/nearby --task_type=infer
echo "sort and 2sql"
${PYTHON_PATH} ${MODEL_PATH}/to_database.py
\ No newline at end of file
${PYTHON_PATH} ${MODEL_PATH}/to_database.py
echo "delete files"
rm /home/gmuser/esmm/*.csv
rm /home/gmuser/esmm/native/*
rm /home/gmuser/esmm/nearby/*
......@@ -348,20 +348,18 @@ def main(_):
print("Not Implemented, Do It Yourself!")
def get_filename(dir_in):
pre_add = "hdfs://172.16.32.4:8020/strategy/esmm/"
x = []
for i in range(0, 200):
if i < 10:
t = pre_add + dir_in + "/part-r-0000" + str(i)
x.append(t)
elif 10 <= i < 100:
t = pre_add + dir_in + "/part-r-000" + str(i)
x.append(t)
elif 100 <= i < 200:
t = pre_add + dir_in + "/part-r-00" + str(i)
x.append(t)
return x
pre_path = "hdfs://172.16.32.4:8020"
args = "hdfs dfs -ls " + dir_in + " | awk '{print $8}'"
proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
s_output, s_err = proc.communicate()
all_dart_dirs = s_output.split()
a = []
for i in all_dart_dirs:
b = str(i).split("/")[4]
if b[:4] == "part":
tmp = pre_path + str(i)[2:-1]
a.append(tmp)
return a
if __name__ == "__main__":
b = time.time()
......
......@@ -7,6 +7,7 @@ from pyspark.sql import SparkSession
import datetime
import pandas as pd
import subprocess
import tensorflow as tf
def app_list_func(x,l):
......@@ -142,7 +143,19 @@ def get_filename(dir_in):
x.append(t)
return x
def parse_fn(record):
features = {
"y": tf.FixedLenFeature([], tf.float32),
"z": tf.FixedLenFeature([], tf.float32)
}
parsed = tf.parse_single_example(record, features)
y = parsed.pop('y')
z = parsed.pop('z')
return {"y": y, "z": z}
if __name__ == '__main__':
# sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
# .set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
# .set("spark.tispark.plan.allow_index_double_read", "false") \
......@@ -161,22 +174,17 @@ if __name__ == '__main__':
#
# validate_date, value_map, app_list_map = feature()
# get_predict(validate_date, value_map, app_list_map)
# [path + "tr/part-r-00000"]
#
#
# spark = SparkSession.builder.getOrCreate()
#
# b = [("a", 1), ("a", 1), ("b", 3), ("a", 2)]
# rdd = spark.sparkContext.parallelize(b)
# df = spark.createDataFrame(rdd).toDF("id", "n")
# df.show()
# df.createOrReplaceTempView("df")
# t = spark.sql("select id from df").map()
import glob
import random
tr_files = glob.glob("/home/gmuser/test/*")
random.shuffle(tr_files)
print("tr_files:", tr_files)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment