Commit 7167cda1 authored by 赵威's avatar 赵威

update save function

parent 95395422
...@@ -18,14 +18,14 @@ facerec_model_path = os.path.join(model_dir, "dlib_face_recognition_resnet_model ...@@ -18,14 +18,14 @@ facerec_model_path = os.path.join(model_dir, "dlib_face_recognition_resnet_model
shape_model_path = os.path.join(model_dir, "shape_predictor_68_face_landmarks.dat") shape_model_path = os.path.join(model_dir, "shape_predictor_68_face_landmarks.dat")
faiss_index_path = os.path.join(base_dir, "_index", "diary_cover.index") faiss_index_path = os.path.join(base_dir, "_index", "diary_cover.index")
diary_after_cover_vec_file = "./diary_after_cover_vec.txt"
face_rec = dlib.face_recognition_model_v1(facerec_model_path) face_rec = dlib.face_recognition_model_v1(facerec_model_path)
face_detector = dlib.get_frontal_face_detector() face_detector = dlib.get_frontal_face_detector()
shape_predictor = dlib.shape_predictor(shape_model_path) shape_predictor = dlib.shape_predictor(shape_model_path)
FACE_TO_VEC_FUN = lambda img: face_to_vec(img, face_rec, face_detector, shape_predictor) FACE_TO_VEC_FUN = lambda img: face_to_vec(img, face_rec, face_detector, shape_predictor)
FAISS_DIARY_INDEX = faiss.read_index(faiss_index_path) FAISS_DIARY_INDEX = faiss.read_index(faiss_index_path)
DIARY_AFTER_COVER_FEATURE_KEY = "strategy_embedding:diary:cover:after"
@bind("strategy_embedding/face_similarity/hello") @bind("strategy_embedding/face_similarity/hello")
def hello(): def hello():
...@@ -93,31 +93,26 @@ def save_diary_image_info(): ...@@ -93,31 +93,26 @@ def save_diary_image_info():
faces = FACE_TO_VEC_FUN(img) faces = FACE_TO_VEC_FUN(img)
for face in faces: for face in faces:
after_res_dict[diary_id] = face["feature"] after_res_dict[diary_id] = face["feature"]
redis_client_db.hmset("strategy_embedding:diary:cover:after", after_res_dict) redis_client_db.hmset(DIARY_AFTER_COVER_FEATURE_KEY, after_res_dict)
def save_faiss_index(load_file, save_path): def save_faiss_index(save_path):
with open(load_file, "r") as f: data = redis_client_db.hgetall(DIARY_AFTER_COVER_FEATURE_KEY)
ids = [] ids = []
features = [] features = []
lines = f.readlines() for (k, v) in data.items():
print("lines: " + str(len(lines))) ids.append(str(k, "utf-8"))
count = 0 features.append(np.array(json.loads(v)))
for line in lines:
count += 1
tmp = line.split("\t")
ids.append(tmp[0])
features.append(np.array(json.loads(tmp[1])))
print("{} {}".format(count, tmp[0]))
ids_np = np.array(ids).astype("int") print("ids: " + str(len(ids)))
features_np = np.array(features).astype("float32") ids_np = np.array(ids).astype("int")
index = faiss.IndexHNSWFlat(128, 32) features_np = np.array(features).astype("float32")
print("trained: " + str(index.is_trained)) index = faiss.IndexHNSWFlat(128, 32)
index2 = faiss.IndexIDMap(index) print("trained: " + str(index.is_trained))
index2.add_with_ids(features_np, ids_np) index2 = faiss.IndexIDMap(index)
faiss.write_index(index2, save_path) index2.add_with_ids(features_np, ids_np)
print("faiss index saved") faiss.write_index(index2, save_path)
print("faiss index saved")
@bind("strategy_embedding/face_similarity/diary_url") @bind("strategy_embedding/face_similarity/diary_url")
...@@ -179,10 +174,11 @@ def get_similar_diary_ids_by_face_features(feature, limit=0.1): ...@@ -179,10 +174,11 @@ def get_similar_diary_ids_by_face_features(feature, limit=0.1):
# redis_key = "doris:diary:face_similary" # redis_key = "doris:diary:face_similary"
# redis_client3.hmset(redis_key, res_dict) # redis_client3.hmset(redis_key, res_dict)
if __name__ == "__main__":
begin_time = time.time()
def main():
save_diary_image_info() save_diary_image_info()
# save_faiss_index(diary_after_cover_vec_file, faiss_index_path) # save_faiss_index(faiss_index_path)
# imgs = [ # imgs = [
# "https://pic.igengmei.com/2020/07/03/1437/1b9975bb0b81-w", "https://pic.igengmei.com/2020/07/01/1812/ca64827a83da-w", # "https://pic.igengmei.com/2020/07/03/1437/1b9975bb0b81-w", "https://pic.igengmei.com/2020/07/01/1812/ca64827a83da-w",
...@@ -224,10 +220,4 @@ def main(): ...@@ -224,10 +220,4 @@ def main():
# res = get_similar_diary_ids_by_face_features(a) # res = get_similar_diary_ids_by_face_features(a)
# print(res) # print(res)
if __name__ == "__main__":
begin_time = time.time()
main()
print("total cost: {:.2f}mins".format((time.time() - begin_time) / 60)) print("total cost: {:.2f}mins".format((time.time() - begin_time) / 60))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment