Commit 5504ec58 authored by 张彦钊's avatar 张彦钊

修改测试文件

parent cf0a5f2c
...@@ -203,17 +203,38 @@ def con_sql(db,sql): ...@@ -203,17 +203,38 @@ def con_sql(db,sql):
def test(): def test():
sql = "select stat_date,cid_id,y,ccity_name from esmm_train_data limit 60" # sql = "select stat_date,cid_id,y,ccity_name from esmm_train_data limit 60"
rdd = spark.sql(sql).select("stat_date","cid_id","y","ccity_name").rdd.map(lambda x:(x[0],x[1],x[2],x[3])) # rdd = spark.sql(sql).select("stat_date","cid_id","y","ccity_name").rdd.map(lambda x:(x[0],x[1],x[2],x[3]))
df = spark.createDataFrame(rdd) # df = spark.createDataFrame(rdd)
df.show(6) # df.show(6)
from hdfs import InsecureClient from hdfs import InsecureClient
from hdfs.ext.dataframe import read_dataframe from hdfs.ext.dataframe import read_dataframe
client = InsecureClient('http://nvwa01:50070') client = InsecureClient('http://nvwa01:50070')
df = read_dataframe(client,"/recommend/native/part-00196-f83757ab-9f64-4a2c-9f27-0b76df51c1c4-c000.avro") df = read_dataframe(client,"/recommend/native/part-00058-e818163a-5502-4339-9d72-3cef1edeb449-c000.avro")
print("native")
print(df.head())
df = read_dataframe(client, "/recommend/nearby/part-00136-93b2ba3d-c098-4c43-8d90-87d3db38c3ec-c000.avro")
print("nearby")
print(df.head())
df = read_dataframe(client, "/recommend/tr/part-00185-acd4327a-a0ac-415a-b2c5-e8ad57857c0d-c000.avro")
print("tr")
print(df.head())
df = read_dataframe(client, "/recommend/va/part-00191-f1aeb1df-048b-4794-af9f-2c71f14b28b6-c000.avro")
print("va")
print(df.head())
df = read_dataframe(client, "/recommend/pre_native/part-00193-d3f6b96e-1eb5-4df2-8800-20b2506363e9-c000.avro")
print("pre_native")
print(df.head()) print(df.head())
# print(df.count())
df = read_dataframe(client, "/recommend/pre_nearby/part-00175-e3b9b9ea-2c9f-4e1f-bf6e-78f107c6f83d-c000.avro")
print("pre_nearby")
print(df.head())
# spark.sql("use online") # spark.sql("use online")
# spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar") # spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
...@@ -234,11 +255,13 @@ if __name__ == '__main__': ...@@ -234,11 +255,13 @@ if __name__ == '__main__':
.set("spark.driver.maxResultSize", "8g") .set("spark.driver.maxResultSize", "8g")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate() spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark) # ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test") # ti.tidbMapDatabase("jerry_test")
spark.sparkContext.setLogLevel("WARN") # spark.sparkContext.setLogLevel("WARN")
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer() # validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map) # get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
test()
......
...@@ -20,41 +20,37 @@ tf.app.flags.DEFINE_string("output_dir", "./", "output dir") ...@@ -20,41 +20,37 @@ tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num") tf.app.flags.DEFINE_integer("threads", 16, "threads num")
def gen_tfrecords(in_file): def gen_tfrecords(in_file):
# basename = os.path.basename("/home/gmuser/") + ".tfrecord" basename = os.path.basename(in_file) + ".tfrecord"
# out_file = os.path.join(FLAGS.output_dir, basename) out_file = os.path.join(FLAGS.output_dir, basename)
out_file = "/home/gmuser/hello.tfrecord"
tfrecord_out = tf.python_io.TFRecordWriter(out_file) tfrecord_out = tf.python_io.TFRecordWriter(out_file)
from hdfs import InsecureClient from hdfs import InsecureClient
from hdfs.ext.dataframe import read_dataframe from hdfs.ext.dataframe import read_dataframe
client = InsecureClient('http://nvwa01:50070') client = InsecureClient('http://nvwa01:50070')
df = read_dataframe(client,"/recommend/tr/part-00000-2f0d632b-0c61-4a0b-97d4-54bd5e579c5e-c000.avro") df = read_dataframe(client,in_file)
df = df.rename({"app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z"})
for i in range(df.shape[0]): for i in range(df.shape[0]):
feats = ["cid_id"] feats = ["cid_id"]
id = np.array([]) id = np.array([])
for j in feats: for j in feats:
id = np.append(id,df[j][i]) id = np.append(id,df[j][i])
# app_list = np.array(str(df["app_list"][i]).split(",")) app_list = np.array(str(df["app_list"][i]).split(","))
# level2_list = np.array(str(df["clevel2_id"][i]).split(",")) level2_list = np.array(str(df["clevel2_id"][i]).split(","))
# level3_list = np.array(str(df["level3_ids"][i]).split(",")) level3_list = np.array(str(df["level3_ids"][i]).split(","))
features = tf.train.Features(feature={ features = tf.train.Features(feature={
"y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])), "y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])),
"z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])), "z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])),
"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int))) "ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int))),
"app_list": tf.train.Feature(int64_list=tf.train.Int64List(value=app_list.astype(np.int))),
"level2_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level2_list.astype(np.int))),
"level3_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level3_list.astype(np.int)))
}) })
# "app_list":tf.train.Feature(int64_list=tf.train.Int64List(value=app_list.astype(np.int))),
# "level2_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level2_list.astype(np.int))),
# "level3_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level3_list.astype(np.int)))
example = tf.train.Example(features = features) example = tf.train.Example(features = features)
serialized = example.SerializeToString() serialized = example.SerializeToString()
tfrecord_out.write(serialized) tfrecord_out.write(serialized)
tfrecord_out.close() tfrecord_out.close()
def main(_): def main(_):
if not os.path.exists(FLAGS.output_dir): if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir) os.mkdir(FLAGS.output_dir)
...@@ -68,7 +64,5 @@ def main(_): ...@@ -68,7 +64,5 @@ def main(_):
if __name__ == "__main__": if __name__ == "__main__":
# tf.logging.set_verbosity(tf.logging.INFO) tf.logging.set_verbosity(tf.logging.INFO)
# tf.app.run() tf.app.run()
gen_tfrecords("a")
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment