from gm_rpcd.all import bind
from libs.cache import redis_client
import base64
from gm_types.doris import MIND_TYPE
import logging
import traceback
import redis
from django.conf import settings
from libs.es import get_es, es_indices_analyze
import pdb

QUERY_KEY = "query:{}:set"

LABEL_VALUE = {
    MIND_TYPE.PROJECT: 8,
    MIND_TYPE.BRAND: 7,
    MIND_TYPE.HOSPITAL: 7,
    MIND_TYPE.DOCTOR: 6,
    MIND_TYPE.FREE_FACE: 4,
    MIND_TYPE.FACE: 3,
    MIND_TYPE.AREA: 2,
    MIND_TYPE.USER: 1,
    MIND_TYPE.UNKNOWN: 0
}

QUERY_WORD_LABEL_NEED_MODIFIED = {
    u"玻尿酸": MIND_TYPE.PROJECT,
    u"鼻": MIND_TYPE.PROJECT,
    u"眼": MIND_TYPE.PROJECT,
    u"嘴": MIND_TYPE.PROJECT,
    u"脱毛": MIND_TYPE.PROJECT
}


def label_key(label):
    return LABEL_VALUE.get(label)


@bind('doris/query/inference')
def query_inference(query=''):
    try:
        labels = list()
        query_base64 = base64.b64encode(query.encode('utf8')).decode('utf8')
        key = QUERY_KEY.format(query_base64)
        labels = list(map(lambda x: x.decode("utf8"), list(redis_client.smembers(key))))
        if len(labels) == 0:
            labels = list(get_synonym_query(query))
        labels.append(MIND_TYPE.UNKNOWN)
        if query in QUERY_WORD_LABEL_NEED_MODIFIED:
            labels.append(MIND_TYPE.PROJECT)
        labels.sort(key=label_key, reverse=True)
        return {'label': labels[0]}
    except:
        logging.error("catch exception,err_msg:%s" % traceback.format_exc())
        return {'label': MIND_TYPE.UNKNOWN}


def get_synonym_query(query=''):
    try:
        synonym_term_set = set()
        synonym_term_set_labels = set()
        es = get_es()
        body = {
            'text': query,
            'analyzer': "gm_default_index"
        }
        res = es_indices_analyze(doc_type="newwiki", body=body, es=es)
        for item in res["tokens"]:
            if item["type"] == "SYNONYM" and item["start_offset"] == 0 and item["end_offset"] == len(query):
                synonym_term_set.add(item["token"])
                for query in synonym_term_set:
                    query_base64 = base64.b64encode(query.encode('utf8')).decode('utf8')
                    key = QUERY_KEY.format(query_base64)
                    labels = list(map(lambda x: x.decode("utf8"), list(redis_client.smembers(key))))
                    for i in labels:
                        synonym_term_set_labels.add(i)
        return synonym_term_set_labels

    except:
        logging.error("catch exception, query_sku:%s" % traceback.format_exc())
