#!/usr/bin/env python
# -*- coding: utf-8 -*-

import math
import time
import json
import random
from threading import Thread
from queue import Queue

from django.core.management import BaseCommand
from django.conf import settings

from gm_types.doris import (
    CONTENT_AGGRE_CARD_TYPE,
)
from gm_types.gaia import (
    TOPIC_TYPE,
)
from gm_types.mimas import (
    QUICK_SEARCH_CONTENT_DIVISION,
)
from utils.rpc import get_rpc_invoker

from communal.models import QuickSearchContentKeyword
from qa.models import (
    Answer,
    Question
)
from talos.models.topic import Problem
from talos.models.tractate import Tractate

rpc_client = get_rpc_invoker()

is_finished = False


def sync_content_quick_search(content_ids, content_type):
    """
    同步内容快速搜索词
    :param content_ids:
    :param content_type:
    :return:
    """
    def keywords_check(keywords):
        """
        关键字校验,索引列表仅保留一个有效位置
        :param keywords:
        :return:
        """
        _index_map_set = set()
        _result = []
        for item in keywords:
            keyword_index_list = list(map(tuple, item.pop("index", [])))
            _index_map_diff = set(keyword_index_list) - _index_map_set

            if not _index_map_diff:
                continue

            # 更新位置暂存库
            _index_map_set.update(set(keyword_index_list))
            item["index"] = sorted(_index_map_diff, key=keyword_index_list.index)[0]  # 注意此处是个元组
            _result.append(item)

        return _result

    mapping_dic = {
        QUICK_SEARCH_CONTENT_DIVISION.TRACTATE: CONTENT_AGGRE_CARD_TYPE.TRACTATE,
        QUICK_SEARCH_CONTENT_DIVISION.TOPIC: CONTENT_AGGRE_CARD_TYPE.DIARY,
        QUICK_SEARCH_CONTENT_DIVISION.QUESTION: CONTENT_AGGRE_CARD_TYPE.QA,
        QUICK_SEARCH_CONTENT_DIVISION.ANSWER: CONTENT_AGGRE_CARD_TYPE.ANSWER,
        QUICK_SEARCH_CONTENT_DIVISION.ARTICLE: CONTENT_AGGRE_CARD_TYPE.ARTICLE,
    }

    per_num = 100
    success_create_num = 0

    if not content_ids:
        return

    in_sql_ids = set(QuickSearchContentKeyword.objects.filter(
        content_id__in=content_ids,
        content_type=content_type
    ).using(settings.SLAVE_DB_NAME).values_list("content_id", flat=True))

    need_update_content_ids = list(set(content_ids) - in_sql_ids)
    for i in range(int(math.ceil(len(need_update_content_ids)/per_num))):
        bulk_create_list = []
        stn = i * per_num
        _ids = need_update_content_ids[stn: stn + per_num]

        time.sleep(0.5 * random.random())  # 让接口歇一会儿(*^▽^*)
        try:
            key_words_dic = rpc_client["doris/search/get_content_search_highlight"](
                content_ids=_ids,
                content_type=mapping_dic.get(content_type)
            ).unwrap()
        except:
            print(" doris rpc error")
            key_words_dic = {}

        keywords_dic = json.loads(key_words_dic.get("content_highlight_list") or '{}')

        if keywords_dic:
            for content_id in _ids:
                keywords = keywords_dic.get(str(content_id), [])
                if not keywords:
                    continue

                _data = {
                    "content_id": content_id,
                    "content_type": content_type,
                    "keywords": json.dumps(keywords_check(keywords)),
                }

                bulk_create_list.append(_data)

        if bulk_create_list:
            # 批量创建前再做一次校验
            _in_sql_ids = set(QuickSearchContentKeyword.objects.filter(
                content_id__in=_ids,
                content_type=content_type
            ).values_list("content_id", flat=True))

            print("filter can create ids in sql")
            bulk_create_list = list(filter(lambda item: item["content_id"] not in _in_sql_ids, bulk_create_list))

            QuickSearchContentKeyword.objects.bulk_create([
                QuickSearchContentKeyword(**item) for item in bulk_create_list
            ])
            success_create_num = len(bulk_create_list)

    return success_create_num


class ContentSearchThreadBase(Thread):

    def __init__(self, name, queue, model, content_type, step=1000):
        super(ContentSearchThreadBase, self).__init__(name=name)
        self.queue = queue
        self.model = model
        self.content_type = content_type
        self.step = step

    @property
    def get_ids(self):
        filters = {
            "is_online": True,
        }
        if self.content_type == QUICK_SEARCH_CONTENT_DIVISION.TOPIC:
            filters.update({
                "topic_type__in": [TOPIC_TYPE.ASK, TOPIC_TYPE.SHARE, TOPIC_TYPE.TOPIC],
            })

        elif self.content_type == QUICK_SEARCH_CONTENT_DIVISION.ARTICLE:
            filters.update({
                "topic_type__in": [TOPIC_TYPE.USER_ARTICLE, TOPIC_TYPE.COLUMN_ARTICLE],
            })

        return self.model.objects.filter(**filters).only('id').using(settings.SLAVE_DB_NAME)
        # return self.model.objects.filter(**filters).only('id')

    @property
    def get_min_id(self):
        return self.get_ids.first().id

    @property
    def get_max_id(self):
        return self.get_ids.last().id


class Producer(ContentSearchThreadBase):
    """
    生产者
    """

    def run(self):

        global is_finished
        count = self.get_ids.count()
        min_id = self.get_min_id

        print("{} total_count {}".format(self.content_type, count))

        transfer_id = min_id
        for i in range(int(math.ceil(count/self.step))):
            nexts_data = self.get_ids.filter(pk__gte=transfer_id).order_by("id")[:self.step]
            if not nexts_data:
                break

            next_id = list(nexts_data.values_list("id", flat=True))[-1]
            print("%s is producing %s to the queue!" % (self.getName(), (transfer_id, next_id)))
            self.queue.put((transfer_id, next_id))

            transfer_id = next_id + 1

        is_finished = True

        print("%s finished!" % self.getName())


class Consumer(ContentSearchThreadBase):
    """
    消费者
    """

    def executive_logic(self, start_pk_id, end_pk_id):
        """
        执行逻辑
        :return:
        """
        content_ids = list(self.get_ids.filter(pk__range=[start_pk_id, end_pk_id]).values_list("id", flat=True))
        nums = sync_content_quick_search(content_ids, self.content_type)
        return nums

    def run(self):
        global is_finished
        while True:
            if is_finished and self.queue.empty():
                break

            try:
                start_id, end_id = self.queue.get(timeout=0.2)
                print("%s is consuming. %s in the queue is consumed!" % (self.getName(), (start_id, end_id)))
            except:
                continue

            nums = self.executive_logic(start_id, end_id)
            print("bulk_create_nums", nums)

        print("%s finished!" % self.getName())


class Command(BaseCommand):
    """
    python django_manage.py initialization_content_quick_search_data --content_type
    内容-新标签映射 数据清洗
    """
    _input_content_type_map_dic = {
        "topic": (Problem, QUICK_SEARCH_CONTENT_DIVISION.TOPIC),
        "article": (Problem, QUICK_SEARCH_CONTENT_DIVISION.ARTICLE),
        "question": (Question, QUICK_SEARCH_CONTENT_DIVISION.QUESTION),
        "answer": (Answer, QUICK_SEARCH_CONTENT_DIVISION.ANSWER),
        "tractate": (Tractate, QUICK_SEARCH_CONTENT_DIVISION.TRACTATE),
    }

    def add_arguments(self, parser):
        parser.add_argument(
            '--content_type',
            help=u'内容类型(单选), choice is article/topic/question/answer/tractate ...'
        )

    def handle(self, *args, **options):
        _input_content_type = options["content_type"]

        model_, content_type = self._input_content_type_map_dic.get(_input_content_type, (None, None))
        if not model_:
            print(u'请输入正确参数')
            return

        print("START!")

        start_time = time.time()
        content_queue = Queue()

        producer = Producer("producer_{}".format(content_type), content_queue, model_, content_type)
        consumer = Consumer("Consumer_{}".format(content_type), content_queue, model_, content_type)

        producer.start()
        consumer.start()

        producer.join()
        consumer.join()

        end_time = time.time()
        print("ALL threads finished!")
        print("END")
        print("total_time: %s s" % int(math.ceil(end_time - start_time)))