#!/usr/bin/env python
# -*- coding: utf-8 -*-

from io import StringIO
import json
import math
import time
import random
from itertools import groupby, chain
from multiprocessing import Process, JoinableQueue

from gm_types.gaia import (
    TAG_V3_TYPE,
)
from django.core.management import BaseCommand
from django.conf import settings

from tags.models.tag import (
    TagV3,
    TagMapOldTag,
)
from talos.cache.base import tag_map_tag3_record
from talos.models.diary import (
    Diary,
    DiaryTagV3,
)
from talos.models.topic import (
    Problem,
    ProblemTag,
    ProblemTagV3,
)
from talos.models.tractate import (
    Tractate,
    TractateTagV3,
)
from qa.models.answer import (
    Answer,
    AnswerTagV3,
    Question,
    QuestionTagV3,
)

from utils.rpc import get_rpc_invoker
from utils.group_routine import GroupRoutine


rpc_client = get_rpc_invoker()


class TagMapTools(object):
    """
    标签映射工具
    """
    step = 100
    BATCH_SIZE = 1000
    each_rpc_doris_num = 20
    RPC_URL = 'doris/search/content_tagv3_info'
    file_path = "/tmp/"
    # file_path = ""

    @classmethod
    def get_all_item_tag_v3_ids(cls):
        """
        获取所有的项目标签
        :return:
        """
        so = StringIO()
        tag_v3_ids_data_from_io = so.getvalue()
        if tag_v3_ids_data_from_io:
            print("tag data from stringIO")
            tag_v3_ids = list(map(int, tag_v3_ids_data_from_io.split(",")))
        else:
            print("tag data from sql")
            tag_v3_ids = list(TagV3.objects.using(settings.ZHENGXING_DB).filter(
                is_online=True,
                tag_type__in=[TAG_V3_TYPE.NORMAL, TAG_V3_TYPE.FIRST_CLASSIFY, TAG_V3_TYPE.SECOND_CLASSIFY]
            ).values_list("id", flat=True))

            so.write(",".join(map(str, tag_v3_ids)))

        return tag_v3_ids

    @classmethod
    def get_tag_id_map_tag_v3_ids_relation(cls):
        """
        获取新老标签的映射关系
        :return:
        """
        result = {}
        all_item_tag_v3_ids = cls.get_all_item_tag_v3_ids()
        _round_num = int(math.ceil(len(all_item_tag_v3_ids)/cls.step))

        for i in range(_round_num):
            _tag_v3_ids = all_item_tag_v3_ids[i * cls.step: (i + 1) * cls.step]
            map_ids = TagMapOldTag.objects.using(settings.ZHENGXING_DB).filter(
                tag_id__in=_tag_v3_ids
            ).values_list("tag_id", "old_tag_id")

            for tag_v1_id, items in groupby(sorted(map_ids, key=lambda x: x[1]), key=lambda j: j[1]):
                result[tag_v1_id] = [item[0] for item in items]

        print(result)
        return result

    @classmethod
    def request_task(cls, _ids, content_type):
        result = {
            "success": {},
            "error": []
        }
        try:
            _result = rpc_client[cls.RPC_URL](
                content_id_list=_ids, content_type=content_type
            ).unwrap()

            # json序列化后的数据
            # result = "{\"20895845\": [202], \"20895848\": [203]}" {content_id: []project_tags}
            result["success"] = json.loads(_result)

        except Exception as e:
            result["error"] = _ids

        return result

    @classmethod
    def get_tag_map_result_from_doris(cls, content_ids, content_type):
        _keys = []
        result = {
            "rpc_err_ids": []
        }
        write_in_rpc_err_log = []
        _rpc_doris_num = TagMapTools.each_rpc_doris_num

        routine = GroupRoutine()
        for i in range(int(math.ceil(len(content_ids)/_rpc_doris_num))):
            _keys.append(i)

            routine.submit(
                i,
                cls.request_task,
                content_ids[i*_rpc_doris_num: (i+1)*_rpc_doris_num],
                content_type
            )

        routine.go()

        for key in _keys:
            _data = routine.results.get(key, {})
            if not _data:
                continue

            if _data.get("error"):
                result["rpc_err_ids"].extend(_data["error"])
                write_in_rpc_err_log.append(
                    "content_type:{},content_ids:{}\n".format(content_type, json.dumps(_data["error"]))
                )

            elif _data.get("success"):
                result.update(_data["success"])

            else:
                continue

        # rpc 错误日志
        if write_in_rpc_err_log:
            _file_path = cls.file_path + "{}_get_rel_tag_v3_from_doris_err_log.log".format(content_type)
            with open(_file_path, 'a+') as f:
                f.writelines(write_in_rpc_err_log)

        return result

    @classmethod
    def get_content_ids(cls, model_):
        """
        获取所有内容ID
        :param model_:
        :return:
        """
        content_ids = model_.objects.using(settings.SLAVE_DB_NAME).filter(
            is_online=True).values_list('id', flat=True)

        return content_ids

    @classmethod
    def get_content_max_id(cls, model_):
        return cls.get_content_ids(model_).last()

    @staticmethod
    def get_content_rel_tag_v3_ids(tag_model_, content_ids, rel_param, all_tag_ids):
        """
        获取内容关联的 3.0 标签
        :param tag_model_:
        :param content_ids:
        :param rel_param:
        :param all_tag_ids:
        :return:
        """
        _filters = {
            "{}__in".format(rel_param): content_ids
        }
        _fields = [rel_param, "id", "tag_v3_id"]
        rel_tag_v3_ids = tag_model_.objects.using(settings.SLAVE_DB_NAME).filter(**_filters).values_list(*_fields)

        result = {}
        for _id, items in groupby(sorted(rel_tag_v3_ids, key=lambda x: x[0]), key=lambda x: x[0]):
            _rel_old_tag_v3_ids = [(item[1], item[2]) for item in items if item[2] in all_tag_ids]

            if not _rel_old_tag_v3_ids:
                continue

            result[_id] = _rel_old_tag_v3_ids

        return result

    @classmethod
    def get_content_reco_tag_v3_ids(cls, old_tag_model, content_ids, rel_param, map_tags):
        """
        通过新老标签映射，获取内容推荐的新标签
        :param old_tag_model:
        :param content_ids:
        :param rel_param:
        :param map_tags:
        :return:
        """
        _filters = {
            "{}__in".format(rel_param): content_ids
        }
        _fields = [rel_param, "tag_id"]
        rel_tag_ids = old_tag_model.objects.using(settings.SLAVE_DB_NAME).filter(**_filters).values_list(*_fields)

        result = {}
        for _id, items in groupby(sorted(rel_tag_ids, key=lambda x: x[0]), key=lambda x: x[0]):

            result[_id] = list(chain.from_iterable(map_tags.get(item[1], []) for item in items))

        return result

    @staticmethod
    def del_content_rel_tag_info(tag_model_, rel_ids):
        """
        删除关联关系
        :param tag_model_:
        :param rel_ids:
        :return:
        """
        print("will delete rel_pks nums", len(rel_ids))
        tag_model_.objects.filter(pk__in=rel_ids).delete()

    @staticmethod
    def add_content_rel_tag_info(tag_model_, rel_bulk_list):
        """
        批量创建
        :param tag_model_:
        :param rel_bulk_list: [{"rel_param": id, "tag_v3_id": 1}]
        :return:
        """
        print("will create nums", len(rel_bulk_list))
        tag_model_.objects.bulk_create([tag_model_(**item) for item in rel_bulk_list])

    @classmethod
    def content_rel_tag_data_clean(cls, tag_model_, old_tag_model_, rel_param, content_type, content_ids, tags):
        """
        内容关联标签数据清洗
        :param tag_model_: 标签类model
        :param old_tag_model_: 老标签类model
        :param rel_param: 关联字段
        :param content_type: 内容类型
        :param content_ids: 内容ID
        :param tags: 标签
        :return:
        """
        write_in_log_str = "content_id:{},content_type:{},old_rel_item_tags:{},new_rel_item_tags:{},del_item_tags:{},add_item_tags:{}\n"

        if content_type == "topic":
            reco_tags_dic = cls.get_content_reco_tag_v3_ids(old_tag_model_, content_ids, rel_param, tags)
            _err_content_ids = []
            _all_item_tags = []
        else:
            _all_item_tags = tags
            reco_tags_dic = cls.get_tag_map_result_from_doris(content_ids, content_type)
            _err_content_ids = reco_tags_dic.pop("rpc_err_ids", [])

        rel_data_from_sql = cls.get_content_rel_tag_v3_ids(tag_model_, content_ids, rel_param, _all_item_tags)

        print("new reco data nums {}".format(len(reco_tags_dic)))
        print("old rel data from sql nums {}".format(len(rel_data_from_sql)), rel_data_from_sql)

        need_del_infos, need_add_infos, write_in_log_list = [], [], []

        for content_id in content_ids:
            _reco_project_tags = set(reco_tags_dic.get(content_id) or reco_tags_dic.get(str(content_id)) or [])
            _old_rel_item_tags = rel_data_from_sql.get(content_id, [])

            # print("content {} rel old tag_v3_ids {} reco tag_v3_ids {}".format(content_id, _old_rel_item_tags, _reco_project_tags))

            if content_id in _err_content_ids or not any([_reco_project_tags, _old_rel_item_tags]):
                write_in_log_list.append(write_in_log_str.format(content_id, content_type, [], [], [], []))
                continue

            _rel_item_tags = set()
            for rel_pk_id, rel_old_tag_v3_id in _old_rel_item_tags:
                _rel_item_tags.add(rel_old_tag_v3_id)

                if rel_old_tag_v3_id not in _reco_project_tags:
                    need_del_infos.append(rel_pk_id)

            _need_add_rel_tags = _reco_project_tags - _rel_item_tags
            for tag_id in _need_add_rel_tags:
                need_add_infos.append({
                    rel_param: content_id,
                    "tag_v3_id": tag_id,
                })

            write_in_log_list.append(write_in_log_str.format(
                content_id,
                content_type,
                list(_rel_item_tags),
                list(_reco_project_tags),
                list(_rel_item_tags - _reco_project_tags),
                list(_need_add_rel_tags)
            ))

        if need_del_infos:
            cls.del_content_rel_tag_info(tag_model_, need_del_infos)

        if need_add_infos:
            cls.add_content_rel_tag_info(tag_model_, need_add_infos)

        if write_in_log_list:
            _file_path = cls.file_path + "{}_tag_v3_item_clean_log.log".format(content_type)
            with open(_file_path, "a+") as f:
                f.writelines(write_in_log_list)


def producer(queue, model_, content_type):
    """
    生产者
    :param queue:
    :param model_:
    :param content_type:
    :return:
    """
    cache_key = "{}_clean_record_max_id"
    content_max_id = TagMapTools.get_content_max_id(model_)
    cache_content_max_id = int(tag_map_tag3_record.get(cache_key.format(content_type)) or 0)

    if content_max_id and content_max_id > cache_content_max_id:
        print("will handle {} data".format(content_type))

        if not cache_content_max_id:
            transfer_id = 1
        else:
            transfer_id = cache_content_max_id

        count = TagMapTools.get_content_ids(model_).filter(pk__gte=transfer_id).count()
        _round_num = int(math.ceil(count / TagMapTools.BATCH_SIZE))

        # for i in range(1):
        for i in range(_round_num):
            nexts_data = TagMapTools.get_content_ids(model_).filter(
                pk__gte=transfer_id).order_by("id")[:TagMapTools.BATCH_SIZE]

            if not nexts_data:
                break

            next_ids = list(nexts_data)
            print("{} is producing {} nums ids to the queue".format(content_type, len(next_ids)))

            queue.put(next_ids)

            next_id = max(next_ids)
            transfer_id = next_id + 1

            if i + 1 == _round_num:
                tag_map_tag3_record.set(cache_key.format(content_type), next_id)

            time.sleep(0.2 * random.random())
    else:
        print("{} not need to deal".format(content_type))

    queue.join()  # 生产完毕，使用此方法进行阻塞，直到队列中所有项目均被处理。

    print("{} producer finished".format(content_type))


def consumer(queue, tag_model_, old_tag_model_, content_type, rel_param, tags):
    """
    消费者
    :param queue:
    :param tag_model_:
    :param old_tag_model_: 老标签表
    :param content_type:
    :param rel_param:
    :param tags:
    :return:
    """
    while True:
        content_ids = queue.get()
        if not content_ids:
            break

        print("{} is consuming. {} nums in the queue is consumed!".format(content_type, len(content_ids)))

        _step = TagMapTools.step
        for i in range(int(math.ceil(len(content_ids)/_step))):
            _new_content_ids = content_ids[i*_step: (i+1)*_step]
            TagMapTools.content_rel_tag_data_clean(
                tag_model_,
                old_tag_model_,
                rel_param,
                content_type,
                _new_content_ids,
                tags
            )
            time.sleep(0.5 * random.random())

        queue.task_done()  # 向q.join()发送一次信号,证明一个数据已经被取走

    print("{} consumer finished".format(content_type))


class Command(BaseCommand):
    """
    python django_manage.py content_tag_map_v2 --content_type
    """

    all_item_tag_ids = []
    old_tag_map_tag_v3s = {}
    content_type_map = {
        "diary": (Diary, DiaryTagV3, "diary_id", None),
        "topic": (Problem, ProblemTagV3, "problem_id", ProblemTag),
        "question": (Question, QuestionTagV3, "question_id", None),
        "answer": (Answer, AnswerTagV3, "answer_id", None),
        "tractate": (Tractate, TractateTagV3, "tractate_id", None),
    }

    def add_arguments(self, parser):
        parser.add_argument(
            '--content_type',
            help=u'内容类型(单选), choice is diary/topic/question/answer/tractate ...'
        )

    def handle(self, *args, **options):

        content_type = options['content_type']
        if content_type not in self.content_type_map:
            print("内容参数有误，请重新输入")
            return

        print("START")
        start_time = time.time()

        _producer_list, _consumer_list = [], []
        model_, tag_model_, _rel_tag_param, old_tag_model_ = self.content_type_map[content_type]

        self.all_item_tag_ids = TagMapTools.get_all_item_tag_v3_ids()
        if content_type == "topic":
            self.old_tag_map_tag_v3s = TagMapTools.get_tag_id_map_tag_v3_ids_relation()

        # 队列
        queue = JoinableQueue()

        # 创建生产者
        for i in range(1):
            p = Process(target=producer, args=(queue, model_, content_type))
            _producer_list.append(p)

        # 创建消费者
        for j in range(3):
            if content_type == "topic":
                args = (queue, tag_model_, old_tag_model_, content_type, _rel_tag_param, self.old_tag_map_tag_v3s)
            else:
                args = (queue, tag_model_, old_tag_model_, content_type, _rel_tag_param, self.all_item_tag_ids)
            c = Process(target=consumer, args=args)
            c.daemon = True  # 设置为守护进程，随主进程结束而结束
            _consumer_list.append(c)

        # 启动
        for s in chain(_producer_list, _consumer_list):
            s.start()

        # 阻塞主进程，等待所有子进程结束
        # 主进程等---> p1 等---->c1,c2
        for s in _producer_list:
            s.join()

        end_time = time.time()
        print("END")
        print("total_time: %s s" % int(math.ceil(end_time - start_time)))
