#!/usr/bin/env python
# -*- coding: utf-8 -*-
from collections import defaultdict
from io import StringIO
import json
import math
import time
import random
from itertools import groupby, chain
from multiprocessing import Process, JoinableQueue

import itertools
from gm_types.gaia import (
    TAG_V3_TYPE,
)
from django.core.management import BaseCommand
from django.conf import settings

from tags.models.tag import (
    TagV3,
    TagMapOldTag,
)
from talos.models.diary import (
    Diary,
    DiaryTagV3,
    DiaryTag,
)
from talos.models.tractate import (
    Tractate,
    TractateTagV3,
    TractateTag,
)
from qa.models.answer import (
    Answer,
    AnswerTagV3,
    AnswerTag,
    QuestionTag,
    QuestionTagV3, Question)

from utils.rpc import get_rpc_invoker, logging_exception
from utils.group_routine import GroupRoutine


rpc_client = get_rpc_invoker()


class TagMapTools(object):
    """
    标签映射工具
    """
    step = 100
    BATCH_SIZE = 1000
    each_rpc_doris_num = 20
    RPC_URL = 'doris/search/content_tagv3_tokenizer'
    file_path = "/tmp/"
    # file_path = ""

    @classmethod
    def get_all_item_tag_v3_ids(cls):
        """
        获取所有的项目标签
        :return:
        """
        so = StringIO()
        tag_v3_ids_data_from_io = so.getvalue()
        if tag_v3_ids_data_from_io:
            print("tag data from stringIO")
            tag_v3_ids = list(map(int, tag_v3_ids_data_from_io.split(",")))
        else:
            print("tag data from sql")
            tag_v3_ids = list(TagV3.objects.using(settings.ZHENGXING_DB).filter(
                is_online=True,
                tag_type__in=[TAG_V3_TYPE.NORMAL, TAG_V3_TYPE.FIRST_CLASSIFY, TAG_V3_TYPE.SECOND_CLASSIFY]
            ).values_list("id", flat=True))

            so.write(",".join(map(str, tag_v3_ids)))

        return tag_v3_ids

    @classmethod
    def get_tag_id_map_tag_v3_ids_relation(cls, old_tag_ids):
        """
        获取新老标签的映射关系
        :return:
        """
        result = defaultdict(set)

        map_ids = TagMapOldTag.objects.using(settings.ZHENGXING_DB).filter(
            old_tag_id__in=old_tag_ids
        ).values_list("tag_id", "old_tag_id")

        for tag_v3_id, tag_v1_id in map_ids:
            result[tag_v1_id].add(tag_v3_id)

        return dict(result)

    @classmethod
    def get_content_map_old_tag_ids(cls, old_tag_model, field, content_ids, tag_field):
        """内容与老标签映射"""
        query = {'{}__in'.format(field): content_ids}
        map_ids = set(old_tag_model.objects.using(settings.SLAVE_DB_NAME).filter(
            **query
        ).values_list(*[field, tag_field]))

        old_tag_map_content = defaultdict(set)
        for content_id, tag_v1_id in map_ids:
            old_tag_map_content[tag_v1_id].add(content_id)
        return dict(old_tag_map_content)

    @classmethod
    def request_task(cls, _ids, content_type):
        result = {
            "success": {},
            "error": []
        }
        try:
            _result = rpc_client[cls.RPC_URL](
                content_id_list=_ids, content_type=content_type
            ).unwrap()
            result["success"] = json.dumps(_result)

        except:
            logging_exception()
            result["error"] = _ids

        return result

    @classmethod
    def get_tag_map_result_from_doris(cls, content_ids, content_type):
        _keys = []
        result = {
            "rpc_err_ids": []
        }
        write_in_rpc_err_log = []
        _rpc_doris_num = TagMapTools.each_rpc_doris_num

        routine = GroupRoutine()
        for i in range(int(math.ceil(len(content_ids)/_rpc_doris_num))):
            _keys.append(i)

            routine.submit(
                i,
                cls.request_task,
                content_ids[i*_rpc_doris_num: (i+1)*_rpc_doris_num],
                content_type
            )

        routine.go()

        for key in _keys:
            _data = routine.results.get(key, {})
            if not _data:
                continue

            if _data.get("error"):
                result["rpc_err_ids"].extend(_data["error"])
                write_in_rpc_err_log.append(
                    "content_type:{},content_ids:{}\n".format(content_type, json.dumps(_data["error"]))
                )

            elif _data.get("success"):
                result.update(_data["success"])

            else:
                continue

        # rpc 错误日志
        if write_in_rpc_err_log:
            _file_path = cls.file_path + "level_lte3_{}_get_rel_tag_v3_from_doris_err_log.log".format(content_type)
            with open(_file_path, 'a+') as f:
                f.writelines(write_in_rpc_err_log)

        return result

    @classmethod
    def get_content_ids(cls, model_):
        """
        获取所有内容ID
        :param model_:
        :return:
        """
        content_ids = model_.objects.using(settings.SLAVE_DB_NAME).filter(
            is_online=True).values_list('id', flat=True)

        return content_ids

    @classmethod
    def get_content_max_id(cls, model_):
        return cls.get_content_ids(model_).last()

    @staticmethod
    def get_content_rel_tag_v3_ids(tag_model_, content_ids, rel_param, all_tag_ids):
        """
        获取内容关联的 3.0 标签
        :param tag_model_:
        :param content_ids:
        :param rel_param:
        :param all_tag_ids:
        :return:
        """
        _filters = {
            "{}__in".format(rel_param): content_ids
        }
        _fields = [rel_param, "tag_v3_id"]
        rel_tag_v3_ids = tag_model_.objects.using(settings.SLAVE_DB_NAME).filter(**_filters).values_list(*_fields)

        result = set()
        for content_id, tag3_id in rel_tag_v3_ids:
            result.add((content_id, tag3_id))

        return result

    @staticmethod
    def add_content_rel_tag_info(tag_model_, rel_bulk_list, rel_param):
        """
        批量创建
        :param tag_model_:
        :param rel_bulk_list: [{"rel_param": id, "tag_v3_id": 1}]
        :return:
        """
        print("will create nums", len(rel_bulk_list), rel_bulk_list)
        tag_model_.objects.bulk_create([tag_model_(**{
            rel_param: content_id, "tag_v3_id": tag3_id
        }) for content_id, tag3_id in rel_bulk_list])

    @classmethod
    def content_rel_tag_data_clean(cls, tag_model_, old_tag_model_, rel_param, content_type, content_ids, tags, tag_field):
        """
        内容关联标签数据清洗
        :param tag_model_: 标签类model
        :param old_tag_model_: 老标签类model
        :param rel_param: 关联字段
        :param content_type: 内容类型
        :param content_ids: 内容ID
        :param tags: 标签
        :param tag_field:
        :return:
        """
        write_in_log_str = "content_id:{}, content_type:{}, create_tags_map:{}\n"
        need_add_infos, write_in_log_list, log_info = set(), [], []
        _all_item_tags = tags
        # 记录日志
        log_result = defaultdict(set)

        rel_data_from_sql = cls.get_content_rel_tag_v3_ids(tag_model_, content_ids, rel_param, _all_item_tags)
        tag1_map_content = cls.get_content_map_old_tag_ids(old_tag_model_, rel_param, content_ids, tag_field)
        tag1_map_tag3 = cls.get_tag_id_map_tag_v3_ids_relation(list(tag1_map_content.keys()))
        for tag_v1_id in tag1_map_content.keys():
            content_ids = tag1_map_content.get(tag_v1_id)
            tag_v3_ids = tag1_map_tag3.get(tag_v1_id)
            if not all([content_ids, tag_v3_ids]):
                continue
            for content_id, tag3_id in itertools.product(content_ids, tag_v3_ids):
                if (content_id, tag3_id) in rel_data_from_sql:
                    continue
                need_add_infos.add((content_id, tag3_id))
                log_result[content_id].add(tag3_id)

        for content_id in content_ids:
            write_in_log_list.append(write_in_log_str.format(
                content_id,
                content_type,
                list(log_result.get(content_id, [])),
            ))

        if need_add_infos:
            cls.add_content_rel_tag_info(tag_model_, need_add_infos, rel_param)

        if write_in_log_list:
            _file_path = cls.file_path + "level_lte3_{}_tag_v3_item_clean_log.log".format(content_type)
            with open(_file_path, "a+") as f:
                f.writelines(write_in_log_list)


def producer(queue, model_, content_type, level_field):
    """
    生产者
    :param queue:
    :param model_:
    :param content_type:
    :param level_field:
    :return:
    """
    print("will handle {} data".format(content_type))

    transfer_id = 1

    count = TagMapTools.get_content_ids(model_).filter(pk__gte=transfer_id).count()
    _round_num = int(math.ceil(count / TagMapTools.BATCH_SIZE))

    # for i in range(1):
    level_interval = ['0', '1', '2'] if content_type == 'diary' else [0, 1, 2]
    for i in range(_round_num):
        query = {'pk__gte': transfer_id, '{}__in'.format(level_field): level_interval}
        nexts_data = TagMapTools.get_content_ids(model_).filter(
            **query).order_by("id")[:TagMapTools.BATCH_SIZE]

        if not nexts_data:

            break

        next_ids = list(nexts_data)
        print("{} is producing {} nums ids to the queue".format(content_type, len(next_ids)))

        queue.put(next_ids)

        next_id = max(next_ids)
        transfer_id = next_id + 1

        time.sleep(0.2 * random.random())

    queue.join()  # 生产完毕，使用此方法进行阻塞，直到队列中所有项目均被处理。

    print("{} producer finished".format(content_type))


def consumer(queue, tag_model_, old_tag_model_, content_type, rel_param, tags, tag_field):
    """
    消费者
    :param queue:
    :param tag_model_:
    :param old_tag_model_: 老标签表
    :param content_type:
    :param rel_param:
    :param tags:
    :return:
    """
    while True:
        content_ids = queue.get()
        if not content_ids:
            break

        print("{} is consuming. {} nums in the queue is consumed!".format(content_type, len(content_ids)))

        _step = TagMapTools.step
        for i in range(int(math.ceil(len(content_ids)/_step))):
            _new_content_ids = content_ids[i*_step: (i+1)*_step]
            TagMapTools.content_rel_tag_data_clean(
                tag_model_,
                old_tag_model_,
                rel_param,
                content_type,
                _new_content_ids,
                tags,
                tag_field,
            )
            time.sleep(0.5 * random.random())

        queue.task_done()  # 向q.join()发送一次信号,证明一个数据已经被取走

    print("{} consumer finished".format(content_type))


class Command(BaseCommand):
    """
    python django_manage.py level_lte3_content_map_tag --content_type
    内容类三星以下等级新标签清洗
    """

    all_item_tag_ids = []
    old_tag_map_tag_v3s = {}
    content_type_map = {
        "diary": (Diary, DiaryTagV3, "diary_id", DiaryTag, 'content_level', 'tag_id'),
        "answer": (Answer, AnswerTagV3, "answer_id", AnswerTag, 'level', 'tag'),
        "tractate": (Tractate, TractateTagV3, "tractate_id", TractateTag, 'content_level', 'tag_id'),
    }

    def add_arguments(self, parser):
        parser.add_argument(
            '--content_type',
            help=u'内容类型(单选), choice is diary/answer/tractate ...'
        )

    def handle(self, *args, **options):

        content_type = options['content_type']
        if content_type not in self.content_type_map:
            print("内容参数有误，请重新输入")
            return

        print("START")
        start_time = time.time()

        _producer_list, _consumer_list = [], []
        model_, tag_model_, _rel_tag_param, old_tag_model_, level_field, tag_field = self.content_type_map[content_type]

        self.all_item_tag_ids = TagMapTools.get_all_item_tag_v3_ids()

        # 队列
        queue = JoinableQueue()

        # 创建生产者
        for i in range(1):
            p = Process(target=producer, args=(queue, model_, content_type, level_field))
            _producer_list.append(p)

        # 创建消费者
        for j in range(3):
            args = (queue, tag_model_, old_tag_model_, content_type, _rel_tag_param, self.all_item_tag_ids, tag_field)
            c = Process(target=consumer, args=args)
            c.daemon = True  # 设置为守护进程，随主进程结束而结束
            _consumer_list.append(c)

        # 启动
        for s in chain(_producer_list, _consumer_list):
            s.start()

        # 阻塞主进程，等待所有子进程结束
        # 主进程等---> p1 等---->c1,c2
        for s in _producer_list:
            s.join()

        end_time = time.time()
        print("END")
        print("total_time: %s s" % int(math.ceil(end_time - start_time)))
