from collections import defaultdict

import itertools
from celery import shared_task
from django.conf import settings
from django.utils.functional import cached_property
from gm_types.gaia import TAG_V3_TYPE

from qa.models import QuestionTag, QuestionTagV3, Question, AnswerTagV3, Answer, AnswerTag
from talos.cache.base import tag_map_tag3_record
from talos.logger import info_logger
from talos.models.diary import Diary, DiaryTagV3, DiaryTag
from talos.models.topic import ProblemTagV3, Problem, ProblemTag
from talos.services import TagV3Service
from utils.rpc import get_rpc_invoker, logging_exception
from utils.group_routine import GroupRoutine


rpc_client = get_rpc_invoker()
cache_key = "{}_clean_record_max_id"                        # 标签全量清洗的记录
task_cache_key = "{content_type}:current_content_id"        # 增量数据的处理记录
CONTENT_TYPE_MAP = {
    "diary": (Diary, DiaryTagV3, "diary_id", DiaryTag, 'tag_id'),
    "topic": (Problem, ProblemTagV3, "problem_id", ProblemTag, 'tag_id'),
    "question": (Question, QuestionTagV3, "question_id", QuestionTag, 'tag'),
    "answer": (Answer, AnswerTagV3, "answer_id", AnswerTag, 'tag'),
}


class TagMapTool(object):
    def __init__(self, content_type):
        self.model, self.tag3_model, self.field, self.tag_model, self.tag_field = CONTENT_TYPE_MAP.get(content_type)
        task_record = int(tag_map_tag3_record.get(task_cache_key.format(content_type=content_type)) or 0)
        full_clean_record = int(tag_map_tag3_record.get(cache_key.format(content_type)) or 0)
        self.cache_record = max(task_record, full_clean_record)
        self.old_tag_ids = set()
        self.RPC_URL = 'doris/search/content_tagv3_tokenizer'   # params: content_id_list, content_type

    @cached_property
    def get_content_max_id(self):
        """当前最大内容id"""
        max_id = self.model.objects.using(settings.SLAVE_DB_NAME).last()
        info_logger.info('max id {}'.format(max_id.id))
        return max_id.id

    @cached_property
    def get_all_new_content_ids(self):
        content_ids = list(self.model.objects.using(settings.SLAVE_DB_NAME).filter(
            id__gte=self.cache_record, id__lte=self.get_content_max_id, is_online=True,
        ).values_list('id', flat=True))

        return content_ids

    @cached_property
    def get_exists_content_tag_map(self):
        query = {'{}__in'.format(self.field): self.get_all_new_content_ids}
        relation = set(self.tag3_model.objects.using(settings.SLAVE_DB_NAME).filter(
            **query
        ).values_list(*[self.field, 'tag_v3_id']))
        return relation

    def status(self):
        """是否有内容更新"""
        return 0 < self.cache_record < self.get_content_max_id

    def get_content_map_old_tag_ids(self):
        query = {'{}__in'.format(self.field): self.get_all_new_content_ids}
        tag_ids = set(self.tag_model.objects.using(settings.SLAVE_DB_NAME).filter(
            **query
        ).values_list(*[self.field, self.tag_field]))

        old_tag_map_content = defaultdict(list)
        for content_id, old_tag_id in tag_ids:
            old_tag_map_content[old_tag_id].append(content_id)
        self.old_tag_ids.update(set(old_tag_map_content.keys()))

        return old_tag_map_content

    def get_old_tag_map_tag3(self, old_tag_ids):
        tags = TagV3Service.get_tag_v3_by_old_tag_ids(list(old_tag_ids))
        old_tag_map_tag3 = defaultdict(list)
        for tag_id, tag3_info_list in tags.items():
            old_tag_map_tag3[int(tag_id)].extend([tag['id'] for tag in tag3_info_list])

        return old_tag_map_tag3

    def request_task(self, _ids, content_type):
        """策略切词
        http://wiki.wanmeizhensuo.com/pages/viewpage.action?pageId=36564788
        """
        try:
            result = rpc_client[self.RPC_URL](
                content_id_list=_ids, content_type=content_type
            ).unwrap()
        except:
            logging_exception()
            result = {}

        return result

    @classmethod
    def get_doris_tag_map(cls, content_ids, content_type):
        """没有新老标签映射的内容走一遍策略切词"""
        content_steps = (content_ids[i:i + 20] for i in range(0, len(content_ids), 20))
        t = GroupRoutine()
        for index, content_ids in enumerate(content_steps):
            t.submit(index, cls.request_task, content_ids, content_type)
        t.go()
        content_map_tag = {}
        for index, _ in enumerate(content_steps):
            _data = t.results.get(index, {})
            if not _data:
                continue
            content_map_tag.update(_data)
        return content_map_tag


@shared_task
def tag_map_tag3(content_type):
    """增量内容标签清洗 使用新老标签关系映射，不走策略"""
    t = TagMapTool(content_type=content_type)
    if not t.status():
        info_logger.info('{} 无新增内容同步'.format(content_type))
        return
    create_info = set()
    old_tag_map_content = t.get_content_map_old_tag_ids()
    old_tag_ids = t.old_tag_ids
    old_tag_map_tag3 = t.get_old_tag_map_tag3(old_tag_ids)
    already_exists_map = t.get_exists_content_tag_map

    for tag_id in old_tag_ids:
        content_ids = old_tag_map_content.get(tag_id)
        tag3_ids = old_tag_map_tag3.get(tag_id)
        info_logger.info('content_ids:{}, tag3_ids:{}'.format(content_ids, tag3_ids))
        if not all([content_ids, tag3_ids]):
            continue

        for content_id, tag3_id in itertools.product(content_ids, tag3_ids):
            if (content_id, tag3_id) in already_exists_map:
                continue
            create_info.add((content_id, tag3_id))

    # 没有找到映射关系的内容再走一遍策略切词  现在策略切词还没支持，先注释掉
    # all_content_ids = t.get_all_new_content_ids
    # invalid_content_ids = set(all_content_ids) - set([content_id for content_id, _ in create_info])
    # content_tag_map = t.get_doris_tag_map(invalid_content_ids, content_type)
    #
    # for c_id in invalid_content_ids:
    #     tag3_ids = content_tag_map.get(c_id, {}).get(TAG_V3_TYPE.NORMAL, [])
    #     info_logger.info('doris content_id:{}, tag3_ids:{}'.format(c_id, tag3_ids))
    #     if not tag3_ids:
    #         continue
    #     create_info.update([(c_id, tag3_id) for tag3_id in tag3_ids])

    create_object = []
    for content_id, tag3_id in create_info:
        info = {t.field: content_id, 'tag_v3_id': tag3_id}
        create_object.append(t.tag3_model(
            **info
        ))

    t.tag3_model.objects.bulk_create(create_object)
    tag_map_tag3_record.set(task_cache_key.format(content_type=content_type), t.get_content_max_id)

    return


@shared_task
def sync_tag_map_tag3():
    for content_type in CONTENT_TYPE_MAP:
        tag_map_tag3.apply_async(
            args=(content_type, )
        )
