import json import os import random import sys sys.path.append(os.path.realpath(".")) import faiss import numpy as np from bert_serving.client import BertClient from bs4 import BeautifulSoup from utils.cache import redis_client_db from utils.es import get_answer_info_from_es from utils.files import MODEL_PATH def cos_sim(vector_a, vector_b): """ 计算两个向量之间的余弦相似度 :param vector_a: 向量 a :param vector_b: 向量 b :return: sim """ vector_a = np.mat(vector_a) vector_b = np.mat(vector_b) num = float(vector_a * vector_b.T) denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b) cos = num / denom sim = 0.5 + 0.5 * cos return sim def write_result(): bc = BertClient("172.16.44.82", check_length=False) count = 0 content_level_dict = {} embedding_dict = {} for item in get_answer_info_from_es(["id", "answer", "content_level"]): count += 1 try: id = int(item["_id"]) soup = BeautifulSoup(item["_source"]["answer"], "html.parser") content = soup.get_text() content_level = str(item["_source"]["content_level"]) content_level_dict[id] = content_level # print(count, id, content) embedding_dict[id] = bc.encode([content]).tolist()[0] except Exception as e: pass answer_ids = np.array(list(embedding_dict.keys())).astype("int") answer_embeddings = np.array(list(embedding_dict.values())).astype("float32") print(answer_embeddings.shape) index = faiss.IndexFlatL2(answer_embeddings.shape[1]) index2 = faiss.IndexIDMap(index) index2.add_with_ids(answer_embeddings, answer_ids) print("trained: " + str(index2.is_trained)) print("total index: " + str(index2.ntotal)) # index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity6.index") # faiss.write_index(index2, index_path) # print(index_path) for (id, emb) in embedding_dict.items(): emb = np.array([emb]).astype("float32") D, I = index2.search(emb, 10) distances = D.tolist()[0] ids = I.tolist()[0] res = [] for (index, i) in enumerate(distances): tmp_id = ids[index] if i <= 1.0 and tmp_id != id: res.append(str(tmp_id)) if res: data = "{}:{}:{}".format(content_level_dict.get(id, "-1"), str(id), ",".join(res)) print(data) def save_result(): bc = BertClient("172.16.44.82", check_length=False) index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity6.index") faiss_index = faiss.read_index(index_path) count = 0 for item in get_answer_info_from_es(["id", "answer", "content_level"]): count += 1 id = int(item["_id"]) content = item["_source"]["answer"] content_level = str(item["_source"]["content_level"]) try: emb = np.array([bc.encode([content]).tolist()[0]]).astype("float32") D, I = faiss_index.search(emb, 10) distances = D.tolist()[0] ids = I.tolist()[0] res = [] for (index, i) in enumerate(distances): tmp_id = ids[index] if i <= 1.0 and tmp_id != id: res.append(str(tmp_id)) if res: data = "{}:{}:{}".format(content_level, str(id), ",".join(res)) print(data) except Exception as e: pass print("done") if __name__ == "__main__": write_result()