diary_cover_similarity.py 9.3 KB
import json
import os
import random
import time

import dlib
import faiss
import numpy as np
from gm_rpcd.all import bind
from utils.cache import redis_client3, redis_client_db
from utils.es import es_query, es_scan
from utils.images import face_to_vec, url_to_ndarray

base_dir = os.getcwd()
print("base_dir: " + base_dir)
model_dir = os.path.join(base_dir, "_models")
facerec_model_path = os.path.join(model_dir, "dlib_face_recognition_resnet_model_v1.dat")
shape_model_path = os.path.join(model_dir, "shape_predictor_68_face_landmarks.dat")
faiss_index_path = os.path.join(base_dir, "_index", "diary_cover.index")

face_rec = dlib.face_recognition_model_v1(facerec_model_path)
face_detector = dlib.get_frontal_face_detector()
shape_predictor = dlib.shape_predictor(shape_model_path)
FACE_TO_VEC_FUN = lambda img: face_to_vec(img, face_rec, face_detector, shape_predictor)
FAISS_DIARY_INDEX = faiss.read_index(faiss_index_path)

DIARY_AFTER_COVER_FEATURE_KEY = "strategy_embedding:diary:cover:after"


@bind("strategy_embedding/face_similarity/hello")
def hello():
    return ["hello", "world"]


# curl "http://172.16.31.17:9200/gm-dbmw-diary/_search?pretty&size=0" -d '
def save_diary_image_info():
    q = {
        "query": {
            "bool": {
                "filter": [{
                    "term": {
                        "is_online": True
                    }
                }, {
                    "term": {
                        "has_cover": True
                    }
                }, {
                    "term": {
                        "is_sink": False
                    }
                }, {
                    "term": {
                        "has_before_cover": True
                    }
                }, {
                    "term": {
                        "has_after_cover": True
                    }
                }, {
                    "terms": {
                        "content_level": [6, 5, 4, 3.5, 3]
                    }
                }, {
                    "term": {
                        "content_simi_bol_show": 0
                    }
                }, {
                    "exists": {
                        "field": "before_cover_url"
                    }
                }]
            }
        },
        "_source": {
            "include": ["id", "before_cover_url", "after_cover_url"]
        }
    }
    count = 0
    # before_res_dict = {}
    after_res_dict = {}
    results = es_scan("diary", q)
    for item in results:
        diary_id = item["_id"]
        # before_cover_url = item["_source"]["before_cover_url"] + "-w"
        # before_img = url_to_ndarray(before_cover_url)

        after_cover_url = item["_source"]["after_cover_url"] + "-w"
        img = url_to_ndarray(after_cover_url)
        if img.any():
            count += 1
            print("count: " + str(count) + " " + str(diary_id))
            faces = FACE_TO_VEC_FUN(img)
            for face in faces:
                after_res_dict[diary_id] = face["feature"]
    redis_client_db.hmset(DIARY_AFTER_COVER_FEATURE_KEY, after_res_dict)


def save_faiss_index(save_path):
    data = redis_client_db.hgetall(DIARY_AFTER_COVER_FEATURE_KEY)
    ids = []
    features = []
    for (k, v) in data.items():
        ids.append(str(k, "utf-8"))
        features.append(np.array(json.loads(v)))

    print("ids: " + str(len(ids)))
    ids_np = np.array(ids).astype("int")
    features_np = np.array(features).astype("float32")
    index = faiss.IndexHNSWFlat(128, 32)
    print("trained: " + str(index.is_trained))
    index2 = faiss.IndexIDMap(index)
    index2.add_with_ids(features_np, ids_np)
    faiss.write_index(index2, save_path)
    print("faiss index saved")


@bind("strategy_embedding/face_similarity/diary_url")
def get_similar_diary_ids_by_url(url, limit=0.1):
    img = url_to_ndarray(url)
    if img.any():
        faces = FACE_TO_VEC_FUN(img)
        for face in faces:
            face_feature = np.array(json.loads(face["feature"])).astype("float32")
            _scores, _ids = FAISS_DIARY_INDEX.search(np.array([face_feature]), 10)
            ids = _ids.flat
            scores = _scores.flat
            tmp = list(set(zip(ids, scores)))
            res = []
            for (id, score) in tmp:
                if score >= limit:
                    res.append((id, score))
            res.sort(key=lambda x: x[1], reverse=True)
            return res
    else:
        return []


@bind("strategy_embedding/face_similarity/diary_feature")
def get_similar_diary_ids_by_face_features(feature, limit=0.1):
    feature = np.array(feature).astype("float32")
    _sources, _ids = FAISS_DIARY_INDEX.search(np.array([feature]), 10)
    ids = _ids.flat
    scores = _sources.flat
    tmp = list(set(zip(ids, scores)))
    res = []
    for (id, score) in tmp:
        if score >= limit:
            res.append((str(id), float(score)))
    res.sort(key=lambda x: x[1], reverse=True)
    return res


# def save_diary_similarity(load_file, index_path):
#     res_dict = {}
#     with open(load_file) as f:
#         lines = f.readlines()
#         print("lines: " + str(len(lines)))
#         count = 0
#         for line in lines:
#             count += 1
#             tmp = line.split("\t")
#             id = tmp[0]
#             feature = np.array(json.loads(tmp[1]))
#             print("{} {}".format(count, id))
#             tup_res = get_similar_diary_ids_by_face_features(feature, index_path, FACE_TO_VEC_FUN)
#             if tup_res:
#                 res_dict[id] = json.dumps(tup_res)

#     print("done: " + str(len(res_dict)))
#     key = random.choice(list(res_dict.keys()))
#     print(key + str(res_dict[key]))

#     redis_key = "doris:diary:face_similary"
#     redis_client3.hmset(redis_key, res_dict)

if __name__ == "__main__":
    begin_time = time.time()

    save_diary_image_info()
    # save_faiss_index(faiss_index_path)

    # imgs = [
    #     "https://pic.igengmei.com/2020/07/03/1437/1b9975bb0b81-w", "https://pic.igengmei.com/2020/07/01/1812/ca64827a83da-w",
    #     "https://pic.igengmei.com/2020/07/04/1711/24f4131a9b1e-w", "https://pic.igengmei.com/2020/07/04/1507/e17a995be219-w"
    # ]
    # for img_url in imgs:
    #     res = get_similar_diary_ids_by_url(img_url, limit=0.18232107)
    #     print(res)

    # print("@@@@@@@@")
    # a = [
    #     -0.08361373096704483, 0.06760436296463013, 0.10752949863672256, -0.020746365189552307, -0.07035162299871445,
    #     -0.014547230675816536, -0.043201886117458344, -0.12196271121501923, 0.13929598033428192, -0.1360183209180832,
    #     0.23247791826725006, -0.08867999166250229, -0.24177594482898712, -0.05600903555750847, -0.05371646583080292,
    #     0.22015368938446045, -0.12883149087429047, -0.0822330191731453, -0.0413128100335598, 0.08704500645399094,
    #     0.10081718862056732, -0.03764188289642334, 0.036720920354127884, 0.04766431450843811, -0.0685625970363617,
    #     -0.38336044549942017, -0.10978807508945465, -0.07328074425458908, -0.023904308676719666, -0.007438751868903637,
    #     -0.09545779973268509, 0.027364756911993027, -0.1537190079689026, -0.04008519649505615, -0.03581209108233452,
    #     0.04322449117898941, -0.05686069279909134, -0.11610691249370575, 0.1640746295452118, -0.004643512889742851,
    #     -0.34821364283561707, 0.03711444139480591, -0.0026186704635620117, 0.1917344480752945, 0.14298999309539795,
    #     0.04084448516368866, 0.06119539216160774, -0.12611950933933258, 0.10941470414400101, -0.20786598324775696,
    #     0.03435457497835159, 0.11412393301725388, 0.0602775476872921, 0.054409340023994446, -0.002967053558677435,
    #     -0.12524624168872833, 0.026284342631697655, 0.08236880600452423, -0.10654348134994507, 0.00403654295951128,
    #     0.10716681182384491, -0.08270247280597687, 0.018992319703102112, -0.11595900356769562, 0.18344789743423462,
    #     0.0895184576511383, -0.1307670772075653, -0.15750591456890106, 0.11103398352861404, -0.13521818816661835,
    #     -0.03199139982461929, 0.11129119992256165, -0.17407448589801788, -0.20658859610557556, -0.3114454746246338,
    #     0.01914297416806221, 0.39955294132232666, 0.12365783005952835, -0.14545315504074097, -0.03254598751664162,
    #     -0.10342024266719818, 0.03375910595059395, 0.11272192746400833, 0.21788232028484344, 0.08588762581348419,
    #     0.012640122324228287, -0.07646650820970535, -0.043292030692100525, 0.21306097507476807, -0.12407292425632477,
    #     -0.025112995877861977, 0.2634827196598053, 0.005047444254159927, 0.06562616676092148, -0.07397496700286865,
    #     0.06206338107585907, -0.0634055882692337, 0.05882266163825989, -0.05909111723303795, 0.027562778443098068,
    #     0.043835900723934174, 0.00407575536519289, -0.007656056433916092, 0.1048622876405716, -0.17822585999965668,
    #     0.1303984671831131, -0.021631652489304543, 0.0836174339056015, 0.11956407874822617, 0.007379574701189995,
    #     -0.07777556777000427, -0.08474794030189514, 0.09585978090763092, -0.21120299398899078, 0.1435444951057434,
    #     0.19884724915027618, 0.07154559344053268, 0.06259742379188538, 0.10118959099054337, 0.10188969224691391,
    #     -0.015351934358477592, -0.04335442930459976, -0.26258283853530884, -0.021509556099772453, 0.12185295671224594,
    #     -0.011788002215325832, 0.01337978895753622, -0.008025042712688446
    # ]
    # res = get_similar_diary_ids_by_face_features(a)
    # print(res)

    print("total cost: {:.2f}mins".format((time.time() - begin_time) / 60))