#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
import math
from itertools import groupby, chain

from django.conf import settings
from django.core.management import BaseCommand

from gm_types.gaia import (
    TAG_V3_TYPE,
    TAG_TYPE_ATTR,
)

from agile.models import (
    TagV3,
    AttrTag,
)

from launch.models import (
    QuestionnaireAnswerRelationAttrTag,
    QuestionnaireAnswerRelationTagV3,
)
from rpc.tool.queryset_tool import big_qs_iter


class Command(BaseCommand):
    """
    kyc 属性标签ID 转为 TAG_V3_ID
    python manage.py sync_questionnaire_answer_attr_tag_to_v3_tag
    """
    step = 100

    @staticmethod
    def get_tag_type_from_attr_type(attr_type):
        if attr_type == TAG_TYPE_ATTR.FIRST_SYMPTOM:
            return TAG_V3_TYPE.FIRST_SYMPTOM

        elif attr_type == TAG_TYPE_ATTR.SYMPTOM:
            return TAG_V3_TYPE.SECOND_SYMPTOM

        elif attr_type in (TAG_TYPE_ATTR.MODE, TAG_TYPE_ATTR.FIRST_BRAND):
            return TAG_V3_TYPE.FIRST_BRAND

        elif attr_type == TAG_TYPE_ATTR.BRAND:
            return TAG_V3_TYPE.SECOND_BRAND

        elif attr_type == TAG_TYPE_ATTR.MACROSCOPIC_MODE:
            return TAG_V3_TYPE.MACROSCOPIC_MODE

        elif attr_type == TAG_TYPE_ATTR.FIRST_APPEAL:
            return TAG_V3_TYPE.FIRST_APPEAL

        elif attr_type == TAG_TYPE_ATTR.SECOND_APPEAL:
            return TAG_V3_TYPE.SECOND_APPEAL

        elif attr_type == TAG_TYPE_ATTR.FIRST_POSITION:
            return TAG_V3_TYPE.FIRST_POSITION

        elif attr_type == TAG_TYPE_ATTR.POSITION:
            return TAG_V3_TYPE.SECOND_POSITION

        elif attr_type == TAG_TYPE_ATTR.DRUG:
            return TAG_V3_TYPE.DRUG

        elif attr_type == TAG_TYPE_ATTR.INSTRUMENT:
            return TAG_V3_TYPE.INSTRUMENT

        elif attr_type == TAG_TYPE_ATTR.CONSUMABLES:
            return TAG_V3_TYPE.CONSUMABLES

        elif attr_type == TAG_TYPE_ATTR.MATERIAL:
            return TAG_V3_TYPE.MATERIAL

        else:
            return None

    def _get_tag_from_attr(self, attr):
        tag_type = self.get_tag_type_from_attr_type(attr.aggregate_type)
        if not tag_type:
            return
        return TagV3.objects.filter(name=attr.name, tag_type=tag_type).only("id").first()

    @staticmethod
    def get_kyc_answer_rel_tag_v3_ids(ans_ids):
        """
        获取kyc回答已关联的3.0标签
        :param ans_ids:
        :return:
        """
        rel_tag_v3_innfos = QuestionnaireAnswerRelationTagV3.objects.using(
            settings.SLAVE_DB_NAME
        ).filter(questionnaire_answer_id__in=ans_ids).values_list("questionnaire_answer_id", "tag_v3_id")

        result = {}
        for ans_id, items in groupby(sorted(rel_tag_v3_innfos, key=lambda item: item[0]), key=lambda item: item[0]):
            result[ans_id] = [item[1] for item in items]

        return result

    def handle(self, *args, **options):
        print("START")

        start_time = time.time()

        rel_attr_tag_infos = QuestionnaireAnswerRelationAttrTag.objects.using(
            settings.SLAVE_DB_NAME
        ).values_list("questionnaire_answer_id", "attr_tag_id").order_by("questionnaire_answer_id")

        rel_attr_infos = {}
        for _answer_id, items in groupby(big_qs_iter(rel_attr_tag_infos), key=lambda x: x[0]):
            rel_attr_infos[_answer_id] = [item[1] for item in items]

        attr_tag_ids = set(chain.from_iterable(rel_attr_infos.values()))
        attr_tags_info_from_sql = AttrTag.objects.filter(
            pk__in=attr_tag_ids
        ).using(settings.SLAVE_DB_NAME).only("name", "aggregate_type", "id")

        rel_attr_tag_map_v3_tag = {}
        for attr_obj in big_qs_iter(attr_tags_info_from_sql):
            time.sleep(0.05)

            _tag_obj = self._get_tag_from_attr(attr_obj)

            if not _tag_obj:
                continue

            rel_attr_tag_map_v3_tag.update({
                attr_obj.id: _tag_obj.id
            })

        print(rel_attr_tag_map_v3_tag)
        has_rel_tag_v3_infos = self.get_kyc_answer_rel_tag_v3_ids(list(rel_attr_infos.keys()))

        bulk_create_list = []
        bulk_create_nums = 0
        for ans_id, attr_ids in rel_attr_infos.items():

            has_rel_tag_v3_ids = has_rel_tag_v3_infos.get(ans_id, [])
            for attr_id in attr_ids:
                _rel_tag_v3_id = rel_attr_tag_map_v3_tag.get(attr_id, 0)
                if not _rel_tag_v3_id or _rel_tag_v3_id in has_rel_tag_v3_ids:
                    continue

                bulk_create_list.append(QuestionnaireAnswerRelationTagV3(
                    questionnaire_answer_id=ans_id,
                    tag_v3_id=_rel_tag_v3_id
                ))

            if len(bulk_create_list) >= self.step:
                QuestionnaireAnswerRelationTagV3.objects.bulk_create(
                    bulk_create_list
                )
                bulk_create_nums += len(bulk_create_list)
                bulk_create_list = []

        if bulk_create_list:
            QuestionnaireAnswerRelationTagV3.objects.bulk_create(
                bulk_create_list
            )
            bulk_create_nums += len(bulk_create_list)

        print('Done! cost {} s bulk create nums {}'.format(time.time() - start_time, bulk_create_nums))
        print("END")
