Commit e337d89f authored by 赵威's avatar 赵威

save index

parent d6ea66b1
import os import os
import sys
import random import random
import sys
sys.path.append(os.path.realpath(".")) sys.path.append(os.path.realpath("."))
import faiss
import numpy as np import numpy as np
from bert_serving.client import BertClient from bert_serving.client import BertClient
from utils.es import es_scan, get_answer_info_from_es from utils.es import es_scan, get_answer_info_from_es
import faiss
def cos_sim(vector_a, vector_b): def cos_sim(vector_a, vector_b):
...@@ -43,15 +44,13 @@ if __name__ == "__main__": ...@@ -43,15 +44,13 @@ if __name__ == "__main__":
count = 0 count = 0
embedding_dict = {} embedding_dict = {}
for item in get_answer_info_from_es(["id", "answer", "content_level"]): for item in get_answer_info_from_es(["id", "answer", "content_level"]):
if count < 1000: count += 1
count += 1 id = int(item["_id"])
id = int(item["_id"]) print(count, id)
print(count, id) content = item["_source"]["answer"]
content = item["_source"]["answer"] content_level = str(item["_source"]["content_level"])
content_level = str(item["_source"]["content_level"]) level_dict[content_level].append(id)
# print(id, content_level, content) embedding_dict[id] = bc.encode([content]).tolist()[0]
level_dict[content_level].append(id)
embedding_dict[id] = bc.encode([content]).tolist()[0]
tmp_tuple = random.choice(list(embedding_dict.items())) tmp_tuple = random.choice(list(embedding_dict.items()))
print(tmp_tuple) print(tmp_tuple)
...@@ -67,6 +66,12 @@ if __name__ == "__main__": ...@@ -67,6 +66,12 @@ if __name__ == "__main__":
print("trained: " + str(index2.is_trained)) print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal)) print("total index: " + str(index2.ntotal))
base_dir = os.getcwd()
model_dir = os.path.join(base_dir, "_models")
index_path = os.path.join(model_dir, "faiss_answer_similarity.index")
faiss.write_index(index2, index_path)
print(index_path)
id = tmp_tuple[0] id = tmp_tuple[0]
emb = np.array([embedding_dict[id]]).astype("float32") emb = np.array([embedding_dict[id]]).astype("float32")
print(emb) print(emb)
...@@ -78,4 +83,3 @@ if __name__ == "__main__": ...@@ -78,4 +83,3 @@ if __name__ == "__main__":
if i <= 0.1: if i <= 0.1:
res.append(ids[index]) res.append(ids[index])
print(res, "\n") print(res, "\n")
print(D)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment