Commit 96420a74 authored by 张彦钊's avatar 张彦钊

multi hot

parent f8f8ef62
...@@ -78,6 +78,22 @@ def get_data(): ...@@ -78,6 +78,22 @@ def get_data():
return validate_date,value_map return validate_date,value_map
def multi_hot(df,i,n):
ItemID_set = set()
for i in df[i].unique():
ItemID_set.update(set(i.split(",")))
ItemID2int = dict(zip(list(ItemID_set),list(range(n+1,n+1+len(ItemID_set),1))))
ItemID_map = {val: [ItemID2int[row] for row in val.split(',')] \
for ii, val in enumerate(set(df[1]))}
ItemID_map_max_len = 3
for key in ItemID_map:
for cnt in range(ItemID_map_max_len - len(ItemID_map[key])):
ItemID_map[key].insert(len(ItemID_map[key]) + cnt, 88)
df[i] = df[i].map(ItemID_map)
def write_csv(df,name,n): def write_csv(df,name,n):
for i in range(0, df.shape[0], n): for i in range(0, df.shape[0], n):
if i == 0: if i == 0:
......
import pandas as pd
import pymysql
import datetime
def con_sql(db,sql):
cursor = db.cursor()
try:
cursor.execute(sql)
result = cursor.fetchall()
df = pd.DataFrame(list(result))
except Exception:
print("发生异常", Exception)
df = pd.DataFrame()
finally:
db.close()
return df
def multi():
db = pymysql.connect(host='10.66.157.22', port=4000, user='root', passwd='3SYz54LS9#^9sBvC', db='jerry_test')
sql = "select diary_id,level2_ids from diary_feat"
df = con_sql(db, sql).dropna()
print(df.shape)
df = df.rename(columns={0: "cid", 1: "level"})
df["l1"] = "lost"
df["l2"] = "lost"
df["l3"] = "lost"
for i in list(df["level"].unique()):
l = i.split(",")
if len(l) == 3:
df.loc[df["level"] == i, ["l1"]] = l[0]
df.loc[df["level"] == i, ["l2"]] = l[1]
df.loc[df["level"] == i, ["l3"]] = l[2]
elif len(l) == 2:
df.loc[df["level"] == i, ["l1"]] = l[0]
df.loc[df["level"] == i, ["l2"]] = l[1]
elif len(l) == 1:
df.loc[df["level"] == i, ["l1"]] = l[0]
df = df.drop("level",axis=1)
print(df.head())
a = list(df["l1"].unique())
b = list(df["l2"].unique())
c = list(df["l3"].unique())
print(len(a))
print(a)
print(len(b))
print(b)
print(len(c))
print(c)
if __name__ == "__main__":
multi()
...@@ -43,4 +43,6 @@ ${PYTHON_PATH} ${MODEL_PATH}/train.py --ctr_task_wgt=0.3 --learning_rate=0.0001 ...@@ -43,4 +43,6 @@ ${PYTHON_PATH} ${MODEL_PATH}/train.py --ctr_task_wgt=0.3 --learning_rate=0.0001
echo "infer nearby..." echo "infer nearby..."
${PYTHON_PATH} ${MODEL_PATH}/train.py --ctr_task_wgt=0.3 --learning_rate=0.0001 --deep_layers=256,128 --dropout=0.8,0.5 --optimizer=Adam --num_epochs=1 --embedding_size=16 --batch_size=1024 --field_size=10 --feature_size=2000 --l2_reg=0.005 --log_steps=100 --num_threads=36 --model_dir=${DATA_PATH}/model_ckpt/DeepCvrMTL/ --data_dir=${DATA_PATH}/nearby --task_type=infer > ${DATA_PATH}/infer.log ${PYTHON_PATH} ${MODEL_PATH}/train.py --ctr_task_wgt=0.3 --learning_rate=0.0001 --deep_layers=256,128 --dropout=0.8,0.5 --optimizer=Adam --num_epochs=1 --embedding_size=16 --batch_size=1024 --field_size=10 --feature_size=2000 --l2_reg=0.005 --log_steps=100 --num_threads=36 --model_dir=${DATA_PATH}/model_ckpt/DeepCvrMTL/ --data_dir=${DATA_PATH}/nearby --task_type=infer > ${DATA_PATH}/infer.log
echo "sort and 2sql"
${PYTHON_PATH} ${OLD_PATH}/Model_pipline/sort_and_2sql.py
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment