import json
import logging
import os
from threading import Thread, Lock
from multiprocessing import cpu_count
from traceback import format_exc

from django.core.management import BaseCommand

from api.models.message import Message
from gm_types.msg import MESSAGE_TYPE
from message.utils.es_abstract import get_esop, table_message

_logger = None
_lock = Lock()


def get_logger():
    global _logger
    if _logger is not None:
        return _logger
    # double check lock(DCL)
    with _lock:
        if _logger is None:
            logger = logging.getLogger(__file__)
            logger.setLevel(logging.INFO)

            fname = '/tmp/{}-dbfix.log'.format(os.getpid())
            handler = logging.FileHandler(fname)
            formatter = logging.Formatter('%(levelname)s:%(message)s')
            handler.setFormatter(formatter)
            handler.setLevel(logging.INFO)

            logger.addHandler(handler)
            _logger = logger
    return _logger


def fix_message_body(messages):
    logger = get_logger()
    for self in messages:
        body = None
        if self.type in [MESSAGE_TYPE.TEXT, MESSAGE_TYPE.AUDIO, MESSAGE_TYPE.IMAGE,
                         MESSAGE_TYPE.CUSTOMER_SRV_CTRL]:
            body = self.content
        elif self.type == MESSAGE_TYPE.SERVICE:
            body = int(self.content.lstrip('service:'))
        elif self.type == MESSAGE_TYPE.DOCTOR_TOPIC:
            body = int(self.content.lstrip('doctor_topic:'))
        elif self.type == MESSAGE_TYPE.DIARY:
            body = int(self.content.lstrip('diary:'))
        elif self.type == MESSAGE_TYPE.GIFT:
            channel, gift = self.content.split(',')
            channel_id = channel.split(':')[1]
            gift_id = gift.split(':')[1]
            body = {
                'channel_id': int(channel_id),
                'gift_id': int(gift_id)
            }
        elif self.type == MESSAGE_TYPE.TEXT_WITH_URL:
            query = {
                "query": {
                    "match": {
                        "id": self.id
                    }
                }
            }
            res = get_esop().search(
                table_message,
                body=query,
                _source=['content']
            )
            # Not found, record it to count missing.
            if res['hits']['total'] != 1:
                logger.warning('message id {} content is missing.'.format(self.id))
            else:
                doc = res['hits']['hits'][0]
                content = doc['_source']['content']
                body = {
                    'text': content['text'],
                    'url': content['url']
                }
        if body is not None:
            self.body = json.dumps(body, ensure_ascii=False)
            self.save(update_fields=['body'])


def worker(start, stop, chunk=100):
    logger = get_logger()
    for left in range(start, stop + 1, chunk):
        right = min(stop, left + chunk - 1)
        try:
            messages = Message.objects.filter(
                pk__in=range(left, right + 1)
            )
            fix_message_body(messages)
            logger.info('range[{},{}] fix done'.format(left, right))
        except Exception:
            logger.error('range[{},{}] got wrong, reason:{}'.format(left, right, format_exc()))


class Command(BaseCommand):
    def add_arguments(self, parser):
        parser.add_argument('-c', help='concurrency numbers;default the core numbers of CPU.', type=int)
        parser.add_argument('-s', help='begin id[include].', type=int)
        parser.add_argument('-t', help='end id[include].', type=int)

    def handle(self, *args, **options):
        start = Message.objects.first().id
        stop = Message.objects.last().id
        thread_nums = cpu_count()
        if options['s']:
            start = options['s']
        if options['t']:
            stop = options['t']
        if options['c']:
            thread_nums = options['c']

        assert start <= stop

        logger = get_logger()
        logger.info('concurrency numbers: {}, fix range[{},{}]'.format(thread_nums, start, stop))

        step = max((stop - start + 1) / thread_nums, 1)
        threads = []
        for i, left in enumerate(range(start, stop + 1, step)):
            right = min(stop, left + step - 1)
            t = Thread(target=worker, args=(left, right), name=i + 1)
            threads.append(t)
            t.start()

        for t in threads:
            t.join()

        logger.info('database fix done.')
