import json import os import random import sys sys.path.append(os.path.realpath(".")) import faiss import numpy as np from bert_serving.client import BertClient from utils.cache import redis_client_db from utils.es import get_answer_info_from_es from utils.files import MODEL_PATH def cos_sim(vector_a, vector_b): """ 计算两个向量之间的余弦相似度 :param vector_a: 向量 a :param vector_b: 向量 b :return: sim """ vector_a = np.mat(vector_a) vector_b = np.mat(vector_b) num = float(vector_a * vector_b.T) denom = np.linalg.norm(vector_a) * np.linalg.norm(vector_b) cos = num / denom sim = 0.5 + 0.5 * cos return sim def save_result(): bc = BertClient("172.16.44.82", check_length=False) index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity.index") faiss_index = faiss.read_index(index_path) count = 0 for item in get_answer_info_from_es(["id", "answer", "content_level"]): count += 1 id = int(item["_id"]) content = item["_source"]["answer"] content_level = str(item["_source"]["content_level"]) try: emb = np.array([bc.encode([content]).tolist()[0]]).astype("float32") D, I = faiss_index.search(emb, 10) distances = D.tolist()[0] ids = I.tolist()[0] res = [] for (index, i) in enumerate(distances): tmp_id = ids[index] if i <= 1.0 and tmp_id != id: res.append(str(tmp_id)) if res: data = "{}:{}:{}".format(content_level, str(id), ",".join(res)) print(data) except Exception as e: pass print("done") if __name__ == "__main__": # bc = BertClient("172.16.44.82", check_length=False) # sentence = """ # <p>做完私处整形手术,最好在一个月以后进行同房。因为过早同房,可能会对女性的私处造成损伤,甚至可能出现感染的情况。在恢复期间,女性可以适当的多吃水果蔬菜,多喝水,保持体内水分的充足。尽量不要吃刺激性过强的食物。在平时要注意私处的卫生,如果私处有瘙痒的情况,尽量不要用手直接的抓挠,坚持每天更换内裤,不要擅自用妇科清洗液,可以用温水轻轻擦拭私处。如果私处有不适感,需要及时去医院进行检查并治疗。</p> # """ # sen1_em = bc.encode([sentence]) # sen2_em = bc.encode([sentence]) # print(type(sen1_em), sen1_em) # print(sen2_em) # print(cos_sim(sen1_em, sen2_em)) # level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []} # count = 0 # embedding_dict = {} # for item in get_answer_info_from_es(["id", "answer", "content_level"]): # count += 1 # id = int(item["_id"]) # print(count, id) # content = item["_source"]["answer"] # content_level = str(item["_source"]["content_level"]) # level_dict[content_level].append(id) # try: # embedding_dict[id] = bc.encode([content]).tolist()[0] # except Exception as e: # pass # # redis_client_db.hmset("answer:level_dict", json.dumps(level_dict)) # tmp_tuple = random.choice(list(embedding_dict.items())) # print(tmp_tuple) # answer_ids = np.array(list(embedding_dict.keys())).astype("int") # answer_embeddings = np.array(list(embedding_dict.values())).astype("float32") # print(answer_embeddings.shape) # index = faiss.IndexFlatL2(answer_embeddings.shape[1]) # print("trained: " + str(index.is_trained)) # index2 = faiss.IndexIDMap(index) # index2.add_with_ids(answer_embeddings, answer_ids) # print("trained: " + str(index2.is_trained)) # print("total index: " + str(index2.ntotal)) # index_path = os.path.join(MODEL_PATH, "faiss_answer_similarity.index") # faiss.write_index(index2, index_path) # print(index_path) # id = tmp_tuple[0] # emb = np.array([embedding_dict[id]]).astype("float32") # print(emb) # D, I = index2.search(emb, 10) # distances = D.tolist()[0] # ids = I.tolist()[0] # res = [] # for (index, i) in enumerate(distances): # if i <= 1.0: # res.append(ids[index]) # print(res, "\n") save_result()