Commit 8249f3d9 authored by 赵威's avatar 赵威

try search

parent 0f48d413
......@@ -32,16 +32,12 @@ def match_tractate_by_device(device_id, n=10):
vectors.append(np.array(lst).astype("float32"))
if vectors:
average_time_begin = time.time()
average_vectors = np.array([np.average(vectors, axis=0)]).astype("float32")
average_time_end = time.time() - average_time_begin
search_time_begin = time.time()
D, I = FAISS_TAGS_INDEX.search(average_vectors, n)
search_time_end = time.time() - search_time_begin
to_list_time_begin = time.time()
distances = D.tolist()[0]
ids = I.tolist()[0]
to_list_time_end = time.time() - to_list_time_begin
for (index, i) in enumerate(distances):
if i <= 5.0:
res.append(ids[index])
......@@ -52,8 +48,6 @@ def match_tractate_by_device(device_id, n=10):
"api": "strategy_embedding/personas_vector/match",
"device_id": device_id,
"n": n,
"average_time": "{:.3f}ms".format(average_time_end * 1000),
"to_list_time": "{:.3f}ms".format(to_list_time_end * 1000),
"search": "{:.3f}ms".format(search_time_end * 1000),
"total_time": "{:.3f}ms".format(time_end * 1000)
}
......
......@@ -79,7 +79,8 @@ if __name__ == "__main__":
tractate_ids = np.array(list(tractate_vector_dict.keys())).astype("int")
tractate_embeddings = np.array(list(tractate_vector_dict.values())).astype("float32")
index = faiss.IndexFlatL2(tractate_embeddings.shape[1])
# index = faiss.IndexFlatL2(tractate_embeddings.shape[1])
index= faiss.IndexIVFFlat(tractate_embeddings.shape[1])
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index)
......@@ -87,24 +88,24 @@ if __name__ == "__main__":
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
base_dir = os.getcwd()
model_dir = os.path.join(base_dir, "_models")
index_path = os.path.join(model_dir, "faiss_personas_vector.index")
faiss.write_index(index2, index_path)
print(index_path)
# base_dir = os.getcwd()
# model_dir = os.path.join(base_dir, "_models")
# index_path = os.path.join(model_dir, "faiss_personas_vector.index")
# faiss.write_index(index2, index_path)
# print(index_path)
# device vector
# for _, row in device_tags_df.iterrows():
# vecs = []
# for i in row["business_tags"]:
# # vec = tags_vector_dict.get(i, np.array([]))
# vec = tags_vector_dict.get(i)
# if vec:
# vecs.append(np.array(json.loads(vec)).astype("float32"))
# if vecs:
# t = np.array([np.average(vecs, axis=0)]).astype("float32")
# D, I = index2.search(t, 10)
# print(row["cl_id"], row["business_tags"])
# print(I)
for _, row in device_tags_df.iterrows():
vecs = []
for i in row["business_tags"]:
# vec = tags_vector_dict.get(i, np.array([]))
vec = tags_vector_dict.get(i)
if vec:
vecs.append(np.array(json.loads(vec)).astype("float32"))
if vecs:
t = np.array([np.average(vecs, axis=0)]).astype("float32")
D, I = index2.search(t, 10)
print(row["cl_id"], row["business_tags"])
print(I)
# curl "http://172.16.31.17:9000/gm-dbmw-tractate-read/_search?pretty" -d '{"query": {"term": {"id": "10269"}}, "_source": {"include": ["content", "portrait_tag_name"]}}'
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment