Commit 8249f3d9 authored by 赵威's avatar 赵威

try search

parent 0f48d413
...@@ -32,16 +32,12 @@ def match_tractate_by_device(device_id, n=10): ...@@ -32,16 +32,12 @@ def match_tractate_by_device(device_id, n=10):
vectors.append(np.array(lst).astype("float32")) vectors.append(np.array(lst).astype("float32"))
if vectors: if vectors:
average_time_begin = time.time()
average_vectors = np.array([np.average(vectors, axis=0)]).astype("float32") average_vectors = np.array([np.average(vectors, axis=0)]).astype("float32")
average_time_end = time.time() - average_time_begin
search_time_begin = time.time() search_time_begin = time.time()
D, I = FAISS_TAGS_INDEX.search(average_vectors, n) D, I = FAISS_TAGS_INDEX.search(average_vectors, n)
search_time_end = time.time() - search_time_begin search_time_end = time.time() - search_time_begin
to_list_time_begin = time.time()
distances = D.tolist()[0] distances = D.tolist()[0]
ids = I.tolist()[0] ids = I.tolist()[0]
to_list_time_end = time.time() - to_list_time_begin
for (index, i) in enumerate(distances): for (index, i) in enumerate(distances):
if i <= 5.0: if i <= 5.0:
res.append(ids[index]) res.append(ids[index])
...@@ -52,8 +48,6 @@ def match_tractate_by_device(device_id, n=10): ...@@ -52,8 +48,6 @@ def match_tractate_by_device(device_id, n=10):
"api": "strategy_embedding/personas_vector/match", "api": "strategy_embedding/personas_vector/match",
"device_id": device_id, "device_id": device_id,
"n": n, "n": n,
"average_time": "{:.3f}ms".format(average_time_end * 1000),
"to_list_time": "{:.3f}ms".format(to_list_time_end * 1000),
"search": "{:.3f}ms".format(search_time_end * 1000), "search": "{:.3f}ms".format(search_time_end * 1000),
"total_time": "{:.3f}ms".format(time_end * 1000) "total_time": "{:.3f}ms".format(time_end * 1000)
} }
......
...@@ -79,7 +79,8 @@ if __name__ == "__main__": ...@@ -79,7 +79,8 @@ if __name__ == "__main__":
tractate_ids = np.array(list(tractate_vector_dict.keys())).astype("int") tractate_ids = np.array(list(tractate_vector_dict.keys())).astype("int")
tractate_embeddings = np.array(list(tractate_vector_dict.values())).astype("float32") tractate_embeddings = np.array(list(tractate_vector_dict.values())).astype("float32")
index = faiss.IndexFlatL2(tractate_embeddings.shape[1]) # index = faiss.IndexFlatL2(tractate_embeddings.shape[1])
index= faiss.IndexIVFFlat(tractate_embeddings.shape[1])
print("trained: " + str(index.is_trained)) print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index) index2 = faiss.IndexIDMap(index)
...@@ -87,24 +88,24 @@ if __name__ == "__main__": ...@@ -87,24 +88,24 @@ if __name__ == "__main__":
print("trained: " + str(index2.is_trained)) print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal)) print("total index: " + str(index2.ntotal))
base_dir = os.getcwd() # base_dir = os.getcwd()
model_dir = os.path.join(base_dir, "_models") # model_dir = os.path.join(base_dir, "_models")
index_path = os.path.join(model_dir, "faiss_personas_vector.index") # index_path = os.path.join(model_dir, "faiss_personas_vector.index")
faiss.write_index(index2, index_path) # faiss.write_index(index2, index_path)
print(index_path) # print(index_path)
# device vector # device vector
# for _, row in device_tags_df.iterrows(): for _, row in device_tags_df.iterrows():
# vecs = [] vecs = []
# for i in row["business_tags"]: for i in row["business_tags"]:
# # vec = tags_vector_dict.get(i, np.array([])) # vec = tags_vector_dict.get(i, np.array([]))
# vec = tags_vector_dict.get(i) vec = tags_vector_dict.get(i)
# if vec: if vec:
# vecs.append(np.array(json.loads(vec)).astype("float32")) vecs.append(np.array(json.loads(vec)).astype("float32"))
# if vecs: if vecs:
# t = np.array([np.average(vecs, axis=0)]).astype("float32") t = np.array([np.average(vecs, axis=0)]).astype("float32")
# D, I = index2.search(t, 10) D, I = index2.search(t, 10)
# print(row["cl_id"], row["business_tags"]) print(row["cl_id"], row["business_tags"])
# print(I) print(I)
# curl "http://172.16.31.17:9000/gm-dbmw-tractate-read/_search?pretty" -d '{"query": {"term": {"id": "10269"}}, "_source": {"include": ["content", "portrait_tag_name"]}}' # curl "http://172.16.31.17:9000/gm-dbmw-tractate-read/_search?pretty" -d '{"query": {"term": {"id": "10269"}}, "_source": {"include": ["content", "portrait_tag_name"]}}'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment