Commit 04dccced authored by 赵威's avatar 赵威

try write answer

parent 9491a0b9
......@@ -8,6 +8,7 @@ sys.path.append(os.path.realpath("."))
import faiss
import numpy as np
from bert_serving.client import BertClient
from bs4 import BeautifulSoup
from utils.cache import redis_client_db
from utils.es import get_answer_info_from_es
from utils.files import MODEL_PATH
......@@ -29,10 +30,70 @@ def cos_sim(vector_a, vector_b):
return sim
def write_result():
bc = BertClient("172.16.44.82", check_length=False)
count = 0
embedding_dict = {}
for item in get_answer_info_from_es(["id", "answer", "content_level"]):
count += 1
if count < 1000:
try:
id = int(item["_id"])
soup = BeautifulSoup(item["_source"]["answer"], "html.parser")
content = soup.get_text()
# content_level = str(item["_source"]["content_level"])
print(count, id, content)
embedding_dict[id] = bc.encode([content]).tolist()[0]
except Exception as e:
pass
answer_ids = np.array(list(embedding_dict.keys())).astype("int")
answer_embeddings = np.array(list(embedding_dict.values())).astype("float32")
print(answer_embeddings.shape)
index = faiss.IndexFlatL2(answer_embeddings.shape[1])
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(answer_embeddings, answer_ids)
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
# index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity6.index")
# faiss.write_index(index2, index_path)
# print(index_path)
# id = tmp_tuple[0]
# emb = np.array([embedding_dict[id]]).astype("float32")
# print(emb)
# D, I = index2.search(emb, 10)
# distances = D.tolist()[0]
# ids = I.tolist()[0]
# res = []
# for (index, i) in enumerate(distances):
# if i <= 1.0:
# res.append(ids[index])
# print(res, "\n")
for (id, emb) in embedding_dict.items():
emb = np.array([emb]).astype("float32")
D, I = index2.search(emb, 10)
distances = D.tolist()[0]
ids = I.tolist()[0]
res = []
for (index, i) in enumerate(distances):
if i <= 1.0:
res.append(ids[index])
if res:
data = "{}:{}".format(str(id), ",".join(res))
print(data)
def save_result():
bc = BertClient("172.16.44.82", check_length=False)
index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity.index")
index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity6.index")
faiss_index = faiss.read_index(index_path)
count = 0
......@@ -73,51 +134,4 @@ if __name__ == "__main__":
# print(cos_sim(sen1_em, sen2_em))
# level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []}
# count = 0
# embedding_dict = {}
# for item in get_answer_info_from_es(["id", "answer", "content_level"]):
# count += 1
# id = int(item["_id"])
# print(count, id)
# content = item["_source"]["answer"]
# content_level = str(item["_source"]["content_level"])
# level_dict[content_level].append(id)
# try:
# embedding_dict[id] = bc.encode([content]).tolist()[0]
# except Exception as e:
# pass
# # redis_client_db.hmset("answer:level_dict", json.dumps(level_dict))
# tmp_tuple = random.choice(list(embedding_dict.items()))
# print(tmp_tuple)
# answer_ids = np.array(list(embedding_dict.keys())).astype("int")
# answer_embeddings = np.array(list(embedding_dict.values())).astype("float32")
# print(answer_embeddings.shape)
# index = faiss.IndexFlatL2(answer_embeddings.shape[1])
# print("trained: " + str(index.is_trained))
# index2 = faiss.IndexIDMap(index)
# index2.add_with_ids(answer_embeddings, answer_ids)
# print("trained: " + str(index2.is_trained))
# print("total index: " + str(index2.ntotal))
# index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity.index")
# faiss.write_index(index2, index_path)
# print(index_path)
# id = tmp_tuple[0]
# emb = np.array([embedding_dict[id]]).astype("float32")
# print(emb)
# D, I = index2.search(emb, 10)
# distances = D.tolist()[0]
# ids = I.tolist()[0]
# res = []
# for (index, i) in enumerate(distances):
# if i <= 1.0:
# res.append(ids[index])
# print(res, "\n")
save_result()
write_result()
......@@ -32,6 +32,7 @@ tensorflow==2.3.1
keras==2.4.3
protobuf==3.13.0
beautifulsoup4==4.9.3
bert-serving-server==1.10.0
bert-serving-client==1.10.0
......
......@@ -80,6 +80,7 @@ def get_tractate_info_from_es(fields=["id"]):
return results
# [6, 5, 4, 3.5, 3]
def get_answer_info_from_es(fields=["id"]):
q = {
"query": {
......@@ -90,7 +91,7 @@ def get_answer_info_from_es(fields=["id"]):
}
}, {
"terms": {
"content_level": [6, 5, 4, 3.5, 3]
"content_level": [6]
}
}, {
"range": {
......
......@@ -22,15 +22,15 @@ def get_user_portrait_tag3_from_redis(device_id, limit_score=0, tags_num=5):
portrait_key = get_user_portrait_tag3_redis_key(device_id)
if redis_client2.exists(portrait_key):
user_portrait = json.loads(redis_client2.get(portrait_key))
first_demands = items_gt_score(user_portrait.get("first_demands", {}))
second_demands = items_gt_score(user_portrait.get("second_demands", {}))
first_solutions = items_gt_score(user_portrait.get("first_solutions", {}))
second_solutions = items_gt_score(user_portrait.get("second_solutions", {}))
first_positions = items_gt_score(user_portrait.get("first_positions", {}))
second_positions = items_gt_score(user_portrait.get("second_positions", {}))
projects = items_gt_score(user_portrait.get("projects", {}))
anecdote_tags = items_gt_score(user_portrait.get("anecdote_tags", {}))
business_tags = items_gt_score(user_portrait.get("business_tags", {}))
first_demands = items_gt_score(user_portrait.get("first_demands", {}), limit_score=limit_score, tags_num=tags_num)
second_demands = items_gt_score(user_portrait.get("second_demands", {}), limit_score=limit_score, tags_num=tags_num)
first_solutions = items_gt_score(user_portrait.get("first_solutions", {}), limit_score=limit_score, tags_num=tags_num)
second_solutions = items_gt_score(user_portrait.get("second_solutions", {}), limit_score=limit_score, tags_num=tags_num)
first_positions = items_gt_score(user_portrait.get("first_positions", {}), limit_score=limit_score, tags_num=tags_num)
second_positions = items_gt_score(user_portrait.get("second_positions", {}), limit_score=limit_score, tags_num=tags_num)
projects = items_gt_score(user_portrait.get("projects", {}), limit_score=limit_score, tags_num=tags_num)
anecdote_tags = items_gt_score(user_portrait.get("anecdote_tags", {}), limit_score=limit_score, tags_num=tags_num)
business_tags = items_gt_score(user_portrait.get("business_tags", {}), limit_score=limit_score, tags_num=tags_num)
return {
"first_demands": first_demands,
"second_demands": second_demands,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment