# -*- coding: UTF-8 -*-
import json
import os
import shutil

from django.conf import settings
from django.core.management.base import BaseCommand
from gm_dbmw_api.descriptor import connection
from gm_dbmw_api.descriptor import descriptor_to_json
from gm_dbmw_api.descriptor import sink
from gm_dbmw_api.descriptor.source import MySQLTableScanSource

from data_sync.type_info import get_type_info_map


class TaskGroupCategory(object):
    """
    COPY from gm_dbmw_api
    """
    ONLINE = 'online'
    FULL_SCANNING = 'full-scanning'
    ROUND_SCANNING = 'round-scanning'

    ALL_VALUE_SET = frozenset([
        ONLINE,
        FULL_SCANNING,
        ROUND_SCANNING,
    ])


SERVICE_NAME_PREFIX = 'mimas'

TABLE_EXCLUDED_COLUMNS_MAP = {
    # 'api_doctor': [
    #     'view_num',
    #     'reply_num',
    #     'new_pm_num',
    #     'reply_pm_num',
    #     'no_reply_pm_num',
    #     'share_topic_num',
    #     'last_answered_time',
    # ],
}

VALID_CONFIGURATION_SET = [
    'pk_data_source',
    'index_data_source',
]

CELERY_QUEUE_CONFIGURATION = {
    'pk_data_source': {
        'question': {'instance': 'question', 'worker_count': 4},
        'topic': {'instance': 'topic', 'worker_count': 8},
        'article': {'instance': 'article', 'worker_count': 8},
        'answer': {'instance': 'answer', 'worker_count': 8}
    },
    'index_data_source': {
        'question': None,
        'topic': {'instance': 'topic-index', 'worker_count': 4},
        'article': {'instance': 'article', 'worker_count': 8},
        'answer': {'instance': 'answer', 'worker_count': 8}
    },
}

SINK_CONFIGURATION = {
    TaskGroupCategory.ONLINE: lambda type_info: dict(
        batch_size_limit=min(type_info.bulk_insert_chunk_size, 20),
        flow_rate_limit=20.0,
        push_rate_limit=3.0,
        preferred_batch_size=type_info.bulk_insert_chunk_size,
        linger_seconds=3.0,
    ),
    TaskGroupCategory.FULL_SCANNING: lambda type_info: dict(
        batch_size_limit=type_info.bulk_insert_chunk_size * 5,
        flow_rate_limit=1000.0,
        push_rate_limit=3.0,
        preferred_batch_size=type_info.bulk_insert_chunk_size,
        linger_seconds=3.0,
    ),
    TaskGroupCategory.ROUND_SCANNING: lambda type_info: dict(
        batch_size_limit=type_info.round_insert_chunk_size,
        flow_rate_limit=type_info.round_insert_chunk_size / type_info.round_insert_period,
        push_rate_limit=1.0 / type_info.round_insert_period,
        preferred_batch_size=type_info.round_insert_chunk_size,
        linger_seconds=3.0,
    ),
}


def get_configuration_pretty_name(configuration):
    suffix = '_data_source'
    assert configuration.endswith(suffix)
    return configuration[:-len(suffix)]


class CeleryConfigFactory(object):
    DEFAULT_CELERY_SERVICE_NAME = 'mentha-dbmw-celery'
    DEFAULT_QUEUE_NAME_PREFIX = 'mentha-dbmw'

    def __init__(
            self,
            service_name=None,
            queue_name_prefix=None,
    ):
        self.service_name = service_name or self.DEFAULT_CELERY_SERVICE_NAME
        self.queue_name_prefix = queue_name_prefix or self.DEFAULT_QUEUE_NAME_PREFIX
        self.virtualenv_path = '/srv/envs/{}/'.format(self.service_name)
        self.project_path = '/srv/apps/{}/'.format(self.service_name)
        self.log_base_dir = '/data/log/{}/'.format(self.service_name)

    def get_instance(self, configuration, es_type):
        c = CELERY_QUEUE_CONFIGURATION[configuration][es_type]
        if not c:
            return None
        return CeleryConfigInstance(
            factory=self,
            instance_name=c['instance'],
            worker_count=c['worker_count'],
        )

    def get_scan_instance(self):
        return CeleryConfigInstance(
            factory=self,
            instance_name='scan',
            worker_count=24,
        )


class CeleryConfigInstance(object):
    def __init__(
            self,
            factory,
            instance_name,
            worker_count,
    ):
        assert isinstance(factory, CeleryConfigFactory)
        self.__factory = factory
        self.instance_name = instance_name

        self.worker_count = int(worker_count)
        self.queue_name = '{}-{}'.format(self.__factory.queue_name_prefix, instance_name)

    @property
    def supervisor_conf(self):
        return '''
[program:{service_name}-{instance}]
command={virtualenv_path}/bin/celery worker -c{worker_count}  -A api -Q {queue} --loglevel=DEBUG --maxtasksperchild 500
directory={project_path}
user=gmuser
stdout_logfile={stdout_logfile}
stderr_logfile={stderr_logfile}
autostart=true
autorestart=true
startsecs=10
stopwaitsecs = 180
stopasgroup=true
killasgroup=true
'''.format(
            instance=self.instance_name,
            queue=self.queue_name,
            service_name=self.__factory.service_name,
            virtualenv_path=self.__factory.virtualenv_path,
            worker_count=self.worker_count,
            project_path=self.__factory.project_path,
            stdout_logfile=os.path.join(
                self.__factory.log_base_dir, 'supervisor', 'celery-{}.log'.format(self.instance_name),
            ),
            stderr_logfile=os.path.join(
                self.__factory.log_base_dir, 'supervisor', 'celery-error-{}.log'.format(self.instance_name),
            ),
        )


class ConnectionInstance(object):
    def __init__(
            self,
            celery_config_instance,
            task_group_category,
            type_info,
            configuration=None,  # for online
            queue_name=None,  # for dev
    ):
        if task_group_category == TaskGroupCategory.ONLINE:
            assert configuration is not None
            data_source = getattr(type_info, configuration)
            self.__name = '{}-{}'.format(type_info.name, get_configuration_pretty_name(configuration))
            is_online = True
        else:
            assert configuration is None
            self.__name = type_info.name
            is_online = False
            model = type_info.model
            logical_database_id = type_info.logic_database_id or settings.LOGICAL_DATABASE_ID
            if task_group_category == TaskGroupCategory.FULL_SCANNING:
                data_source = MySQLTableScanSource(
                    logical_database_id=logical_database_id,
                    table_name=model._meta.db_table,
                    key_columns=[model._meta.pk.column],
                )
            elif task_group_category == TaskGroupCategory.ROUND_SCANNING:
                data_source = MySQLTableScanSource(
                    logical_database_id=logical_database_id,
                    table_name=model._meta.db_table,
                    key_columns=[model._meta.pk.column],
                    cyclic=True,
                    scan_rate=(
                            float(type_info.round_insert_chunk_size) / float(type_info.round_insert_period) * 2
                    ),  # limit in sink
                    random_start_point=True,
                )
            else:
                raise Exception('unknown task group category: {}'.format(repr(task_group_category)))

        self.__task_group_category = task_group_category
        sink_configuration = SINK_CONFIGURATION[task_group_category](type_info)

        s = sink.CeleryV1Sink(
            task='data_sync.tasks.write_to_es',
            key_list_param_name='pk_list',
            static_params_json=json.dumps({
                'es_type': type_info.name,
            }),
            broker_url=settings.BROKER_URL,
            queue=queue_name or celery_config_instance.queue_name,
            unpack_input_tuple=True,
            **sink_configuration
        )
        if type_info.gm_mq_endpoint and task_group_category == TaskGroupCategory.ONLINE:
            mq_sink = sink.GMMQSink(
                endpoint=type_info.gm_mq_endpoint,
                key_list_param_name='pk_list',
                static_params_json=json.dumps({
                    'es_type': type_info.name,
                    'configuration': configuration,
                    'use_batch_query_set': not is_online,
                }),
                unpack_input_tuple=True,
                **sink_configuration
            )
            s = sink.GroupedSink.of(s, mq_sink)
        self.__connection = connection.Connection(
            name=self.__name,
            source=data_source,
            sink=s,
            service_name='{}-{}'.format(SERVICE_NAME_PREFIX, task_group_category),
            task_group_category=task_group_category,
        )

    def save(self, base_path):
        if not base_path:
            return
        path = os.path.join(
            base_path,
            self.__task_group_category,
            '{}.json'.format(self.__name)
        )
        with open(path, 'w') as f:
            json_str = descriptor_to_json(self.__connection)
            f.write(json_str)


def make_and_clear_path(path):
    if os.path.exists(path):
        shutil.rmtree(path, ignore_errors=True)
    os.makedirs(path)


class Command(BaseCommand):
    def add_arguments(self, parser):
        parser.add_argument('-sp', '--store-path', dest='path')
        parser.add_argument('--supervisor-conf-path')
        parser.add_argument('--queue-name')

    def handle(self, *args, **options):

        dir_path = options.get('path')
        supervisor_conf_path = options.get('supervisor_conf_path')
        queue_name = options.get('queue_name')
        if supervisor_conf_path and queue_name:
            raise Exception('option supervisor-conf-path and queue-name are incompatible')

        if dir_path:
            for task_group_category in TaskGroupCategory.ALL_VALUE_SET:
                make_and_clear_path(os.path.join(dir_path, task_group_category))
        supervisor_conf_list = []
        factory = CeleryConfigFactory()
        scan_celery_config_instance = factory.get_scan_instance()
        supervisor_conf_list.append(scan_celery_config_instance.supervisor_conf)

        for name, type_info in get_type_info_map().items():
            for configuration in VALID_CONFIGURATION_SET:

                print("Name: %s, config: %s, queue: %s" % (name, connection.Connection, queue_name))

                if not getattr(type_info, configuration):
                    continue

                celery_config_instance = factory.get_instance(
                    configuration=configuration,
                    es_type=name,
                )
                if not celery_config_instance:
                    continue
                supervisor_conf_list.append(celery_config_instance.supervisor_conf)
                ConnectionInstance(
                    celery_config_instance=celery_config_instance,
                    type_info=type_info,
                    configuration=configuration,
                    task_group_category=TaskGroupCategory.ONLINE,
                    queue_name=queue_name,
                ).save(base_path=dir_path)
            for task_group_category in [TaskGroupCategory.FULL_SCANNING, TaskGroupCategory.ROUND_SCANNING]:
                celery_config_instance = scan_celery_config_instance
                ConnectionInstance(
                    celery_config_instance=celery_config_instance,
                    type_info=type_info,
                    configuration=None,
                    task_group_category=task_group_category,
                    queue_name=queue_name,
                ).save(base_path=dir_path)

        if supervisor_conf_path:
            with open(supervisor_conf_path, 'w') as f:
                f.write(''.join(supervisor_conf_list))