Commit f3381f35 authored by 高雅喆's avatar 高雅喆

Merge branch 'master' of git.wanmeizhensuo.com:ML/ffm-baseline

add esmm model
parents 34e547b3 96d0a5db
......@@ -136,14 +136,13 @@ class multiFFMFormatPandas:
return False
def get_data():
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select max(stat_date) from esmm_train_data"
validate_date = con_sql(db, sql)[0].values.tolist()[0]
print("validate_date:"+validate_date)
temp = datetime.datetime.strptime(validate_date, "%Y-%m-%d")
start = (temp - datetime.timedelta(days=2)).strftime("%Y-%m-%d")
start = (temp - datetime.timedelta(days=14)).strftime("%Y-%m-%d")
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select device_id,y,z,stat_date,ucity_id,cid_id,clevel1_id,ccity_name from esmm_train_data " \
"where stat_date >= '{}'".format(start)
......@@ -160,14 +159,7 @@ def get_data():
df["y"] = df["stat_date"].str.cat([df["device_id"].values.tolist(),df["ucity_id"].values.tolist(), df["cid_id"].values.tolist(),
df["y"].values.tolist(),df["z"].values.tolist()], sep=",")
df = df.drop("z", axis=1)
print(df.head(2))
print("shape")
print(df.shape)
df = pd.merge(df,get_statistics(),how='left',on = "device_id").fillna(0)
print("merge")
# print(df.head())
print("shape")
print(df.shape)
df = df.drop("device_id", axis=1)
print(df.head())
return df,validate_date,ucity_id,cid
......@@ -191,17 +183,20 @@ def transform(a,validate_date):
train = df[df["stat_date"] != validate_date]
train = train.drop("stat_date",axis=1)
# print("train shape")
# print(train.shape)
test = df[df["stat_date"] == validate_date]
test = test.drop("stat_date",axis=1)
# print("test shape")
# print(test.shape)
# train.to_csv(path+"train.csv",index=None)
# test.to_csv(path + "test.csv", index=None)
print("train shape")
print(train.shape)
yconnect = create_engine('mysql+pymysql://root:3SYz54LS9#^9sBvC@10.66.157.22:4000/jerry_test?charset=utf8')
pd.io.sql.to_sql(train, "train_zhao", yconnect, schema='jerry_test', if_exists='replace', index=False)
print("train insert done")
pd.io.sql.to_sql(test, "test_zhao", yconnect, schema='jerry_test', if_exists='replace', index=False)
print("test insert done")
return model
# yconnect = create_engine('mysql+pymysql://root:3SYz54LS9#^9sBvC@10.66.157.22:4000/jerry_test?charset=utf8')
# n = 100000
# for i in range(0,df.shape[0],n):
# print(i)
......@@ -233,7 +228,6 @@ def get_predict_set(ucity_id, cid,model):
df = con_sql(db, sql)
df = df.rename(columns={0: "device_id", 1: "y", 2: "z", 3: "stat_date", 4: "ucity_id", 5: "cid_id",
6: "clevel1_id", 7: "ccity_name",8:"label"})
print("df ok")
df = df[df["cid_id"].isin(cid)]
df = df[df["ucity_id"].isin(ucity_id)]
print(df.shape)
......@@ -269,13 +263,17 @@ def get_predict_set(ucity_id, cid,model):
native_pre = native_pre.drop("label", axis=1)
print("native_pre shape")
print(native_pre.shape)
native_pre.to_csv(path + "native_pre.csv", index=None)
nearby_pre = df[df["label"] == "1"]
nearby_pre = nearby_pre.drop("label", axis=1)
print("nearby_pre shape")
print(nearby_pre.shape)
nearby_pre.to_csv(path + "nearby_pre.csv", index=None)
yconnect = create_engine('mysql+pymysql://root:3SYz54LS9#^9sBvC@10.66.157.22:4000/jerry_test?charset=utf8')
pd.io.sql.to_sql(native_pre, "native_zhao", yconnect, schema='jerry_test', if_exists='replace', index=False)
print("train insert done")
pd.io.sql.to_sql(nearby_pre, "nearby_zhao", yconnect, schema='jerry_test', if_exists='replace', index=False)
print("test insert done")
# df = pd.DataFrame(df)
......@@ -294,8 +292,12 @@ def get_predict_set(ucity_id, cid,model):
if __name__ == "__main__":
path = "/home/gmuser/ffm/"
a = time.time()
df, validate_date, ucity_id, cid = get_data()
model = transform(df, validate_date)
get_predict_set(ucity_id, cid,model)
b = time.time()
print("cost(分钟)")
print((b-a)/60)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment