# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

import collections
import logging
import os
import random
import signal
import sys
import threading
import time
import traceback
from Queue import Queue, Empty

from injection.data_sync.impl import tracers
from ..connections import create_connections
from ..event_attach import ddsm
from ..queue import get_redis_connection, parse_table_event
from ... import settings as ds_settings
from ...core.data import Event, EventSource
from ...core.event_bridge import EventBridge

FINAL_SOURCE_STAGE_COUNT = 10

logger = logging.getLogger(ds_settings.LOGGER_NAME)

logging_handler = logging.StreamHandler()
logging_handler.setFormatter(logging.Formatter(
    '[{}]'.format(os.getpid()) + ' %(asctime)s %(levelname)s %(module)s.%(funcName)s Line:%(lineno)d  %(message)s'))
logging_handler.setLevel(logging.INFO)
logger.addHandler(logging_handler)


class IQueue(object):
    __slots__ = []

    def enqueue(self, value):
        raise NotImplementedError

    def iter_one_stage(self):
        raise NotImplementedError

    def is_empty(self):
        raise NotImplementedError


class UniqueQueue(IQueue):
    __slots__ = '_set', '_deque'

    def __init__(self):
        self._set = set()
        self._deque = collections.deque()

    def enqueue(self, value):
        if value in self._set:
            return
        self._set.add(value)
        self._deque.append(value)

    def dequeue(self, default=None):
        if len(self._set) == 0:
            return default
        value = self._deque.popleft()
        self._set.remove(value)
        return value

    def get_all(self):
        r = list(self._deque)
        self._set.clear()
        self._deque.clear()
        return r

    def iter_one_stage(self):
        return self.get_all()

    def is_empty(self):
        return len(self._deque) == 0

    def __contains__(self, item):
        return item in self._set


class DepQueue(object):
    __slots__ = '_set'

    def __init__(self):
        self._set = collections.OrderedDict()

    def enqueue(self, value):
        if value in self._set:
            del self._set[value]
        self._set[value] = None

    def dequeue(self, default=None):
        for value in self._set:
            del self._set[value]
            return value
        else:
            return default

    def get_all(self):
        r = list(self._set.keys())
        self._set.clear()
        return r

    def is_empty(self):
        return len(self._set) == 0

    def __contains__(self, item):
        return item in self._set


class DualUniqueQueue(IQueue):
    __slots__ = '_tail', '_head'

    def __init__(self):
        self._tail = UniqueQueue()
        self._head = UniqueQueue()

    def _swap(self):
        self._head, self._tail = self._tail, self._head

    def enqueue(self, value):
        if value in self._tail or value in self._head:
            return
        self._tail.enqueue(value)

    def dequeue(self, default=None):
        if self._head.is_empty():
            self._swap()
        return self._head.dequeue(default=default)

    def get_half(self):
        r = self._head.get_all()
        self._swap()
        return r

    def get_all(self):
        return self._head.get_all() + self._tail.get_all()

    def iter_one_stage(self):
        return self.get_half()

    def is_empty(self):
        return self._head.is_empty() and self._tail.is_empty()


class FreeStageUniqueQueue(IQueue):
    __slots__ = '_set', '_queue', '_counts', '_last_count'

    def __init__(self):
        self._set = set()
        self._queue = collections.deque()
        self._counts = collections.deque()
        self._last_count = 0

    def open_stage(self):
        self._counts.append(self._last_count)
        self._last_count = 0

    def enqueue(self, value):
        if value in self._set:
            return
        self._set.add(value)
        self._queue.append(value)
        self._last_count += 1

    def _yield_one_stage(self):
        if len(self._counts) == 0:
            self.open_stage()
        count = self._counts.popleft()
        for _ in range(count):
            value = self._queue.popleft()
            self._set.remove(value)
            yield value

    def iter_one_stage(self):
        return self._yield_one_stage()

    def is_empty(self):
        return len(self._set) == 0


class MultiStageUniqueQueue(FreeStageUniqueQueue):
    __slots__ = '_stage_count'

    def __init__(self, stage_count):
        super(MultiStageUniqueQueue, self).__init__()
        self._stage_count = stage_count
        for _ in range(stage_count):
            self.open_stage()

    def iter_one_stage(self):
        self.open_stage()
        return super(MultiStageUniqueQueue, self)._yield_one_stage()


class Scheduler(object):
    def __init__(self):
        self._source_buffer_map = dict()
        self._active_buffers = DepQueue()
        # TODO: topological sort

    def _get_event_buffer(self, event_source):
        if event_source in self._source_buffer_map:
            return self._source_buffer_map[event_source]

        assert isinstance(event_source, EventSource)

        queue = UniqueQueue()

        # TODO: handle weakref
        buf = EventBuffer(scheduler=self, event_source=event_source, queue=queue)
        self._source_buffer_map[event_source] = buf
        return buf

    def push_event(self, event):
        buf = self._get_event_buffer(event.event_source)
        buf.push_event(event)
        self._active_buffers.enqueue(buf)

    def _run_buffer(self, buf):
        assert isinstance(buf, EventBuffer)
        buf.process_one_stage()
        if not buf.is_empty():
            self._active_buffers.enqueue(buf)

    def run_once(self):
        buf = self._active_buffers.dequeue()
        if buf is None:
            return
        self._run_buffer(buf)

    def run_one_iteration(self):
        buffer_list = self._active_buffers.get_all()
        for buf in buffer_list:
            self._run_buffer(buf)

    def is_idle(self):
        return self._active_buffers.is_empty()


class EventBuffer(object):
    def __init__(self, scheduler, event_source, queue):
        assert isinstance(scheduler, Scheduler)
        assert isinstance(event_source, EventSource)
        assert isinstance(queue, IQueue)
        self._sched = scheduler
        self._event_source = event_source
        event_source.bound_objects.add(self)
        self._event_queue = queue

    def push_event(self, event):
        assert isinstance(event, Event)
        self._event_queue.enqueue(event)

    def process_one_stage(self):
        for e in self._event_queue.iter_one_stage():
            assert isinstance(e, Event)
            for h in self._event_source.handler_set:
                try:
                    generated_events = h.process_event(event=e)
                    if generated_events is None:
                        continue
                except Exception:
                    logger.exception('Exception while processing event {}'.format(e))
                else:
                    for ge in generated_events:
                        assert isinstance(ge, Event)
                        ge.trace_id = e.trace_id
                        self._sched.push_event(ge)

    def is_empty(self):
        return self._event_queue.is_empty()


def abort_on_exception(func):
    def wrapped(*args, **kwargs):
        try:
            func(*args, **kwargs)
        except:
            traceback.print_exc()
            os.kill(os.getpid(), signal.SIGABRT)

    return wrapped


class DelayedQueue(object):
    def __init__(self, delay, output_callback, maxsize=None):
        delay = float(delay)
        assert delay >= 0
        assert callable(output_callback)

        self._delay = delay
        self._output_callback = output_callback

        self._queue = Queue(maxsize=maxsize or 0)

        self._thread = threading.Thread(target=self.worker)
        self._thread.daemon = True
        self._thread.start()

    def enqueue(self, timed_data_json):
        self._queue.put(timed_data_json)

    def __len__(self):
        return self._queue.qsize()

    @abort_on_exception
    def worker(self):

        while True:
            timed_data_json = self._queue.get()
            self._queue.task_done()
            try:
                duration, table_event = parse_table_event(timed_data_json)
            except Exception:
                logger.exception('Exception while parsing data: {}'.format(timed_data_json))
                continue

            if duration < self._delay:
                sleep_time = self._delay - duration
                time.sleep(sleep_time)

            self._output_callback(table_event)


class ForkQueue(object):
    def __init__(self, delay_list, maxsize=None):
        self._delay_list = list(sorted(delay_list))
        self._delayed_queues = [
            DelayedQueue(delay=delay, output_callback=self._output_callback, maxsize=maxsize)
            for delay in self._delay_list
            ]
        self._queue = Queue()

    def enqueue(self, timed_data_json):
        for dq in self._delayed_queues:
            dq.enqueue(timed_data_json)

    def dequeue(self, block):
        try:
            table_event = self._queue.get(block=block)
            self._queue.task_done()
            return table_event
        except Empty:
            return None

    def qsize_tuple(self):
        return tuple(len(dq) for dq in self._delayed_queues)

    def _output_callback(self, table_event):
        self._queue.put(table_event)


class RedisPuller(object):
    def __init__(self, redis_connection, output_callback, random_delay=0.0, use_pubsub=False):
        assert callable(output_callback)
        self._redis_connection = redis_connection
        self._output_callback = output_callback
        self._random_delay = random_delay = float(random_delay)
        assert random_delay >= 0.0

        if use_pubsub:
            worker = self.pubsub_worker
        else:
            worker = self.worker

        self._thread = threading.Thread(target=worker)
        self._thread.daemon = True
        self._thread.start()

    @abort_on_exception
    def worker(self):
        random_delay = self._random_delay
        while True:
            timed_data_json = self._redis_connection.lpop(ds_settings.REDIS_TABLE_EVENT_QUEUE_NAME)
            if timed_data_json is None:
                time.sleep(random.uniform(0, random_delay))
                key, timed_data_json = self._redis_connection.blpop(ds_settings.REDIS_TABLE_EVENT_QUEUE_NAME)
            self._output_callback(timed_data_json)

    @abort_on_exception
    def pubsub_worker(self):
        channel = ds_settings.REDIS_TABLE_EVENT_CHANNEL_NAME

        pubsub = self._redis_connection.pubsub()
        pubsub.subscribe(channel)
        for msg in pubsub.listen():
            if msg['type'] != 'message':
                continue
            assert msg['channel'] == channel
            timed_data_json = msg['data']
            self._output_callback(timed_data_json)


