# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

import six
import sys
import json
import warnings
import traceback

from django.db.models import Model, Field, ForeignKey, ManyToManyField
from django.db.models.signals import pre_save, post_save, pre_delete, m2m_changed


_DataChangeEvent_map = {}

VERSION = 1


class DataChangeEvent(object):

    event_class_name = None
    _event_fields = None
    __slots__ = []

    def to_dict(self):
        d = {}
        d['version'] = VERSION
        d['event'] = self.event_class_name
        for key in self._event_fields:
            d[key] = getattr(self, key)
        return d

    @classmethod
    def from_dict(cls, d):
        assert isinstance(d, dict)
        d = dict(d)
        version = d.pop('version')
        assert version == VERSION
        event_class_name = d.pop('event')
        event_class = _DataChangeEvent_map[event_class_name]
        assert event_class.event_class_name == event_class_name
        return event_class(**d)


def reg_event(cls):
    assert issubclass(cls, DataChangeEvent)
    name = cls.event_class_name
    assert name not in _DataChangeEvent_map
    _DataChangeEvent_map[name] = cls
    return cls


@reg_event
class TableLocalDataChangeEvent(DataChangeEvent):

    event_class_name = 'TableLocalDataChangeEvent'
    __slots__ = _event_fields = 'table', 'pk', 'is_create', 'is_delete'

    def __init__(self, table, pk, is_create, is_delete):
        assert isinstance(table, six.string_types)
        assert isinstance(is_create, bool)
        assert isinstance(is_delete, bool)
        self.table = table
        self.pk = pk
        self.is_create = is_create
        self.is_delete = is_delete


@reg_event
class TableForeignKeyChangeEvent(DataChangeEvent):

    event_class_name = 'TableForeignKeyChangeEvent'
    __slots__ = _event_fields = 'table', 'column', 'pk', 'old_value', 'new_value', 'is_create', 'is_delete'

    def __init__(self, table, column, pk, old_value, new_value, is_create, is_delete):
        # TODO: remove is_create and is_delete, we can judge from ({old,new}_value is None)
        assert isinstance(table, six.string_types)
        assert isinstance(column, six.string_types)
        assert isinstance(is_create, bool)
        assert isinstance(is_delete, bool)
        self.table = table
        self.column = column
        self.pk = pk
        self.old_value = old_value
        self.new_value = new_value
        self.is_create = is_create
        self.is_delete = is_delete


@reg_event
class TableManyToManyFieldChangeEvent(DataChangeEvent):

    event_class_name = 'TableManyToManyFieldChangeEvent'
    __slots__ = _event_fields = 'through_table', 'primary_column', 'secondary_column', 'primary_pk', 'secondary_pk', 'is_create', 'is_delete'

    def __init__(self, through_table, primary_column, secondary_column, primary_pk, secondary_pk, is_create, is_delete):
        assert isinstance(through_table, six.string_types)
        assert isinstance(primary_column, six.string_types)
        assert isinstance(secondary_column, six.string_types)
        assert isinstance(is_create, bool)
        assert isinstance(is_delete, bool)
        self.through_table = through_table
        self.primary_column = primary_column
        self.secondary_column = secondary_column
        self.primary_pk = primary_pk
        self.secondary_pk = secondary_pk
        self.is_create = is_create
        self.is_delete = is_delete


del reg_event


class ThroughInfo(object):

    def __init__(self, fieldinfo):
        assert isinstance(fieldinfo, FieldInfo)
        assert fieldinfo.is_many_to_many

        field = fieldinfo.field
        assert isinstance(field, ManyToManyField)

        self.many_to_many_fieldinfo = fieldinfo

        self.through_table = field.m2m_db_table()
        self.through_model = getattr(field.model, field.name).through
        assert self.through_model._meta.db_table == self.through_table

        self.primary_column = field.m2m_column_name()
        self.secondary_column = field.m2m_reverse_name()

        self.primary_name = field.m2m_field_name()
        self.secondary_name = field.m2m_reverse_field_name()

        self.primary_attname = self.primary_name + '_id'
        self.secondary_attname = self.secondary_name + '_id'

        self.primary_target_name = field.m2m_target_field_name()
        self.secondary_target_name = field.m2m_reverse_target_field_name()

        self.primary_model = self.through_model._meta.get_field(self.primary_name).related.to
        self.secondary_model = self.through_model._meta.get_field(self.secondary_name).related.to


class FieldInfo(object):

    def __init__(self, model_info, field):
        assert isinstance(model_info, ModelInfo)
        assert isinstance(field, Field)

        self.model_info = model_info
        self.field = field

        self.column = field.column
        assert self.column

        self.name = field.name
        assert self.name

        self.attname = field.attname
        assert self.attname

        self.is_many_to_one = field.many_to_one
        self.is_one_to_one = field.one_to_one
        self.is_foreign = self.is_many_to_one or self.is_one_to_one
        self.is_many_to_many = field.many_to_many
        self._related_model = field.related_model

        if self.is_many_to_many:
            self.throughinfo = ThroughInfo(self)

    @property
    def related_model_info(self):
        return self.model_info.model_manager.model_modelinfo_map[self._related_model]

    def get_value(self, instance):
        return getattr(instance, self.attname)

    def get_value_by_pk(self, pk):
        instance = self.model_info.get_instance_by_pk(pk=pk)
        return self.get_value(instance=instance)


class ModelInfo(object):

    def __init__(self, model_manager, model):
        assert isinstance(model_manager, ModelManager)
        assert issubclass(model, Model)

        self.model_manager = model_manager
        self.model = model

        table = model._meta.db_table
        assert table
        self.table = table

        self.column_fieldinfo_map = dict()
        self.field_fieldinfo_map = dict()
        self.attname_fieldinfo_map = dict()

        self.fieldinfo_list = list()
        self.many_to_one_fieldinfo_list = list()
        self.one_to_one_fieldinfo_list = list()
        self.foreign_fieldinfo_list = list()
        self.many_to_many_fieldinfo_list = list()

        for field in self.model._meta.get_fields():
            if field.auto_created and field.is_relation:
                continue
            self._add_field_info(field)

        assert model._meta.pk
        pk_fieldinfo = self.field_fieldinfo_map[model._meta.pk]
        self.pk_fieldinfo = pk_fieldinfo

    def _add_field_info(self, field):
        field_info = FieldInfo(self, field)

        assert field_info.column not in self.column_fieldinfo_map
        assert field_info.field not in self.field_fieldinfo_map
        assert field_info.attname not in self.attname_fieldinfo_map

        self.column_fieldinfo_map[field_info.column] = field_info
        self.field_fieldinfo_map[field] = field_info
        self.attname_fieldinfo_map[field.attname] = field_info

        self.fieldinfo_list.append(field_info)

        if field_info.is_many_to_one:
            self.many_to_one_fieldinfo_list.append(field_info)
        if field_info.is_one_to_one:
            self.one_to_one_fieldinfo_list.append(field_info)
        if field_info.is_foreign:
            self.foreign_fieldinfo_list.append(field_info)
        if field_info.is_many_to_many:
            self.many_to_many_fieldinfo_list.append(field_info)

    def get_pk_value(self, instance):
        return self.pk_fieldinfo.get_value(instance)

    def get_instance_by_pk(self, pk):
        return self.model.objects.get(pk=pk)


class ModelManager(object):

    def __init__(self):
        self.model_set = set()
        self.modelinfo_list = list()
        self.table_modelinfo_map = dict()
        self.model_modelinfo_map = dict()
        self.field_fieldinfo_map = dict()
        self.model_throughinfo_map = dict()

    def add_model(self, model):
        assert issubclass(model, Model)

        if model in self.model_modelinfo_map:
            return self.model_modelinfo_map[model]

        model_info = ModelInfo(self, model)

        assert model_info.table not in self.table_modelinfo_map
        assert model not in self.model_modelinfo_map
        self.table_modelinfo_map[model_info.table] = model_info
        self.model_modelinfo_map[model] = model_info

        for field, fieldinfo in model_info.field_fieldinfo_map.items():
            assert field not in self.field_fieldinfo_map
            self.field_fieldinfo_map[field] = fieldinfo

        for many_to_many_fieldinfo in model_info.many_to_many_fieldinfo_list:
            through_info = many_to_many_fieldinfo.throughinfo
            through_model = through_info.through_model
            assert through_model not in self.model_throughinfo_map
            self.model_throughinfo_map[through_model] = through_info

        self.model_set.add(model)
        self.modelinfo_list.append(model_info)
        return model_info


class ModelDataMonitor(object):

    def __init__(self, event_handler):
        self._model_set = set()
        self._model_manager = ModelManager()
        self._event_handler = event_handler

    def get_model_manager(self):
        return self._model_manager

    def add_model(self, model):
        assert issubclass(model, Model)

        if model in self._model_set:
            return

        model_info = self._model_manager.add_model(model)

        pre_save.connect(self.on_pre_save, sender=model)
        post_save.connect(self.on_post_save, sender=model)
        pre_delete.connect(self.on_pre_delete, sender=model)

        for field_info in model_info.many_to_many_fieldinfo_list:
            m2m_changed.connect(self.on_m2m_changed, sender=field_info.throughinfo.through_model)

        self._model_set.add(model)

    def on_pre_save(self, sender, instance, raw, using, update_fields, **kwargs):
        self._on_save(sender=sender, instance=instance, raw=raw, is_pre=True, created=None)

    def on_post_save(self, sender, instance, created, raw, using, update_fields, **kwargs):
        self._on_save(sender=sender, instance=instance, raw=raw, is_pre=False, created=created)

    def _on_save(self, sender, instance, raw, is_pre, created):
        if raw:
            return

        try:
            assert sender == instance._meta.model
            model_info = self._model_manager.model_modelinfo_map[sender]

            pk = model_info.get_pk_value(instance)
            if is_pre:
                if pk is None:
                    return
            else:
                if not created:
                    return

            if len(model_info.foreign_fieldinfo_list) > 0:
                if not is_pre:
                    assert created
                    for field_info in model_info.foreign_fieldinfo_list:
                        attname = field_info.attname
                        new_value = getattr(instance, attname)
                        if new_value is not None:
                            self.dispatch_event(
                                TableForeignKeyChangeEvent(
                                    table=model_info.table,
                                    column=field_info.column,
                                    pk=pk,
                                    old_value=None,
                                    new_value=new_value,
                                    is_create=True,
                                    is_delete=False,
                                )
                            )
                else:
                    assert not created
                    old_instance = model_info.get_instance_by_pk(pk)

                    for field_info in model_info.foreign_fieldinfo_list:
                        attname = field_info.attname
                        old_value = getattr(old_instance, attname)
                        new_value = getattr(instance, attname)
                        if old_value != new_value:
                            self.dispatch_event(
                                TableForeignKeyChangeEvent(
                                    table=model_info.table,
                                    column=field_info.column,
                                    pk=pk,
                                    old_value=old_value,
                                    new_value=new_value,
                                    is_create=False,
                                    is_delete=False,
                                )
                            )

            self.dispatch_event(
                TableLocalDataChangeEvent(
                    table=model_info.table,
                    pk=pk,
                    is_create=bool(created),
                    is_delete=False,
                )
            )
        except Exception:
            self.handle_exception()

    def on_pre_delete(self, sender, instance, using, **kwargs):
        try:
            assert sender == instance._meta.model
            model_info = self._model_manager.model_modelinfo_map[sender]

            pk = model_info.get_pk_value(instance)

            if len(model_info.foreign_fieldinfo_list) > 0:
                old_instance = model_info.get_instance_by_pk(pk)

                for field_info in model_info.foreign_fieldinfo_list:
                    attname = field_info.attname
                    old_value = getattr(old_instance, attname)
                    self.dispatch_event(
                        TableForeignKeyChangeEvent(
                            table=model_info.table,
                            column=field_info.column,
                            pk=pk,
                            old_value=old_value,
                            new_value=None,
                            is_create=False,
                            is_delete=True,
                        )
                    )

            self.dispatch_event(
                TableLocalDataChangeEvent(
                    table=model_info.table,
                    pk=pk,
                    is_create=False,
                    is_delete=True,
                )
            )
        except Exception:
            self.handle_exception()

    def on_m2m_changed(self, sender, instance, action, reverse, model, pk_set, using, **kwargs):
        try:
            if action not in ('pre_add', 'pre_remove', 'pre_clear'):
                return

            left_modelinfo = self._model_manager.model_modelinfo_map.get(instance._meta.model)
            if not left_modelinfo:
                return

            right_modelinfo = self._model_manager.model_modelinfo_map.get(model)
            if not right_modelinfo:
                return

            left_pk = left_modelinfo.get_pk_value(instance)

            if not reverse:
                primary_modelinfo, secondary_modelinfo = left_modelinfo, right_modelinfo
            else:
                primary_modelinfo, secondary_modelinfo = right_modelinfo, left_modelinfo

            throughinfo = self._model_manager.model_throughinfo_map[sender]

            assert throughinfo.through_model is sender
            assert throughinfo.primary_model is primary_modelinfo.model
            assert throughinfo.secondary_model is secondary_modelinfo.model

            if action == 'pre_add':
                is_create = True
                is_delete = False
            elif action == 'pre_remove':
                is_create = False
                is_delete = True
            elif action == 'pre_clear':
                is_create = False
                is_delete = True
                assert pk_set is None
                if not reverse:
                    pk_set = []
                    for through_instance in throughinfo.through_model.objects.filter(**{
                        throughinfo.primary_attname: left_pk
                    }):
                        assert isinstance(through_instance, throughinfo.through_model)
                        pk_set.append(getattr(through_instance, throughinfo.secondary_attname))
                else:
                    pk_set = []
                    for through_instance in throughinfo.through_model.objects.filter(**{
                        throughinfo.secondary_attname: left_pk
                    }):
                        assert isinstance(through_instance, throughinfo.through_model)
                        pk_set.append(getattr(through_instance, throughinfo.primary_attname))
            else:
                raise Exception("UNREACHABLE!!!")

            for right_pk in pk_set:
                if not reverse:
                    primary_pk, secondary_pk = left_pk, right_pk
                else:
                    primary_pk, secondary_pk = right_pk, left_pk
                self.dispatch_event(
                    TableManyToManyFieldChangeEvent(
                        through_table=throughinfo.through_table,
                        primary_column=throughinfo.primary_column,
                        secondary_column=throughinfo.secondary_column,
                        primary_pk=primary_pk,
                        secondary_pk=secondary_pk,
                        is_create=is_create,
                        is_delete=is_delete,
                    )
                )

        except Exception:
            self.handle_exception()

    def dispatch_event(self, event):
        self._event_handler.handle_event(event)

    def handle_exception(self):
        try:
            exc_type, exc_value, exc_traceback = sys.exc_info()
            self._event_handler.handle_exception(exc_type, exc_value, exc_traceback)
        except:
            pass


class EventHandler(object):

    def handle_exception(self, exc_type, exc_value, exc_traceback):
        print(''.join(traceback.format_exception(exc_type, exc_value, exc_traceback)))

    def handle_event(self, event):
        raise NotImplementedError


class ConsoleDumpEventHandler(EventHandler):

    def handle_event(self, event):
        assert isinstance(event, DataChangeEvent)
        print(json.dumps(event.to_dict(), indent=4))


class HeliosRPCEventHandler(EventHandler):

    def __init__(self, rpc_factory):
        from helios.rpc import RPCFactory
        assert isinstance(rpc_factory, RPCFactory)
        self._rpc_client = rpc_factory.create_client()
        self._rpc_invoker = rpc_factory.create_invoker(client=self._rpc_client)

    def handle_event(self, event):
        assert isinstance(event, DataChangeEvent)
        data = event.to_dict()
        self._rpc_invoker['api/data_sync/event/post'](event=data).unwrap()


