Commit 6d492c13 authored by 赵威's avatar 赵威

save index

parent 086d0f85
......@@ -14,9 +14,9 @@ DATA_PATH = os.path.join(base_dir, "_data")
if __name__ == "__main__":
spark = get_spark("personas_vector_data")
card_type = "user_post"
days = 5 # TODO days 30
start, end = get_ndays_before_no_minus(days), get_ndays_before_no_minus(1)
# card_type = "user_post"
# days = 5 # TODO days 30
# start, end = get_ndays_before_no_minus(days), get_ndays_before_no_minus(1)
# click_df = get_click_data(spark, card_type, start, end)
# save_df_to_csv(click_df, "personas_tractate_click.csv")
......
......@@ -7,10 +7,13 @@ sys.path.append(os.path.realpath("."))
import multiprocessing
import faiss
import numpy as np
import pandas as pd
from gensim.models import Word2Vec, word2vec
from utils.defs import nth_element
from utils.files import get_df
from utils.cache import redis_client_db
def device_tractate_fe():
......@@ -40,17 +43,68 @@ if __name__ == "__main__":
tags_data = tractate_tags_df["business_tags"].to_list()
model = tractate_business_tags_word2vec(tags_data)
# all business tags
tags_set = set()
for i in tags_data:
for j in i:
tags_set.add(j)
# tag vector dict
tags_vector_dict = {}
for i in tags_set:
tags_vector_dict[i] = json.dumps(model.wv.get_vector(i))
try:
# vec = json.dumps(model.wv.get_vector(i).tolist())
vec = model.wv.get_vector(i)
tags_vector_dict[i] = vec
except Exception as e:
pass
redis_client_db.hmset("personas_tags_embedding", tags_vector_dict)
print(random.choice(tags_vector_dict.items()))
print(len(tags_vector_dict.items()))
# print(random.choice(list(tags_vector_dict.items())))
# for i in ["自体脂肪面部年轻化", "自体脂肪填充面部", "自体脂肪全面部填充", "自体脂肪面部填充", "鼻综合", "鼻部综合"]:
# print(model.wv.most_similar(i))
# print(model.wv.get_vector(i))
# tractate vector dict
tractate_vector_dict = {}
for _, row in tractate_tags_df.iterrows():
vecs = []
for i in row["business_tags"]:
vec = tags_vector_dict.get(i, np.array([]))
if vec.any():
vecs.append(vec)
if vecs:
tractate_vector_dict[row["tractate_id"]] = np.average(vecs, axis=0)
print(len(tractate_vector_dict.items()))
# print(random.choice(list(tractate_vector_dict.items())))
# tractate vector index
tractate_ids = np.array(list(tractate_vector_dict.keys())).astype("int")
tractate_embeddings = np.array(list(tractate_vector_dict.values())).astype("float32")
index = faiss.IndexFlatL2(tractate_embeddings.shape[1])
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(tractate_embeddings, tractate_ids)
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
base_dir = os.getcwd()
model_dir = os.path.join(base_dir, "_models")
index_path = os.path.join(model_dir, "faiss_personas_vector.index")
faiss.write_index(index2, index_path)
print(index_path)
# device vector
# for _, row in device_tags_df.iterrows():
# vecs = []
# for i in row["business_tags"]:
# vec = tags_vector_dict.get(i, np.array([]))
# if vec.any():
# vecs.append(vec)
# if vecs:
# t = np.array([np.average(vecs, axis=0)]).astype("float32")
# D, I = index2.search(t, 10)
# print(row["cl_id"], row["business_tags"])
# print(I)
# curl "http://172.16.31.17:9000/gm-dbmw-tractate-read/_search?pretty" -d '{"query": {"term": {"id": "10269"}}, "_source": {"include": ["content", "portrait_tag_name"]}}'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment