Commit 5504ec58 authored by 张彦钊's avatar 张彦钊

修改测试文件

parent cf0a5f2c
......@@ -203,17 +203,38 @@ def con_sql(db,sql):
def test():
sql = "select stat_date,cid_id,y,ccity_name from esmm_train_data limit 60"
rdd = spark.sql(sql).select("stat_date","cid_id","y","ccity_name").rdd.map(lambda x:(x[0],x[1],x[2],x[3]))
df = spark.createDataFrame(rdd)
df.show(6)
# sql = "select stat_date,cid_id,y,ccity_name from esmm_train_data limit 60"
# rdd = spark.sql(sql).select("stat_date","cid_id","y","ccity_name").rdd.map(lambda x:(x[0],x[1],x[2],x[3]))
# df = spark.createDataFrame(rdd)
# df.show(6)
from hdfs import InsecureClient
from hdfs.ext.dataframe import read_dataframe
client = InsecureClient('http://nvwa01:50070')
df = read_dataframe(client,"/recommend/native/part-00196-f83757ab-9f64-4a2c-9f27-0b76df51c1c4-c000.avro")
df = read_dataframe(client,"/recommend/native/part-00058-e818163a-5502-4339-9d72-3cef1edeb449-c000.avro")
print("native")
print(df.head())
df = read_dataframe(client, "/recommend/nearby/part-00136-93b2ba3d-c098-4c43-8d90-87d3db38c3ec-c000.avro")
print("nearby")
print(df.head())
df = read_dataframe(client, "/recommend/tr/part-00185-acd4327a-a0ac-415a-b2c5-e8ad57857c0d-c000.avro")
print("tr")
print(df.head())
df = read_dataframe(client, "/recommend/va/part-00191-f1aeb1df-048b-4794-af9f-2c71f14b28b6-c000.avro")
print("va")
print(df.head())
df = read_dataframe(client, "/recommend/pre_native/part-00193-d3f6b96e-1eb5-4df2-8800-20b2506363e9-c000.avro")
print("pre_native")
print(df.head())
# print(df.count())
df = read_dataframe(client, "/recommend/pre_nearby/part-00175-e3b9b9ea-2c9f-4e1f-bf6e-78f107c6f83d-c000.avro")
print("pre_nearby")
print(df.head())
# spark.sql("use online")
# spark.sql("ADD JAR /srv/apps/brickhouse-0.7.1-SNAPSHOT.jar")
......@@ -234,11 +255,13 @@ if __name__ == '__main__':
.set("spark.driver.maxResultSize", "8g")
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
ti = pti.TiContext(spark)
ti.tidbMapDatabase("jerry_test")
spark.sparkContext.setLogLevel("WARN")
validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
# ti = pti.TiContext(spark)
# ti.tidbMapDatabase("jerry_test")
# spark.sparkContext.setLogLevel("WARN")
# validate_date, value_map, app_list_map, leve2_map, leve3_map = feature_engineer()
# get_predict(validate_date, value_map, app_list_map, leve2_map, leve3_map)
test()
......
......@@ -20,41 +20,37 @@ tf.app.flags.DEFINE_string("output_dir", "./", "output dir")
tf.app.flags.DEFINE_integer("threads", 16, "threads num")
def gen_tfrecords(in_file):
# basename = os.path.basename("/home/gmuser/") + ".tfrecord"
# out_file = os.path.join(FLAGS.output_dir, basename)
out_file = "/home/gmuser/hello.tfrecord"
basename = os.path.basename(in_file) + ".tfrecord"
out_file = os.path.join(FLAGS.output_dir, basename)
tfrecord_out = tf.python_io.TFRecordWriter(out_file)
from hdfs import InsecureClient
from hdfs.ext.dataframe import read_dataframe
client = InsecureClient('http://nvwa01:50070')
df = read_dataframe(client,"/recommend/tr/part-00000-2f0d632b-0c61-4a0b-97d4-54bd5e579c5e-c000.avro")
df = df.rename({"app_list","level2_ids","level3_ids","stat_date","ucity_id", "ccity_name", "device_type", "manufacturer",
"channel", "top", "time", "hospital_id","treatment_method", "price_min",
"price_max", "treatment_time","maintain_time", "recover_time","y","z"})
df = read_dataframe(client,in_file)
for i in range(df.shape[0]):
feats = ["cid_id"]
id = np.array([])
for j in feats:
id = np.append(id,df[j][i])
# app_list = np.array(str(df["app_list"][i]).split(","))
# level2_list = np.array(str(df["clevel2_id"][i]).split(","))
# level3_list = np.array(str(df["level3_ids"][i]).split(","))
app_list = np.array(str(df["app_list"][i]).split(","))
level2_list = np.array(str(df["clevel2_id"][i]).split(","))
level3_list = np.array(str(df["level3_ids"][i]).split(","))
features = tf.train.Features(feature={
"y": tf.train.Feature(float_list=tf.train.FloatList(value=[df["y"][i]])),
"z": tf.train.Feature(float_list=tf.train.FloatList(value=[df["z"][i]])),
"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int)))
"ids": tf.train.Feature(int64_list=tf.train.Int64List(value=id.astype(np.int))),
"app_list": tf.train.Feature(int64_list=tf.train.Int64List(value=app_list.astype(np.int))),
"level2_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level2_list.astype(np.int))),
"level3_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level3_list.astype(np.int)))
})
# "app_list":tf.train.Feature(int64_list=tf.train.Int64List(value=app_list.astype(np.int))),
# "level2_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level2_list.astype(np.int))),
# "level3_list": tf.train.Feature(int64_list=tf.train.Int64List(value=level3_list.astype(np.int)))
example = tf.train.Example(features = features)
serialized = example.SerializeToString()
tfrecord_out.write(serialized)
tfrecord_out.close()
def main(_):
if not os.path.exists(FLAGS.output_dir):
os.mkdir(FLAGS.output_dir)
......@@ -68,7 +64,5 @@ def main(_):
if __name__ == "__main__":
# tf.logging.set_verbosity(tf.logging.INFO)
# tf.app.run()
gen_tfrecords("a")
\ No newline at end of file
tf.logging.set_verbosity(tf.logging.INFO)
tf.app.run()
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment