Commit 876f7a6d authored by 赵威's avatar 赵威

add search

parent a45d3de3
...@@ -83,12 +83,14 @@ if __name__ == "__main__": ...@@ -83,12 +83,14 @@ if __name__ == "__main__":
quantizer = faiss.IndexFlatL2(tractate_embeddings.shape[1]) quantizer = faiss.IndexFlatL2(tractate_embeddings.shape[1])
index = faiss.IndexIVFFlat(quantizer, tractate_embeddings.shape[1], 100, faiss.METRIC_L2) index = faiss.IndexIVFFlat(quantizer, tractate_embeddings.shape[1], 100, faiss.METRIC_L2)
index.train(tractate_embeddings)
index.add_with_ids(tractate_embeddings, tractate_ids)
print("trained: " + str(index.is_trained)) print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index) # index2 = faiss.IndexIDMap(index)
index2.add_with_ids(tractate_embeddings, tractate_ids) # index2.add_with_ids(tractate_embeddings, tractate_ids)
print("trained: " + str(index2.is_trained)) # print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal)) # print("total index: " + str(index2.ntotal))
# base_dir = os.getcwd() # base_dir = os.getcwd()
# model_dir = os.path.join(base_dir, "_models") # model_dir = os.path.join(base_dir, "_models")
...@@ -106,7 +108,7 @@ if __name__ == "__main__": ...@@ -106,7 +108,7 @@ if __name__ == "__main__":
vecs.append(np.array(json.loads(vec)).astype("float32")) vecs.append(np.array(json.loads(vec)).astype("float32"))
if vecs: if vecs:
t = np.array([np.average(vecs, axis=0)]).astype("float32") t = np.array([np.average(vecs, axis=0)]).astype("float32")
D, I = index2.search(t, 10) D, I = index.search(t, 10)
print(row["cl_id"], row["business_tags"]) print(row["cl_id"], row["business_tags"])
print(I) print(I)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment