# coding=utf8

from django.conf import settings
from celery import shared_task
import json

from qa.models import (
    QuestionTag,
    QuestionTagV3,
    AnswerTag,
    AnswerTagV3,
)
from talos.logger import sync_tag_mapping_logger


BATCH_COUNT = 300


@shared_task
def question_mapping_sync(add_relations=None, delete_relations=None, delete2old_tag_ids=None):

    for item in delete_relations:

        question_ids = list(QuestionTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_v3_id=item["tag_v3_id"]).values_list("question_id", flat=True))
        question_ids = list(QuestionTag.objects.using(settings.SLAVE_DB_NAME).filter(
            question_id__in=question_ids, tag=item["tag_id"]).values_list("question_id", flat=True))

        # 过滤掉多个老标签映射到一个新标签的情况
        old_tag_ids = delete2old_tag_ids.get(str(item["tag_v3_id"])) if delete2old_tag_ids else []
        if old_tag_ids:
            exclude_question_ids = list(QuestionTag.objects.using(settings.SLAVE_DB_NAME).filter(
                question_id__in=question_ids, tag__in=old_tag_ids).exclude(tag=item["tag_id"]).values_list("question_id", flat=True))
            question_ids = list(set(question_ids) - set(exclude_question_ids))

        sync_tag_mapping_logger.info(json.dumps({"action": "delete", "tag_v3_id": item["tag_v3_id"], "question_ids": question_ids}))
        while question_ids:
            qids = question_ids[:BATCH_COUNT]
            QuestionTagV3.objects.filter(tag_v3_id=item["tag_v3_id"], question_id__in=qids).delete()
            question_ids = question_ids[BATCH_COUNT:]

    for item in add_relations:

        question_ids = list(QuestionTag.objects.using(settings.SLAVE_DB_NAME).filter(tag=item["tag_id"]).values_list("question_id", flat=True))
        question_ids_v3 = list(QuestionTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            question_id__in=question_ids, tag_v3_id=item["tag_v3_id"]).values_list("question_id", flat=True))

        question_ids = set(question_ids) - set(question_ids_v3)
        sync_tag_mapping_logger.info(json.dumps({"action": "add", "tag_v3_id": item["tag_v3_id"], "question_ids": list(question_ids)}))
        while question_ids:
            qids = question_ids[:BATCH_COUNT]
            QuestionTagV3.objects.bulk_create([
                QuestionTagV3(question_id=question_id, tag_v3_id=item["tag_v3_id"])
                for question_id in qids
            ])
            question_ids = question_ids[BATCH_COUNT:]


@shared_task
def answer_tag_mapping_sync(add_relations=None, delete_relations=None, delete2old_tag_ids=None):

    for item in delete_relations:

        answer_ids = list(AnswerTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_v3_id=item["tag_v3_id"]).values_list("answer_id", flat=True))
        answer_ids = list(AnswerTag.objects.using(settings.SLAVE_DB_NAME).filter(
            answer_id__in=answer_ids, tag=item["tag_id"]).values_list("answer_id", flat=True))

        # 过滤掉多个老标签映射到一个新标签的情况
        old_tag_ids = delete2old_tag_ids.get(str(item["tag_v3_id"])) if delete2old_tag_ids else []
        if old_tag_ids:
            exclude_answer_ids = list(AnswerTag.objects.using(settings.SLAVE_DB_NAME).filter(
                answer_id__in=answer_ids, tag__in=old_tag_ids).exclude(tag=item["tag_id"]).values_list("answer_id", flat=True))
            answer_ids = list(set(answer_ids) - set(exclude_answer_ids))

        sync_tag_mapping_logger.info(json.dumps({"action": "delete", "tag_v3_id": item["tag_v3_id"], "answer_ids": answer_ids}))
        while answer_ids:
            aids = answer_ids[:BATCH_COUNT]
            AnswerTagV3.objects.filter(tag_v3_id=item["tag_v3_id"], answer_id__in=aids).delete()
            answer_ids = answer_ids[BATCH_COUNT:]

    for item in add_relations:

        answer_ids = list(AnswerTag.objects.using(settings.SLAVE_DB_NAME).filter(
            tag=item["tag_id"]).values_list("answer_id", flat=True))
        answer_ids_v3 = list(AnswerTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            answer_id__in=answer_ids, tag_v3_id=item["tag_v3_id"]).values_list("answer_id", flat=True))

        answer_ids = set(answer_ids) - set(answer_ids_v3)
        sync_tag_mapping_logger.info(json.dumps({"action": "add", "tag_v3_id": item["tag_v3_id"], "answer_ids": list(answer_ids)}))
        while answer_ids:
            aids = answer_ids[:BATCH_COUNT]
            AnswerTagV3.objects.bulk_create([
                AnswerTagV3(answer_id=answer_id, tag_v3_id=item["tag_v3_id"])
                for answer_id in aids
            ])
            answer_ids = answer_ids[BATCH_COUNT:]
