Commit 92a79400 authored by 赵威's avatar 赵威

get answer result

parent f2e2a137
...@@ -86,6 +86,6 @@ if __name__ == "__main__": ...@@ -86,6 +86,6 @@ if __name__ == "__main__":
ids = I.tolist()[0] ids = I.tolist()[0]
res = [] res = []
for (index, i) in enumerate(distances): for (index, i) in enumerate(distances):
if i <= 0.1: if i <= 1.0:
res.append(ids[index]) res.append(ids[index])
print(res, "\n") print(res, "\n")
...@@ -12,54 +12,85 @@ from utils.cache import redis_client_db ...@@ -12,54 +12,85 @@ from utils.cache import redis_client_db
from utils.es import get_diary_info_from_es from utils.es import get_diary_info_from_es
from utils.files import MODEL_PATH from utils.files import MODEL_PATH
if __name__ == "__main__":
def save_result():
bc = BertClient("172.16.44.82", check_length=False) bc = BertClient("172.16.44.82", check_length=False)
level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []} index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index")
index = faiss.read_index(index_path)
print(index)
# level_dict = {"6": set([]), "5": set([]), "4": set([]), "3.5": set([]), "3": set([])}
count = 0 count = 0
embedding_dict = {}
for item in get_diary_info_from_es(["id", "answer", "content_level"]): for item in get_diary_info_from_es(["id", "answer", "content_level"]):
count += 1 count += 1
id = int(item["_id"]) id = int(item["_id"])
print(count, id)
content = item["_source"]["answer"] content = item["_source"]["answer"]
content_level = str(item["_source"]["content_level"]) content_level = str(item["_source"]["content_level"])
level_dict[content_level].append(id) # level_dict[content_level].add(id)
try: try:
embedding_dict[id] = bc.encode([content]).tolist()[0] emb = np.array([bc.encode([content]).tolist()[0]]).astype("float32")
D, I = index.search(emb, 10)
distances = D.tolist()[0]
ids = I.tolist()[0]
res = []
for (index, i) in enumerate(distances):
if i <= 1.0:
res.append(ids[index])
print(count, id, content_level, res)
except Exception as e: except Exception as e:
pass print(e)
# redis_client_db.hmset("diary:level_dict", json.dumps(level_dict)) if __name__ == "__main__":
# bc = BertClient("172.16.44.82", check_length=False)
tmp_tuple = random.choice(list(embedding_dict.items())) # level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []}
print(tmp_tuple) # count = 0
diary_ids = np.array(list(embedding_dict.keys())).astype("int") # embedding_dict = {}
diary_embeddings = np.array(list(embedding_dict.values())).astype("float32") # for item in get_diary_info_from_es(["id", "answer", "content_level"]):
print(diary_embeddings.shape) # count += 1
# id = int(item["_id"])
# print(count, id)
# content = item["_source"]["answer"]
# content_level = str(item["_source"]["content_level"])
# level_dict[content_level].append(id)
# try:
# embedding_dict[id] = bc.encode([content]).tolist()[0]
# except Exception as e:
# pass
index = faiss.IndexFlatL2(diary_embeddings.shape[1]) # # redis_client_db.hmset("diary:level_dict", json.dumps(level_dict))
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index) # tmp_tuple = random.choice(list(embedding_dict.items()))
index2.add_with_ids(diary_embeddings, diary_ids) # print(tmp_tuple)
print("trained: " + str(index2.is_trained)) # diary_ids = np.array(list(embedding_dict.keys())).astype("int")
print("total index: " + str(index2.ntotal)) # diary_embeddings = np.array(list(embedding_dict.values())).astype("float32")
# print(diary_embeddings.shape)
index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index") # index = faiss.IndexFlatL2(diary_embeddings.shape[1])
faiss.write_index(index2, index_path) # print("trained: " + str(index.is_trained))
print(index_path)
# index2 = faiss.IndexIDMap(index)
# index2.add_with_ids(diary_embeddings, diary_ids)
# print("trained: " + str(index2.is_trained))
# print("total index: " + str(index2.ntotal))
# index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index")
# faiss.write_index(index2, index_path)
# print(index_path)
# id = tmp_tuple[0]
# emb = np.array([embedding_dict[id]]).astype("float32")
# print(emb)
# D, I = index2.search(emb, 10)
# distances = D.tolist()[0]
# ids = I.tolist()[0]
# res = []
# for (index, i) in enumerate(distances):
# if i <= 1.0:
# res.append(ids[index])
# print(res, "\n")
# print(ids, "\n")
# print(D)
id = tmp_tuple[0] save_result()
emb = np.array([embedding_dict[id]]).astype("float32")
print(emb)
D, I = index2.search(emb, 10)
distances = D.tolist()[0]
ids = I.tolist()[0]
res = []
for (index, i) in enumerate(distances):
if i <= 0.5:
res.append(ids[index])
print(res, "\n")
print(ids, "\n")
print(D)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment