diff --git a/eda/esmm/Model_pipline/predict.py b/eda/esmm/Model_pipline/predict.py index 583f93648959eeb0905b4884c0df8a56d14b8d00..bcddca85bc562ff65131f583275cbdb4ff8832a7 100644 --- a/eda/esmm/Model_pipline/predict.py +++ b/eda/esmm/Model_pipline/predict.py @@ -182,6 +182,7 @@ def predict(te_file): result.append([str(prob["uid"][0]), str(prob["city"][0]), str(prob["cid_id"][0]), str(prob['pctcvr'])]) return result + def trans(x): return str(x)[2:-1] if str(x)[0] == 'b' else x @@ -192,68 +193,46 @@ def set_join(lst): return ','.join(r) -if __name__ == "__main__": - - if sys.argv[1] == "native": - - b = time.time() - print("infer native task") - path = "hdfs://172.16.32.4:8020/strategy/esmm/" - # df = spark.read.format("tfrecords").load(path+"test_native/part-r-00000") - # df.show() - - te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/test_native/part-r-00000"] - - print("dist predict native") - - print("耗时(秒):") - print((time.time()-b)) - - if sys.argv[1] == "nearby": - - print("infer nearby task") - b = time.time() - path = "hdfs://172.16.32.4:8020/strategy/esmm/" - # df = spark.read.format("tfrecords").load(path+"test_nearby/part-r-00000") - # df.show() +def df_sort(result,queue_name): + df = pd.DataFrame(result, columns=["uid", "city", "cid_id", "pctcvr"]) + print(df.head(10)) + df['uid1'] = df['uid'].apply(trans) + df['city1'] = df['city'].apply(trans) + df['cid_id1'] = df['cid_id'].apply(trans) + df2 = df.groupby(by=["uid1", "city1"]).apply(lambda x: x.sort_values(by="pctcvr", ascending=False)) \ + .reset_index(drop=True).groupby(by=["uid1", "city1"]).agg({'cid_id1': set_join}).reset_index(drop=False) + df2.columns = ["device_id", "city_id", queue_name] + df2["time"] = "2019-06-27" + return df2 - te_files = ["hdfs://172.16.32.4:8020/strategy/esmm/test_nearby/part-r-00000"] - - - #predict and sort - result = predict(te_files) - df = pd.DataFrame(result, columns=["uid", "city", "cid_id", "pctcvr"]) - print(df.head(10)) - df['uid1'] = df['uid'].apply(trans) - df['city1'] = df['city'].apply(trans) - df['cid_id1'] = df['cid_id'].apply(trans) - - df2 = df.groupby(by=["uid1", "city1"]).apply(lambda x: x.sort_values(by="pctcvr", ascending=False)) \ - .reset_index(drop=True).groupby(by=["uid1", "city1"]).agg({'cid_id1': set_join}).reset_index(drop=False) - df2.columns = ["device_id", "city_id", "nearby_queue"] - df2["time"] = "2019-06-25" - - - #update or insert - host = '172.16.40.158' - port = 4000 - user = 'root' - password = '3SYz54LS9#^9sBvC' - db = 'jerry_test' - charset = 'utf8' - - device_count = df2.shape[0] - con = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC',db='jerry_test') - cur = con.cursor() +def update_or_insert(df2,queue_name): + device_count = df2.shape[0] + con = pymysql.connect(host='172.16.40.158', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test', charset = 'utf8') + cur = con.cursor() + try: for i in range(0, device_count): - query = "INSERT INTO esmm_device_diary_queue_test (device_id, city_id, time,nearby_queue) VALUES(%s, %s, %s, %s) ON DUPLICATE KEY UPDATE device_id=%s, city_id=%s, time=%s, nearby_queue=%s" - cur.execute(query, (df2.device_id[i], df2.city_id[i], df2.time[i], df2.nearby_queue[i], df2.device_id[i], df2.city_id[i], df2.time[i], df2.nearby_queue[i])) + query = "INSERT INTO esmm_device_diary_queue_test (device_id, city_id, time,%s) VALUES(%s, %s, %s, %s) ON DUPLICATE KEY UPDATE device_id=%s, city_id=%s, time=%s, %s=%s" + cur.execute(query, (queue_name, df2.device_id[i], df2.city_id[i], df2.time[i], df2[queue_name][i], df2.device_id[i], df2.city_id[i],df2.time[i], df2[queue_name][i])) con.commit() con.close() + print("insert or update sucess") + except Exception as e: + print(e) + + +if __name__ == "__main__": - print("耗时(min):") - print((time.time()-b)/60) + b = time.time() + print(str(sys.argv[1]) + "task") + path = "hdfs://172.16.32.4:8020/strategy/esmm/" + te_files = path + "test_" + str(sys.argv[1]) + "/part-r-00000" + queue_name = str(sys.argv[1]) + "_queue" + result = predict(te_files) + df = df_sort(result, queue_name) + update_or_insert(df, queue_name) + print("耗时(秒):") + print((time.time()-b))