Commit 805509ea authored by 赵威's avatar 赵威

add answer model

parent c61d32b0
......@@ -13,6 +13,12 @@ from utils.spark import get_spark
answer_click_ids_model_path = os.path.join(MODEL_PATH, "answer_click_ids_item2vec_model")
try:
ANSWER_CLICK_IDS_MODEL = word2vec.Word2Vec.load(answer_click_ids_model_path)
ANSWER_CLICK_IDS = set(ANSWER_CLICK_IDS_MODEL.wv.vocab.keys())
except Exception as e:
print(e)
def get_answer_click_data(spark, start, end):
reg = r"""^\\d+$"""
......@@ -117,6 +123,34 @@ def get_answer_click_data(spark, start, end):
return df
def get_device_click_answer_ids_dict(click_df):
res = defaultdict(list)
cols = click_df.orderBy("partition_date", ascending=False).collect()
for i in cols:
card_id = i["card_id"]
session_id = i["app_session_id"]
if card_id not in res[session_id]:
res[session_id].append(card_id)
return res
def save_clicked_answer_ids_item2vec():
click_ids = []
with open(os.path.join(DATA_PATH, "click_answer_ids.csv"), "r") as f:
data = f.readlines()
for i in data:
tmp = i.split("|")
# app_session_id = tmp[0]
ids = tmp[1].rstrip("\n").split(",")
click_ids.append(ids)
model = Word2Vec(click_ids, hs=0, min_count=3, workers=multiprocessing.cpu_count(), iter=10)
print(model)
print(len(click_ids))
model.save(answer_click_ids_model_path)
return model
if __name__ == "__main__":
begin_time = time.time()
......@@ -125,4 +159,19 @@ if __name__ == "__main__":
click_df.show(5, False)
print(click_df.count())
res_dict = get_device_click_answer_ids_dict(click_df)
with open(os.path.join(DATA_PATH, "click_answer_ids.csv"), "w") as f:
for (k, v) in res_dict.items():
if v:
f.write("{}|{}\n".format(k, ",".join([str(x) for x in v])))
print("write data done.")
save_clicked_answer_ids_item2vec()
for id in ["986424", "744910", "703622"]:
print(ANSWER_CLICK_IDS_MODEL.wv.most_similar(id, topn=5))
print("total cost: {:.2f}mins".format((time.time() - begin_time) / 60))
# spark-submit --master yarn --deploy-mode client --queue root.strategy --driver-memory 16g --executor-memory 1g --executor-cores 1 --num-executors 70 --conf spark.default.parallelism=100 --conf spark.storage.memoryFraction=0.5 --conf spark.shuffle.memoryFraction=0.3 --conf spark.locality.wait=0 --jars /srv/apps/tispark-core-2.1-SNAPSHOT-jar-with-dependencies.jar,/srv/apps/spark-connector_2.11-1.9.0-rc2.jar,/srv/apps/mysql-connector-java-5.1.38.jar /srv/apps/strategy_embedding/word_vector/answer.py
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment