Commit 8a491b2f authored by 赵威's avatar 赵威

get distance

parent 91de5a8e
......@@ -62,7 +62,7 @@ if __name__ == "__main__":
index = faiss.IndexFlatL2(answer_embeddings.shape[1])
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap2(index)
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(answer_embeddings, answer_ids)
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
......@@ -70,7 +70,7 @@ if __name__ == "__main__":
id = tmp_tuple[0]
emb = np.array([embedding_dict[id]]).astype("float32")
print(emb)
aaa = index2.search(emb, 10)
# res = I.tolist()
print(aaa)
# print(res, "\n")
D, I = index2.search(emb, 10)
res = I.tolist()
print(res, "\n")
print(D)
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment