# coding=utf8
import json

from celery import shared_task
from django.conf import settings

from talos.models.diary import DiaryTagV3, DiaryTag
from talos.models.topic import ProblemTagV3, ProblemTag
from talos.models.tractate import TractateTagV3, TractateTag
from talos.logger import sync_tag_mapping_logger


BATCH_COUNT = 300


@shared_task
def diary_tag_mapping_sync(add_relations=None, delete_relations=None, delete2old_tag_ids=None):

    for item in delete_relations:

        diary_ids = list(DiaryTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_v3_id=item["tag_v3_id"]).values_list("diary_id", flat=True))
        diary_ids = list(DiaryTag.objects.using(settings.SLAVE_DB_NAME).filter(
            diary_id__in=diary_ids, tag_id=item["tag_id"]).values_list("diary_id", flat=True))

        # 过滤掉多个老标签映射到一个新标签的情况
        old_tag_ids = delete2old_tag_ids.get(str(item["tag_v3_id"])) if delete2old_tag_ids else []
        if old_tag_ids:
            exclude_diary_ids = list(DiaryTag.objects.using(settings.SLAVE_DB_NAME).filter(
                diary_id__in=diary_ids, tag_id__in=old_tag_ids).exclude(tag_id=item["tag_id"]).values_list("diary_id", flat=True))
            diary_ids = list(set(diary_ids) - set(exclude_diary_ids))

        sync_tag_mapping_logger.info(json.dumps({"action": "delete", "tag_v3_id": item["tag_v3_id"], "diary_ids": diary_ids}))
        while diary_ids:
            dids = diary_ids[:BATCH_COUNT]
            DiaryTagV3.objects.filter(tag_v3_id=item["tag_v3_id"], diary_id__in=dids).delete()
            diary_ids = diary_ids[BATCH_COUNT:]

    for item in add_relations:

        diary_ids = list(DiaryTag.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_id=item["tag_id"]).values_list("diary_id", flat=True))
        diary_ids_v3 = list(DiaryTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            diary_id__in=diary_ids, tag_v3_id=item["tag_v3_id"]).values_list("diary_id", flat=True))

        diary_ids = set(diary_ids) - set(diary_ids_v3)
        sync_tag_mapping_logger.info(json.dumps({"action": "add", "tag_v3_id": item["tag_v3_id"], "diary_ids": list(diary_ids)}))
        while diary_ids:
            dids = diary_ids[:BATCH_COUNT]
            DiaryTagV3.objects.bulk_create([
                DiaryTagV3(diary_id=diary_id, tag_v3_id=item["tag_v3_id"])
                for diary_id in dids
            ])
            diary_ids = diary_ids[BATCH_COUNT:]


@shared_task
def diary_topic_tag_mapping_sync(add_relations=None, delete_relations=None, delete2old_tag_ids=None):

    for item in delete_relations:

        problem_ids = list(ProblemTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_v3_id=item["tag_v3_id"]).values_list("problem_id", flat=True))
        problem_ids = list(ProblemTag.objects.using(settings.SLAVE_DB_NAME).filter(
            problem_id__in=problem_ids, tag_id=item["tag_id"]).values_list("problem_id", flat=True))

        # 过滤掉多个老标签映射到一个新标签的情况
        old_tag_ids = delete2old_tag_ids.get(str(item["tag_v3_id"])) if delete2old_tag_ids else []
        if old_tag_ids:
            exclude_problem_ids = list(ProblemTag.objects.using(settings.SLAVE_DB_NAME).filter(
                problem_id__in=problem_ids, tag_id__in=old_tag_ids).exclude(tag_id=item["tag_id"]).values_list("problem_id", flat=True))
            problem_ids = list(set(problem_ids) - set(exclude_problem_ids))

        sync_tag_mapping_logger.info(json.dumps({"action": "delete", "tag_v3_id": item["tag_v3_id"], "problem_ids": problem_ids}))
        while problem_ids:
            pids = problem_ids[:BATCH_COUNT]
            ProblemTagV3.objects.filter(tag_v3_id=item["tag_v3_id"], problem_id__in=pids).delete()
            problem_ids = problem_ids[BATCH_COUNT:]

    for item in add_relations:

        problem_ids = list(ProblemTag.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_id=item["tag_id"]).values_list("problem_id", flat=True))
        problem_ids_v3 = list(ProblemTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            problem_id__in=problem_ids, tag_v3_id=item["tag_v3_id"]).values_list("problem_id", flat=True))

        problem_ids = set(problem_ids) - set(problem_ids_v3)
        sync_tag_mapping_logger.info(json.dumps({"action": "add", "tag_v3_id": item["tag_v3_id"], "problem_ids": list(problem_ids)}))
        while problem_ids:
            pids = problem_ids[:BATCH_COUNT]
            ProblemTagV3.objects.bulk_create([
                ProblemTagV3(problem_id=problem_id, tag_v3_id=item["tag_v3_id"])
                for problem_id in pids
            ])
            problem_ids = problem_ids[BATCH_COUNT:]


@shared_task
def tractate_tag_mapping_sync(add_relations=None, delete_relations=None, delete2old_tag_ids=None):

    for item in delete_relations:

        tractate_ids = list(TractateTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_v3_id=item["tag_v3_id"]).values_list("tractate_id", flat=True))
        tractate_ids = list(TractateTag.objects.using(settings.SLAVE_DB_NAME).filter(
            tractate_id__in=tractate_ids, tag_id=item["tag_id"]).values_list("tractate_id", flat=True))

        # 过滤掉多个老标签映射到一个新标签的情况
        old_tag_ids = delete2old_tag_ids.get(str(item["tag_v3_id"])) if delete2old_tag_ids else []
        if old_tag_ids:
            exclude_tractate_ids = list(TractateTag.objects.using(settings.SLAVE_DB_NAME).filter(
                tractate_id__in=tractate_ids, tag_id__in=old_tag_ids).exclude(tag_id=item["tag_id"]).values_list("tractate_id", flat=True))
            tractate_ids = list(set(tractate_ids) - set(exclude_tractate_ids))

        sync_tag_mapping_logger.info(json.dumps({"action": "delete", "tag_v3_id": item["tag_v3_id"], "tractate_ids": tractate_ids}))
        while tractate_ids:
            tids = tractate_ids[:BATCH_COUNT]
            TractateTagV3.objects.filter(tag_v3_id=item["tag_v3_id"], tractate_id__in=tids).delete()
            tractate_ids = tractate_ids[BATCH_COUNT:]

    for item in add_relations:

        tractate_ids = list(TractateTag.objects.using(settings.SLAVE_DB_NAME).filter(
            tag_id=item["tag_id"]).values_list("tractate_id", flat=True))
        tractate_ids_v3 = list(TractateTagV3.objects.using(settings.SLAVE_DB_NAME).filter(
            tractate_id__in=tractate_ids, tag_v3_id=item["tag_v3_id"]).values_list("tractate_id", flat=True))

        tractate_ids = set(tractate_ids) - set(tractate_ids_v3)
        sync_tag_mapping_logger.info(json.dumps({"action": "add", "tag_v3_id": item["tag_v3_id"], "tractate_ids": list(tractate_ids)}))
        while tractate_ids:
            tids = tractate_ids[:BATCH_COUNT]
            TractateTagV3.objects.bulk_create([
                TractateTagV3(tractate_id=tractate_id, tag_v3_id=item["tag_v3_id"])
                for tractate_id in tids
            ])
            tractate_ids = tractate_ids[BATCH_COUNT:]
