Commit 876f7a6d authored by 赵威's avatar 赵威

add search

parent a45d3de3
......@@ -83,12 +83,14 @@ if __name__ == "__main__":
quantizer = faiss.IndexFlatL2(tractate_embeddings.shape[1])
index = faiss.IndexIVFFlat(quantizer, tractate_embeddings.shape[1], 100, faiss.METRIC_L2)
index.train(tractate_embeddings)
index.add_with_ids(tractate_embeddings, tractate_ids)
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(tractate_embeddings, tractate_ids)
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
# index2 = faiss.IndexIDMap(index)
# index2.add_with_ids(tractate_embeddings, tractate_ids)
# print("trained: " + str(index2.is_trained))
# print("total index: " + str(index2.ntotal))
# base_dir = os.getcwd()
# model_dir = os.path.join(base_dir, "_models")
......@@ -106,7 +108,7 @@ if __name__ == "__main__":
vecs.append(np.array(json.loads(vec)).astype("float32"))
if vecs:
t = np.array([np.average(vecs, axis=0)]).astype("float32")
D, I = index2.search(t, 10)
D, I = index.search(t, 10)
print(row["cl_id"], row["business_tags"])
print(I)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment