#!/usr/bin/env python
# -*- coding: utf-8 -*-

import time
from itertools import groupby, chain
from django.conf import settings
from django.core.management import BaseCommand

from agile.models import TagMapOldTag
from launch.models import (
    QuestionnaireAnswer,
    QuestionnaireAnswerRelationTag,
    QuestionnaireAnswerRelationTagV3,
)


def big_data_iter(qs, fetch_num=100):
    """
    大数据截断处理
    :param qs: 数据 列表
    :param fetch_num: 每次处理数量
    :return: list
    """
    bgn = 0
    total_nums = len(qs)
    while bgn <= total_nums:
        iter_list = qs[bgn: bgn + fetch_num]
        if not iter_list:
            break
        yield iter_list
        bgn += fetch_num


class Command(BaseCommand):
    """
    标签-标签3.0
    """

    # def add_arguments(self, parser):
    #     parser.add_argument(
    #         '--type',
    #         help=u'内容类型(单选), choice is group/polymer/tab ...'
    #     )

    def get_questionnaireanswer_rel_old_tag_ids(self, questionnaireanswer_ids):
        answer_rel_old_tags = QuestionnaireAnswerRelationTag.objects.using(settings.SLAVE_DB_NAME).filter(
            questionnaire_answer_id__in=questionnaireanswer_ids
        ).values_list("questionnaire_answer_id", "tag_id")

        old_map_dic = {
            k: [x[1] for x in v] for k, v in groupby(answer_rel_old_tags, lambda item: item[0])
        }

        return old_map_dic

    def get_tag_old_new_map_dic(self, old_tag_ids):
        tag_old_new_map_ids = TagMapOldTag.objects.using(settings.SLAVE_DB_NAME).filter(
            old_tag_id__in=old_tag_ids).values_list("old_tag_id", "tag_id")

        tag_old_new_map_dic = {
            k: [x[1] for x in v] for k, v in groupby(tag_old_new_map_ids, lambda item: item[0])
        }
        return tag_old_new_map_dic

    def handle(self, *args, **options):
        print("START")

        start_time = time.time()
        questionnaireanswer_ids = QuestionnaireAnswer.objects.using(settings.SLAVE_DB_NAME).values_list("id")

        for _ids in big_data_iter(questionnaireanswer_ids):
            ans_rel_tag_map = self.get_questionnaireanswer_rel_old_tag_ids(_ids)
            tag_old_new_map = self.get_tag_old_new_map_dic(chain.from_iterable(ans_rel_tag_map.values()))

            bulk_create_list = []
            for ans_id, old_tag_ids in ans_rel_tag_map.items():
                new_tag_ids = set(chain.from_iterable(tag_old_new_map.get(old_tag_id, []) for old_tag_id in old_tag_ids))

                bulk_create_list.extend((QuestionnaireAnswerRelationTagV3(
                    questionnaire_answer_id=ans_id,
                    tag_v3_id=new_tag_id
                )) for new_tag_id in new_tag_ids)

            if bulk_create_list:
                QuestionnaireAnswerRelationTagV3.objects.bulk_create(bulk_create_list)

        print('Done! cost {} s'.format(time.time() - start_time))
        print("END")
