# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

from django.db import models
from .data import DataSource, DataRelation, ReversedDataRelation


class ModelDataSource(DataSource):
    def __init__(self, table_name):
        super(ModelDataSource, self).__init__()
        self.table_name = table_name

    def __getitem__(self, item):
        raise NotImplementedError

    def __repr__(self):
        return '{}(table={})'.format(self.__class__.__name__, self.table_name)

    def _to_extra_json(self):
        return {'table': self.table_name}


class ForeignKeyDataRelation(DataRelation):
    def __init__(self, field, ddsm):
        model = field.model
        related_model = field.related_model
        pk_attname = model._meta.pk.attname
        fk_attname = field.attname

        model_data_source = ddsm.model_data_source(model)
        related_model_data_source = ddsm.model_data_source(related_model)

        table_name = model._meta.db_table
        column_name = field.column
        self.pk_attname = pk_attname
        self.fk_attname = fk_attname
        self.model = model
        self.table_name = table_name
        self.column_name = column_name
        super(ForeignKeyDataRelation, self).__init__(
            primary_data_source=model_data_source,
            secondary_data_source=related_model_data_source,
        )

    def get_related_pk_list(self, primary_pk):
        try:
            instance = self.model.objects.only(self.fk_attname).get(pk=primary_pk)
        except self.model.DoesNotExist:
            # TODO: how to do?
            return []
        fk = getattr(instance, self.fk_attname)
        if fk is None:
            return []
        return [fk]

    def get_reversed_related_pk_list(self, secondary_pk):
        return self.model.objects.filter(**{
            self.fk_attname: secondary_pk
        }).values_list(self.pk_attname, flat=True)

    @property
    def reversed(self):
        return ReversedDataRelation(self)

    def __repr__(self):
        return '{}(table={}, column={})'.format(
            self.__class__.__name__,
            self.table_name,
            self.column_name,
        )

    def _to_extra_json(self):
        return {
            'table': self.table_name,
            'column': self.column_name,
        }


class M2MDataRelation(DataRelation):
    def __init__(self, field, ddsm):
        primary_model = field.model
        secondary_model = field.related.model
        through_model = field.related.through

        primary_model_data_source = ddsm.model_data_source(primary_model)
        secondary_model_data_source = ddsm.model_data_source(secondary_model)

        through_fields = through_model._meta.get_fields()
        primary_field = [x for x in through_fields if x.name == field.m2m_field_name()][0]
        secondary_field = [x for x in through_fields if x.name == field.m2m_reverse_field_name()][0]

        self.primary_attname = primary_field.attname
        self.secondary_attname = secondary_field.attname

        self.primary_table_name = primary_model._meta.db_table
        self.through_table_name = through_model._meta.db_table
        self.secondary_table_name = secondary_model._meta.db_table
        self.through_model = through_model
        super(M2MDataRelation, self).__init__(
            primary_data_source=primary_model_data_source,
            secondary_data_source=secondary_model_data_source,
        )

    def get_related_pk_list(self, primary_pk):
        return self.through_model.objects.filter(**{
            self.primary_attname: primary_pk
        }).exclude(**{
            self.secondary_attname: None
        }).values_list(self.secondary_attname, flat=True)

    def get_reversed_related_pk_list(self, secondary_pk):
        return self.through_model.objects.filter(**{
            self.secondary_attname: secondary_pk
        }).exclude(**{
            self.primary_attname: None
        }).values_list(self.primary_attname, flat=True)

    @property
    def reversed(self):
        return ReversedDataRelation(self)

    def __repr__(self):
        return '{}(table={}, through_table={})'.format(
            self.__class__.__name__,
            self.primary_table_name,
            self.through_table_name,
        )

    def _to_extra_json(self):
        return {
            'primary_table': self.primary_table_name,
            'through_table': self.through_table_name,
        }


class DjangoDataSourceManager(object):
    def __init__(self):
        self._tds = dict()
        self._mds = dict()
        self._fcdr = dict()
        self._ffdr = dict()
        self._m2mtdr = dict()
        self._m2mfdr = dict()

    def get_table_data_source(self, table):
        return self._tds.get(table)

    def model_data_source(self, model):
        assert issubclass(model, models.Model)
        return self._mds[model]

    def get_foreign_column_data_relation(self, table, column):
        return self._fcdr.get((table, column))

    def foreign_key_data_relation(self, field):
        assert isinstance(field, (models.ForeignKey, models.OneToOneField))
        return self._ffdr[field]

    def get_many_to_many_table_data_relation(self, through_table):
        return self._m2mtdr.get(through_table)

    def many_to_may_field_data_relation(self, field):
        assert isinstance(field, models.ManyToManyField)
        return self._m2mfdr[field]

    def add_model(self, model):
        assert issubclass(model, models.Model)

        if model in self._mds:
            return

        table_name = model._meta.db_table

        mds = ModelDataSource(table_name=table_name)

        table = model._meta.db_table
        assert table not in self._tds

        self._tds[table] = mds
        self._mds[model] = mds

    def add_foreign_key(self, field):
        assert isinstance(field, (models.ForeignKey, models.OneToOneField))

        if field in self._ffdr:
            return

        fkdr = ForeignKeyDataRelation(field=field, ddsm=self)

        table = field.model._meta.db_table
        column = field.column
        assert (table, column) not in self._fcdr
        self._fcdr[table, column] = fkdr
        self._ffdr[field] = fkdr

    def add_m2m_field(self, field):
        assert isinstance(field, models.ManyToManyField)

        if field in self._m2mfdr:
            return

        m2mdr = M2MDataRelation(field, ddsm=self)
        through_model = field.related.through
        through_table = through_model._meta.db_table
        assert through_table not in self._m2mtdr
        self._m2mtdr[through_table] = m2mdr
        self._m2mfdr[field] = m2mdr