class Main(object):
    def __init__(self):
        self._scheduler = Scheduler()
        self._event_bridge = EventBridge(django_data_source_manager=ddsm, event_callback=self._scheduler.push_event)
        self._redis_connection = get_redis_connection()

    def run(
            self,
            configuration,
            connection_type,
            data_source_name_list,
            trace_kafka=None,
            use_pubsub=False,
            chunk_size=None,
            max_queue_size=None,
    ):
        tracer = self.__create_tracer(trace_kafka=trace_kafka)
        tracers.set_default_tracer(tracer)

        create_connections(
            configuration=configuration,
            connection_type=connection_type,
            data_source_name_list=data_source_name_list,
        )

        thread = threading.Thread(
            target=self._worker,
            kwargs=dict(
                chunk_size=chunk_size,
                max_queue_size=max_queue_size,
                tracer=tracer,
                use_pubsub=use_pubsub,
            )
        )
        thread.daemon = True
        thread.start()

        while True:
            try:
                thread.join(timeout=1.0)
            except KeyboardInterrupt:
                sys.exit()

    def __create_tracer(self, trace_kafka):
        default = tracers.EmptyTracer()
        if not trace_kafka:
            return default
        brokers = trace_kafka['brokers']
        topic = trace_kafka['topic']
        if not brokers or not topic:
            return default
        return tracers.KafkaTracer(
            brokers=brokers,
            topic=topic,
            period_millis=5 * 60 * 1000,
        )

    @abort_on_exception
    def _worker(self, tracer, use_pubsub, chunk_size, max_queue_size):
        assert isinstance(tracer, tracers.ITracer)
        if chunk_size is None:
            chunk_size = ds_settings.REDIS_FETCH_CHUNK_SIZE

        delay_list = sorted(list(ds_settings.TABLE_EVENT_PROCESS_DELAY_LIST))
        assert len(delay_list) > 0

        table_event_fork_deque = ForkQueue(
            delay_list=delay_list,
            maxsize=max_queue_size,
        )

        def redis_puller_output_callback(timed_data_json):
            tracer.send({'redis_puller': timed_data_json})
            table_event_fork_deque.enqueue(timed_data_json)

        redis_puller = RedisPuller(
            self._redis_connection,
            output_callback=redis_puller_output_callback,
            random_delay=0.5,
            use_pubsub=use_pubsub,
        )

        table_event_queue = []

        previous_idle = False
        previous_duration_s = '-'
        iteration = 0
        while True:
            iteration += 1

            for _ in range(chunk_size):
                table_event = table_event_fork_deque.dequeue(block=False)
                if table_event is None:
                    break
                table_event_queue.append(table_event)

            qsize_tuple = table_event_fork_deque.qsize_tuple()

            logger.info(
                'MainLoop: ForkQueue.Info: new.event.count: {:3d}, qsize: {}, previous.duration: {}'.format(
                    len(table_event_queue),
                    ', '.join(['{:3d}'.format(qlen) for qlen in qsize_tuple]),
                    previous_duration_s,
                )
            )
            previous_duration_s = '-'

            scheduler_push_event = self._scheduler.push_event

            def event_bridge_callback(event):
                event_source = event.event_source

                tracer.send({'entry_event': event.to_json()})
                tracers.trace_event_source(event_source)

                scheduler_push_event(event)

            for table_event in table_event_queue:
                try:
                    EventBridge.static_process_table_event(
                        ddsm=ddsm,
                        event_callback=event_bridge_callback,
                        table_event=table_event,
                    )
                except Exception:
                    logger.exception('Exception while processing data: {}'.format(table_event))
            table_event_queue = []

            if not self._scheduler.is_idle():
                processing_time_start = time.time()
                self._scheduler.run_one_iteration()
                processing_time_end = time.time()

                previous_idle = False
                previous_duration_s = '{:6.3f}'.format(processing_time_end - processing_time_start)
                continue
            else:
                if not previous_idle:
                    logger.info('MainLoop: Idle')
                    previous_idle = True

            table_event = table_event_fork_deque.dequeue(block=True)
            table_event_queue.append(table_event)


def run(*args, **kwargs):
    try:
        main = Main()
        main.run(*args, **kwargs)
    except:
        logger.exception('unknown exception')
        raise
