# coding=utf8
from __future__ import unicode_literals, absolute_import, print_function

from . import basic
from .. import accessor
from .basic import Type, ConstructedType


class Union(ConstructedType):

    def __init__(self, choices):
        self.choices = list(choices)
        for c in self.choices:
            assert isinstance(c, basic.Type)

    def access(self, ac):
        if not isinstance(ac, accessor.UnionAccessor):
            return basic.Bottom()
        if not (0 <= ac.index <= len(self.choices)):
            return basic.Bottom()
        return self.choices[ac.index]

    def unify(self, chain, item):
        el = []
        for index, c in enumerate(self.choices):
            try:
                ch = chain % accessor.UnionAccessor(index)
                c.unify(ch, item)
            except accessor.ConflictException as e:
                el.append(e.conflict)
            else:
                return
        raise accessor.UnionConflict(chain=chain, expected=self, got=item, choices=el).exception


class TypedArray(ConstructedType):

    def __init__(self, element_type):
        self.element_type = element_type

    def access(self, ac):
        if not isinstance(ac, accessor.IndexAccessor):
            return basic.Bottom()
        return self.element_type

    def unify(self, chain, item):
        if not isinstance(item, list):
            raise accessor.Conflict(chain=chain, expected=self, got=item, message="not a list").exception
        for index, x in enumerate(item):
            ac = accessor.IndexAccessor(index)
            self.access(ac).unify(chain % ac, x)


class RecordType(ConstructedType):

    def __init__(self, **entry_map):
        for k, v in entry_map.items():
            assert isinstance(k, basestring)
            assert isinstance(v, basic.Type)
        self.entry_map = dict(entry_map)
        self.key_set = frozenset(self.entry_map.keys())

    def access(self, ac):
        if not isinstance(ac, accessor.AttributeAccessor):
            return basic.Bottom()
        return self.entry_map.get(ac.attr, basic.Bottom())

    def unify(self, chain, item):
        if not isinstance(item, dict):
            raise accessor.Conflict(chain=chain, expected=self, got=item, message="not a dict").exception

        for k in self.key_set.union(item.keys()):
            ac = accessor.AttributeAccessor(k)
            value = item.get(k, basic.Bottom())         # TODO: basic.handle non-python-value in unify
            self.access(ac).unify(chain % ac, value)
