import json import os import random import sys sys.path.append(os.path.realpath(".")) import faiss import numpy as np from bert_serving.client import BertClient from utils.cache import redis_client_db from utils.es import get_diary_info_from_es from utils.files import DATA_PATH, MODEL_PATH def save_result(): bc = BertClient("172.16.44.82", check_length=False) index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index") faiss_index = faiss.read_index(index_path) count = 0 for item in get_diary_info_from_es(["id", "answer", "content_level"]): count += 1 id = int(item["_id"]) content = item["_source"]["answer"] content_level = str(item["_source"]["content_level"]) try: emb = np.array([bc.encode([content]).tolist()[0]]).astype("float32") D, I = faiss_index.search(emb, 10) distances = D.tolist()[0] ids = I.tolist()[0] res = [] for (index, i) in enumerate(distances): tmp_id = ids[index] if i <= 1.0 and tmp_id != id: res.append(str(tmp_id)) if res: data = "{}:{}:{}".format(content_level, str(id), ",".join(res)) print(data) except Exception as e: pass print("done") if __name__ == "__main__": # bc = BertClient("172.16.44.82", check_length=False) # level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []} # count = 0 # embedding_dict = {} # for item in get_diary_info_from_es(["id", "answer", "content_level"]): # count += 1 # id = int(item["_id"]) # print(count, id) # content = item["_source"]["answer"] # content_level = str(item["_source"]["content_level"]) # level_dict[content_level].append(id) # try: # embedding_dict[id] = bc.encode([content]).tolist()[0] # except Exception as e: # pass # # redis_client_db.hmset("diary:level_dict", json.dumps(level_dict)) # tmp_tuple = random.choice(list(embedding_dict.items())) # print(tmp_tuple) # diary_ids = np.array(list(embedding_dict.keys())).astype("int") # diary_embeddings = np.array(list(embedding_dict.values())).astype("float32") # print(diary_embeddings.shape) # index = faiss.IndexFlatL2(diary_embeddings.shape[1]) # print("trained: " + str(index.is_trained)) # index2 = faiss.IndexIDMap(index) # index2.add_with_ids(diary_embeddings, diary_ids) # print("trained: " + str(index2.is_trained)) # print("total index: " + str(index2.ntotal)) # index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index") # faiss.write_index(index2, index_path) # print(index_path) # id = tmp_tuple[0] # emb = np.array([embedding_dict[id]]).astype("float32") # print(emb) # D, I = index2.search(emb, 10) # distances = D.tolist()[0] # ids = I.tolist()[0] # res = [] # for (index, i) in enumerate(distances): # if i <= 1.0: # res.append(ids[index]) # print(res, "\n") # print(ids, "\n") # print(D) save_result()