Commit 3cc3b0ca authored by 赵威's avatar 赵威

add diary

parent 279f2323
......@@ -9,7 +9,7 @@ import faiss
import numpy as np
from bert_serving.client import BertClient
from utils.cache import redis_client_db
from utils.es import es_scan, get_answer_info_from_es
from utils.es import get_answer_info_from_es
from utils.files import MODEL_PATH
......
import json
import os
import random
import sys
sys.path.append(os.path.realpath("."))
import faiss
import numpy as np
from bert_serving.client import BertClient
from utils.cache import redis_client_db
from utils.es import get_diary_info_from_es
from utils.files import MODEL_PATH
if __name__ == "__main__":
bc = BertClient("172.16.44.82", check_length=False)
level_dict = {"6": [], "5": [], "4": [], "3.5": [], "3": []}
count = 0
embedding_dict = {}
for item in get_diary_info_from_es(["id", "answer", "content_level"]):
count += 1
if count < 1000:
id = int(item["_id"])
print(count, id)
content = item["_source"]["answer"]
content_level = str(item["_source"]["content_level"])
level_dict[content_level].append(id)
embedding_dict[id] = bc.encode([content]).tolist()[0]
# redis_client_db.hmset("diary:level_dict", json.dumps(level_dict))
tmp_tuple = random.choice(list(embedding_dict.items()))
print(tmp_tuple)
diary_ids = np.array(list(embedding_dict.keys())).astype("int")
diary_embeddings = np.array(list(embedding_dict.values())).astype("float32")
print(diary_embeddings.shape)
index = faiss.IndexFlatL2(diary_embeddings.shape[1])
print("trained: " + str(index.is_trained))
index2 = faiss.IndexIDMap(index)
index2.add_with_ids(diary_embeddings, diary_ids)
print("trained: " + str(index2.is_trained))
print("total index: " + str(index2.ntotal))
# index_path = os.path.join(MODEL_PATH, "faiss_diary_similarity.index")
# faiss.write_index(index2, index_path)
# print(index_path)
id = tmp_tuple[0]
emb = np.array([embedding_dict[id]]).astype("float32")
print(emb)
D, I = index2.search(emb, 10)
distances = D.tolist()[0]
ids = I.tolist()[0]
res = []
for (index, i) in enumerate(distances):
if i <= 0.5:
res.append(ids[index])
print(res, "\n")
print(ids, "\n")
print(D)
......@@ -109,64 +109,48 @@ def get_answer_info_from_es(fields=["id"]):
return results
# def save_diary_info_from_es():
# q = {
# "query": {
# "bool": {
# "filter": [{
# "term": {
# "is_online": True
# }
# }, {
# "term": {
# "has_cover": True
# }
# }, {
# "term": {
# "is_sink": False
# }
# }, {
# "term": {
# "has_before_cover": True
# }
# }, {
# "term": {
# "has_after_cover": True
# }
# }, {
# "terms": {
# "content_level": [6, 5, 4, 3.5, 3]
# }
# }, {
# "term": {
# "content_simi_bol_show": 0
# }
# }, {
# "exists": {
# "field": "before_cover_url"
# }
# }]
# }
# },
# "_source": {
# "include": ["id"]
# }
# }
# count = 0
# # before_res_dict = {}
# after_res_dict = {}
# results = es_scan("diary", q)
# for item in results:
# diary_id = item["_id"]
# # before_cover_url = item["_source"]["before_cover_url"] + "-w"
# # before_img = url_to_ndarray(before_cover_url)
# after_cover_url = item["_source"]["after_cover_url"] + "-w"
# img = url_to_ndarray(after_cover_url)
# if img.any():
# count += 1
# print("count: " + str(count) + " " + str(diary_id))
# faces = FACE_TO_VEC_FUN(img)
# for face in faces:
# after_res_dict[diary_id] = face["feature"]
# redis_client_db.hmset(DIARY_AFTER_COVER_FEATURE_KEY, after_res_dict)
def get_diary_info_from_es(fields=["id"]):
q = {
"query": {
"bool": {
"filter": [{
"term": {
"is_online": True
}
}, {
"term": {
"has_cover": True
}
}, {
"term": {
"is_sink": False
}
}, {
"term": {
"has_before_cover": True
}
}, {
"term": {
"has_after_cover": True
}
}, {
"terms": {
"content_level": [6, 5, 4, 3.5, 3]
}
}, {
"term": {
"content_simi_bol_show": 0
}
}, {
"exists": {
"field": "before_cover_url"
}
}]
}
},
"_source": {
"include": fields
}
}
results = es_scan("diary", q)
return results
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment