Commit 3379c89b authored by 高雅喆's avatar 高雅喆

Merge branch 'master' of git.wanmeizhensuo.com:ML/ffm-baseline

bug fix
parents 403e0524 13f4ccb4
...@@ -138,7 +138,7 @@ class multiFFMFormatPandas: ...@@ -138,7 +138,7 @@ class multiFFMFormatPandas:
def get_data(): def get_data():
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test') db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data" sql = "select max(stat_date) from esmm_train_test"
validate_date = con_sql(db, sql)[0].values.tolist()[0] validate_date = con_sql(db, sql)[0].values.tolist()[0]
print("validate_date:" + validate_date) print("validate_date:" + validate_date)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d") temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
...@@ -174,7 +174,7 @@ def get_data(): ...@@ -174,7 +174,7 @@ def get_data():
def transform(a,validate_date): def transform(a,validate_date):
model = multiFFMFormatPandas() model = multiFFMFormatPandas()
df = model.fit_transform(a, y="y", n=160000, processes=26) df = model.fit_transform(a, y="y", n=160000, processes=22)
df = pd.DataFrame(df) df = pd.DataFrame(df)
df["stat_date"] = df[0].apply(lambda x: x.split(",")[0]) df["stat_date"] = df[0].apply(lambda x: x.split(",")[0])
df["device_id"] = df[0].apply(lambda x: x.split(",")[1]) df["device_id"] = df[0].apply(lambda x: x.split(",")[1])
...@@ -194,8 +194,8 @@ def transform(a,validate_date): ...@@ -194,8 +194,8 @@ def transform(a,validate_date):
test = test.drop("stat_date",axis=1) test = test.drop("stat_date",axis=1)
# print("train shape") # print("train shape")
# print(train.shape) # print(train.shape)
# train.to_csv(path + "train.csv", sep="\t", index=False) train.to_csv(path + "tr.csv", sep="\t", index=False)
# test.to_csv(path + "test.csv", sep="\t", index=False) test.to_csv(path + "va.csv", sep="\t", index=False)
return model return model
...@@ -245,20 +245,20 @@ def get_predict_set(ucity_id, cid,model): ...@@ -245,20 +245,20 @@ def get_predict_set(ucity_id, cid,model):
native_pre = df[df["label"] == "0"] native_pre = df[df["label"] == "0"]
native_pre = native_pre.drop("label", axis=1) native_pre = native_pre.drop("label", axis=1)
native_pre.to_csv(path+"native_pre.csv",sep="\t",index=False) native_pre.to_csv(path+"native.csv",sep="\t",index=False)
# print("native_pre shape") # print("native_pre shape")
# print(native_pre.shape) # print(native_pre.shape)
nearby_pre = df[df["label"] == "1"] nearby_pre = df[df["label"] == "1"]
nearby_pre = nearby_pre.drop("label", axis=1) nearby_pre = nearby_pre.drop("label", axis=1)
nearby_pre.to_csv(path + "nearby_pre.csv", sep="\t", index=False) nearby_pre.to_csv(path + "nearby.csv", sep="\t", index=False)
# print("nearby_pre shape") # print("nearby_pre shape")
# print(nearby_pre.shape) # print(nearby_pre.shape)
if __name__ == "__main__": if __name__ == "__main__":
path = "/home/gmuser/ffm/" path = "/home/gaoyazhe/esmm/data/"
a = time.time() a = time.time()
df, validate_date, ucity_id, cid = get_data() df, validate_date, ucity_id, cid = get_data()
model = transform(df, validate_date) model = transform(df, validate_date)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment