match_api.py 2.8 KB
import json
import os
import time
import traceback

import faiss
import numpy as np
from gm_rpcd.all import bind
from utils.cache import redis_client_db
from utils.message import (send_msg_to_dingtalk, send_performance_msg_to_dingtalk)
from utils.personas import get_user_portrait_tag3_business_tags

MODEL_PATH = os.path.join(os.getcwd(), "_models")
INDEX_PATH = os.path.join(MODEL_PATH, "faiss_personas_vector.index")
FAISS_TAGS_INDEX = faiss.read_index(INDEX_PATH)

TAG_EMBEDDING_DICT = redis_client_db.hgetall("personas_tags_embedding")


@bind("strategy_embedding/personas_vector/match")
def match_tractate_by_device(device_id, n=10):
    try:
        if device_id not in ["0", "unknown", "87654", ""]:
            time_begin = time.time()
            portrait_time_begin = time.time()
            business_tags = get_user_portrait_tag3_business_tags(device_id, tags_num=3)
            portrait_time_end = time.time() - portrait_time_begin
            res = []
            if business_tags:
                vectors = []
                for tag in business_tags:
                    lst = json.loads(TAG_EMBEDDING_DICT.get(bytes(tag, "utf-8"), b"[]"))
                    if lst:
                        vectors.append(np.array(lst).astype("float32"))

                if vectors:
                    average_vectors = np.array([np.average(vectors, axis=0)]).astype("float32")
                    search_time_begin = time.time()
                    D, I = FAISS_TAGS_INDEX.search(average_vectors, n)
                    search_time_end = time.time() - search_time_begin
                    distances = D.tolist()[0]
                    ids = I.tolist()[0]
                    for (index, i) in enumerate(distances):
                        if i <= 5.0:
                            res.append(ids[index])
                time_end = time.time() - time_begin
                if time_end > 0.04:
                    timer_dict = {
                        "method": "match_tractate_by_device",
                        "api": "strategy_embedding/personas_vector/match",
                        "device_id": device_id,
                        "n": n,
                        "get_business_tags": "{:.3f}ms".format(portrait_time_end * 1000),
                        "search": "{:.3f}ms".format(search_time_end * 1000),
                        "total_time": "{:.3f}ms".format(time_end * 1000)
                    }
                    msg_res = ""
                    for (k, v) in timer_dict.items():
                        msg_res += str(k)
                        msg_res += ": "
                        msg_res += str(v)
                        msg_res += "\n"
                    send_performance_msg_to_dingtalk(msg_res)
            return res
        return []
    except Exception as e:
        send_msg_to_dingtalk(str(traceback.format_exc()))
        return []