Commit 8b967550 authored by Your Name's avatar Your Name

dist test

parent d8b5a892
...@@ -157,8 +157,6 @@ def model_fn(features, labels, mode, params): ...@@ -157,8 +157,6 @@ def model_fn(features, labels, mode, params):
predictions=predictions, predictions=predictions,
export_outputs=export_outputs) export_outputs=export_outputs)
def main(te_file): def main(te_file):
dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d') dt_dir = (date.today() + timedelta(-1)).strftime('%Y%m%d')
model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/" + dt_dir model_dir = "hdfs://172.16.32.4:8020/strategy/esmm/model_ckpt/DeepCvrMTL/" + dt_dir
...@@ -197,6 +195,10 @@ def main(te_file): ...@@ -197,6 +195,10 @@ def main(te_file):
# indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']]) # indices.append([prob['pctr'], prob['pcvr'], prob['pctcvr']])
# return indices # return indices
def trans(x):
return str(x)[2:-1] if str(x)[0] == 'b' else x
if __name__ == "__main__": if __name__ == "__main__":
...@@ -260,44 +262,51 @@ if __name__ == "__main__": ...@@ -260,44 +262,51 @@ if __name__ == "__main__":
spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate() spark = SparkSession.builder.config(conf=sparkConf).enableHiveSupport().getOrCreate()
spark.sparkContext.setLogLevel("WARN") spark.sparkContext.setLogLevel("WARN")
path = "hdfs://172.16.32.4:8020/strategy/esmm/" # path = "hdfs://172.16.32.4:8020/strategy/esmm/"
df = spark.read.format("tfrecords").load(path+"test_nearby/part-r-00000") # df = spark.read.format("tfrecords").load(path+"test_nearby/part-r-00000")
df.show() # df.show()
#
# te_files = [] # te_files = []
# for i in range(0,10): # for i in range(0,10):
# te_files.append([path + "test_nearby/part-r-0000" + str(i)]) # te_files.append([path + "test_nearby/part-r-0000" + str(i)])
# for i in range(10,100): # for i in range(10,100):
# te_files.append([path + "test_nearby/part-r-000" + str(i)]) # te_files.append([path + "test_nearby/part-r-000" + str(i)])
te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/test_nearby/part-r-00000"] #
# te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/test_nearby/part-r-00000"]
rdd_te_files = spark.sparkContext.parallelize(te_files) #
print("-"*100) # rdd_te_files = spark.sparkContext.parallelize(te_files)
indices = rdd_te_files.repartition(1).map(lambda x: main(x)) # print("-"*100)
# print(indices.take(1)) # indices = rdd_te_files.repartition(1).map(lambda x: main(x))
print("-" * 100) # # print(indices.take(1))
# print("-" * 100)
te_result_dataframe = spark.createDataFrame(indices.flatMap(lambda x: x.split(";")).map( #
lambda l: Row(sample_id=l.split(":")[0],uid=l.split(":")[1],city=l.split(":")[2],cid_id=l.split(":")[3],ctcvr=l.split(":")[4]))) # te_result_dataframe = spark.createDataFrame(indices.flatMap(lambda x: x.split(";")).map(
# lambda l: Row(sample_id=l.split(":")[0],uid=l.split(":")[1],city=l.split(":")[2],cid_id=l.split(":")[3],ctcvr=l.split(":")[4])))
print("nearby rdd data") #
te_result_dataframe.show() # print("nearby rdd data")
nearby_data = te_result_dataframe.toPandas() # te_result_dataframe.show()
print("nearby pd data") # nearby_data = te_result_dataframe.toPandas()
print(nearby_data.head()) # print("nearby pd data")
print(nearby_data.dtypes) # print(nearby_data.head())
print("elem type") # print(nearby_data.dtypes)
print(nearby_data["cid_id"][0]) # print("elem type")
print(type(nearby_data["cid_id"][0])) # print(nearby_data["cid_id"][0])
# print(type(nearby_data["cid_id"][0]))
native_data = spark.read.parquet(path+"native_result/") native_data = spark.read.parquet(path+"native_result/")
print("native rdd data") print("native rdd data")
native_data.show() native_data.show()
native_data_pd = native_data.toPandas() native_data_pd = native_data.toPandas()
native_data_pd.apply()
print("native pd data") print("native pd data")
print(native_data_pd.head()) print(native_data_pd.head())
native_data_pd["cid_id1"] = native_data_pd["cid_id"].apply(trans)
native_data_pd["city1"] = native_data_pd["city"].apply(trans)
native_data_pd["uid1"] = native_data_pd["uid"].apply(trans)
print(native_data_pd.head())
print(native_data_pd.dtypes) print(native_data_pd.dtypes)
print("耗时(秒):") print("耗时(秒):")
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment