# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

import six
from django.db.models import ObjectDoesNotExist

from .. import tracers
from ...core.data import GroupedDataSource, DataConnection, DataSourceChangeEvent
from .common import TracedESDataSink
from trans2es import type_info


class TypeInfoDSClass(GroupedDataSource):
    def __init__(self, ds_type_info, pk_data_source):
        assert isinstance(ds_type_info, type_info.TypeInfo)
        super(TypeInfoDSClass, self).__init__(pk_data_source)
        self._ds_type_info = ds_type_info
        self._get_from_qs = ds_type_info.model.objects.all().get
        self._get_data = ds_type_info.get_data_func
        self._batch_get_data = ds_type_info.batch_get_data_func  # maybe None

    def process_event(self, event):
        assert isinstance(event, DataSourceChangeEvent)
        if event.pk in self._ds_type_info.pk_blacklist:
            return ()
        return super(TypeInfoDSClass, self).process_event(event)

    def __getitem__(self, key):
        if self._batch_get_data:
            return self._batch_get_data([key,])[0]
        else:
            try:
                instance = self._get_from_qs(pk=key)
            except ObjectDoesNotExist:
                raise KeyError(key)
            else:
                return self._get_data(instance)

    def _to_extra_json(self):
        return {
            'type_info_name': self._ds_type_info.name,
        }


def get_data_source(name, configuration):
    assert isinstance(configuration, six.string_types)

    if name == 'diary':
        from . import diary as module
    elif name == 'doctor':
        from . import doctor as module
    elif name == 'itemwiki':
        from . import itemwiki as module
    elif name == 'topic':
        from . import problem as module
    elif name == 'service':
        from . import service as module
    elif name == 'tag':
        from . import tag as module
    elif name == 'user':
        from . import user as module
    elif name == 'board':
        from . import board as module
    else:
        raise Exception('unexpected data_source name: {}'.format(name))

    data_source = getattr(module, configuration)

    return data_source


def get_data_source_config():
    return type_info.get_type_info_map()


def get_type_info_data_source(name, configuration):
    pk_data_source = get_data_source(name, configuration)

    type_info_map = type_info.get_type_info_map()
    ti = type_info_map[name]
    assert isinstance(ti, type_info.TypeInfo)

    return TypeInfoDSClass(
        ds_type_info=ti,
        pk_data_source=pk_data_source,
    )


def create_tracing_connection(name, configuration):
    data_source = get_data_source(name, configuration)
    tracing_event_handler = tracers.TracingEventHandler(
        data_source=data_source,
        info={
            'name': name,
            'configuration': configuration,
        },
    )
    return tracing_event_handler


def create_es_connection(name, configuration):
    data_source_config = get_data_source_config()
    es_type = data_source_config[name].type
    ti_source = get_type_info_data_source(name, configuration)
    es_sink = TracedESDataSink(es_type)
    return DataConnection(
        data_source=ti_source,
        data_sink=es_sink,
    )


def create_connection_of_type(name, configuration, connection_type):
    if connection_type == 'tracing':
        return create_tracing_connection(name=name, configuration=configuration)
    elif connection_type == 'es':
        return create_es_connection(name=name, configuration=configuration)
    else:
        raise ValueError('unexpected connection type: {}'.format(repr(connection_type)))


connection_list = []

_valid_configuration_set = [
    'pk_data_source',
    'index_data_source',
]

_valid_connection_type = [
    'tracing',
    'es',
]


def _validate():
    for configuration in _valid_configuration_set:
        for name in get_data_source_config():
            get_data_source(name, configuration)


def create_connections(
        configuration,
        connection_type,
        data_source_name_list=None
):
    assert configuration in _valid_configuration_set
    assert connection_type in _valid_connection_type

    _validate()

    if data_source_name_list is None:
        data_source_name_list = list(get_data_source_config())

    for name in data_source_name_list:
        connection = create_connection_of_type(
            name=name,
            configuration=configuration,
            connection_type=connection_type,
        )
        connection_list.append(connection)
