Commit 73e77473 authored by 赵威's avatar 赵威

add printer

parent 911a2f60
import os
import sys
import random
sys.path.append(os.path.realpath("."))
import numpy as np
......@@ -52,6 +53,7 @@ if __name__ == "__main__":
level_dict[content_level].append(id)
embedding_dict[id] = bc.encode([content])
print(random.choice(list(embedding_dict.item())))
answer_ids = np.array(list(embedding_dict.keys())).astype("int")
answer_embeddings = np.array(list(embedding_dict.values())).astype("float32")
print(answer_embeddings.shape)
......@@ -64,7 +66,7 @@ if __name__ == "__main__":
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
for i in [59753, 54792, 42643]:
for i in [1015527, 1015536, 292379]:
D, I = index2.search(np.array(answer_embeddings[i]).astype("float32"))
res = I.tolist()
print(res, "\n")
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment