Commit 1c24580e authored by 张彦钊's avatar 张彦钊

修改测试文件

parent b1bf2771
This diff is collapsed.
#coding=utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import sys
import os
import glob
import tensorflow as tf
import numpy as np
import re
from multiprocessing import Pool as ThreadPool
flags = tf.app.flags
FLAGS = flags.FLAGS
LOG = tf.logging
tf.app.flags.DEFINE_string("input_dir", "./", "input dir")
tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")
#保证顺序以及字段数量
#User_Fileds = set(['101','109_14','110_14','127_14','150_14','121','122','124','125','126','127','128','129'])
#Ad_Fileds = set(['205','206','207','210','216'])
#Context_Fileds = set(['508','509','702','853','301'])
#Common_Fileds = {'1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8','9':'9','10':'10','11':'11','12':'12','13':'13','14':'14','15':'15','16':'16','17':'17','18':'18','19':'19','20':'20','21':'21','22':'22','23':'23'}
Common_Fileds = {'1':'1','2':'2','3':'3','4':'4','5':'5','6':'6','7':'7','8':'8'}
UMH_Fileds = {'109_14':('u_cat','12'),'110_14':('u_shop','13'),'127_14':('u_brand','14'),'150_14':('u_int','15')} #user multi-hot feature
Ad_Fileds = {'206':('a_cat','16'),'207':('a_shop','17'),'210':('a_int','18'),'216':('a_brand','19')} #ad feature for DIN
#40362692,0,0,216:9342395:1.0 301:9351665:1.0 205:7702673:1.0 206:8317829:1.0 207:8967741:1.0 508:9356012:2.30259 210:9059239:1.0 210:9042796:1.0 210:9076972:1.0 210:9103884:1.0 210:9063064:1.0 127_14:3529789:2.3979 127_14:3806412:2.70805
def gen_tfrecords(in_file):
basename = os.path.basename(in_file) + ".tfrecord"
out_file = os.path.join(FLAGS.output_dir, basename)
tfrecord_out = tf.python_io.TFRecordWriter(out_file)
with open(in_file) as fi:
for line in fi:
line = line.strip().split('\t')[-1]
fields = line.strip().split(',')
if len(fields) != 4:
continue
#1 label
y = [float(fields[1])]
z = [float(fields[2])]
feature = {
"y": tf.train.Feature(float_list = tf.train.FloatList(value=y)),
"z": tf.train.Feature(float_list = tf.train.FloatList(value=z))
}
splits = re.split('[ :]', fields[3])
ffv = np.reshape(splits,(-1,3))
#common_mask = np.array([v in Common_Fileds for v in ffv[:,0]])
#af_mask = np.array([v in Ad_Fileds for v in ffv[:,0]])
#cf_mask = np.array([v in Context_Fileds for v in ffv[:,0]])
#2 不需要特殊处理的特征
feat_ids = np.array([])
#feat_vals = np.array([])
for f, def_id in Common_Fileds.items():
if f in ffv[:,0]:
mask = np.array(f == ffv[:,0])
feat_ids = np.append(feat_ids, ffv[mask,1])
#np.append(feat_vals,ffv[mask,2].astype(np.float))
else:
feat_ids = np.append(feat_ids, def_id)
#np.append(feat_vals,1.0)
feature.update({"feat_ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int)))})
#"feat_vals": tf.train.Feature(float_list=tf.train.FloatList(value=feat_vals))})
#3 特殊字段单独处理
for f, (fname, def_id) in UMH_Fileds.items():
if f in ffv[:,0]:
mask = np.array(f == ffv[:,0])
feat_ids = ffv[mask,1]
feat_vals= ffv[mask,2]
else:
feat_ids = np.array([def_id])
feat_vals = np.array([1.0])
feature.update({fname+"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int))),
fname+"vals": tf.train.Feature(float_list=tf.train.FloatList(value=feat_vals.astype(np.float)))})
for f, (fname, def_id) in Ad_Fileds.items():
if f in ffv[:,0]:
mask = np.array(f == ffv[:,0])
feat_ids = ffv[mask,1]
else:
feat_ids = np.array([def_id])
feature.update({fname+"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=feat_ids.astype(np.int)))})
# serialized to Example
example = tf.train.Example(features = tf.train.Features(feature = feature))
serialized = example.SerializeToString()
tfrecord_out.write(serialized)
#num_lines += 1
#if num_lines % 10000 == 0:
# print("Process %d" % num_lines)
tfrecord_out.close()
def main(_):
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
file_list = glob.glob(os.path.join(FLAGS.input_dir, "*.csv"))
print("total files: %d" % len(file_list))
pool = ThreadPool(FLAGS.threads) # Sets the pool size
pool.map(gen_tfrecords, file_list)
pool.close()
pool.join()
if __name__ == "__main__":
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
\ No newline at end of file
# -*- coding: utf-8 -*-
from pyspark.sql import HiveContext
from pyspark.context import SparkContext
import pymysql
from pyspark.conf import SparkConf
import pytispark.pytispark as pti
# from pyspark.sql import SQLContext
from pyspark.sql import SparkSession
# import datetime
import datetime
import pandas as pd
def feature_engineer():
db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data"
validate_date = con_sql(db, sql)[0].values.tolist()[0]
print("validate_date:" + validate_date)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
start = (temp - datetime.timedelta(days=300)).strftime("%Y-%m-%d")
print(start)
sparkConf = SparkConf().set("spark.hive.mapred.supports.subdirectories", "true") \
.set("spark.hadoop.mapreduce.input.fileinputformat.input.dir.recursive", "true") \
.set("spark.tispark.plan.allow_index_double_read", "false") \
.set("spark.tispark.plan.allow_index_read", "true") \
.set("spark.sql.extensions", "org.apache.spark.sql.TiExtensions") \
.set("spark.tispark.pd.addresses", "172.16.40.158:2379").set("spark.io.compression.codec", "lzf")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
spark.sparkContext.setLogLevel("WARN")
sql = "select e.y,e.z,e.stat_date,e.ucity_id,feat.level2_ids,e.ccity_name,u.device_type,u.manufacturer," \
"u.channel,c.top,e.device_id,cut.time,dl.app_list,e.diary_service_id,feat.level3_ids," \
"k.treatment_method,k.price_min,k.price_max,k.treatment_time,k.maintain_time,k.recover_time " \
"from esmm_train_data e left join user_feature u on e.device_id = u.device_id " \
"left join cid_type_top c on e.device_id = c.device_id " \
"left join cid_time_cut cut on e.cid_id = cut.cid " \
"left join device_app_list dl on e.device_id = dl.device_id " \
"left join diary_feat feat on e.cid_id = feat.diary_id " \
"left join train_Knowledge_network_data k on feat.level2 = k.level2_id " \
"where e.stat_date >= '{}'".format(start)
df = spark.sql(sql)
df.show(6)
print(df.count())
# df = df.rename(columns={0: "y", 1: "z", 2: "stat_date", 3: "ucity_id", 4: "clevel2_id", 5: "ccity_name",
# 6: "device_type", 7: "manufacturer", 8: "channel", 9: "top", 10: "device_id",
# 11: "time", 12: "app_list", 13: "service_id", 14: "level3_ids", 15: "level2"})
#
# db = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
# sql = "select level2_id,treatment_method,price_min,price_max,treatment_time,maintain_time,recover_time " \
# "from train_Knowledge_network_data"
# knowledge = con_sql(db, sql)
# knowledge = knowledge.rename(columns={0: "level2", 1: "method", 2: "min", 3: "max",
# 4: "treatment_time", 5: "maintain_time", 6: "recover_time"})
# knowledge["level2"] = knowledge["level2"].astype("str")
#
# df = pd.merge(df, knowledge, on='level2', how='left')
# df = df.drop("level2", axis=1)
#
# service_id = tuple(df["service_id"].unique())
# db = pymysql.connect(host='172.16.30.143', port=3306, user='work',
# passwd='BJQaT9VzDcuPBqkd', db='zhengxing')
# sql = "select s.id,d.hospital_id from api_service s left join api_doctor d on s.doctor_id = d.id " \
# "where s.id in {}".format(service_id)
# hospital = con_sql(db, sql)
# hospital = hospital.rename(columns={0: "service_id", 1: "hospital_id"})
# # print(hospital.head())
# # print("hospital")
# # print(hospital.count())
# hospital["service_id"] = hospital["service_id"].astype("str")
# df = pd.merge(df, hospital, on='service_id', how='left')
# df = df.drop("service_id", axis=1)
#
# print(df.count())
#
# print("before")
# print(df.shape)
#
# df = df.drop_duplicates(["ucity_id", "clevel2_id", "ccity_name", "device_type", "manufacturer",
# "channel", "top", "time", "stat_date", "app_list", "hospital_id", "level3_ids"])
#
# print("after")
# print(df.shape)
# app_list_number, app_list_map = multi_hot(df, "app_list", 2)
# level2_number, level2_map = multi_hot(df, "clevel2_id", 2 + app_list_number)
# level3_number, level3_map = multi_hot(df, "level3_ids", 2 + app_list_number + level2_number)
#
# unique_values = []
# features = ["ucity_id", "ccity_name", "device_type", "manufacturer",
# "channel", "top", "time", "stat_date", "hospital_id",
# "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]
# for i in features:
# df[i] = df[i].astype("str")
# df[i] = df[i].fillna("lost")
# # 下面这行代码是为了区分不同的列中有相同的值
# df[i] = df[i] + i
# unique_values.extend(list(df[i].unique()))
#
# temp = list(range(2 + app_list_number + level2_number + level3_number,
# 2 + app_list_number + level2_number + level3_number + len(unique_values)))
# value_map = dict(zip(unique_values, temp))
#
# df = df.drop("device_id", axis=1)
# train = df[df["stat_date"] != validate_date + "stat_date"]
# test = df[df["stat_date"] == validate_date + "stat_date"]
# for i in ["ucity_id", "ccity_name", "device_type", "manufacturer",
# "channel", "top", "time", "stat_date", "hospital_id",
# "method", "min", "max", "treatment_time", "maintain_time", "recover_time"]:
# train[i] = train[i].map(value_map)
# test[i] = test[i].map(value_map)
#
# print("train shape")
# print(train.shape)
# print("test shape")
# print(test.shape)
#
# write_csv(train, "tr", 100000)
# write_csv(test, "va", 80000)
def con_sql(db,sql):
cursor = db.cursor()
cursor.execute(sql)
result = cursor.fetchall()
df = pd.DataFrame(list(result))
db.close()
return df
def test():
......@@ -22,9 +141,6 @@ def test():
spark = SparkSession.builder.config(conf= sparkConf).enableHiveSupport().getOrCreate()
spark.sql("use online")
# spark.sql("ADD JAR hdfs:///user/hive/share/lib/udf/brickhouse-0.7.1-SNAPSHOT.jar")
# spark.sql("ADD JAR hdfs:///user/hive/share/lib/udf/hive-udf-1.0-SNAPSHOT.jar")
spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
spark.sql("ADD JAR /srv/apps/hive-udf-1.0-SNAPSHOT.jar")
spark.sql("CREATE TEMPORARY FUNCTION json_map AS 'brickhouse.udf.json.JsonMapUDF'")
......@@ -57,4 +173,4 @@ def test():
if __name__ == '__main__':
test()
\ No newline at end of file
feature_engineer()
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment