# coding=utf8
from __future__ import unicode_literals, absolute_import, print_function


from . import basic
from ..typing.basic import Bottom, Any
from ..typing.simple import Null, Boolean, Integer, Real, String
from ..typing.constructed import Union, TypedArray, RecordType
from ..parser.typedef import parse_type_pragma, parse_code_pragma
from ..typedef import nodes
from ..documents.basic import TypeDocument, FunctionDocument


array_type = TypedArray
union_type = lambda *choices: Union(choices=choices)

alpha_global_types = {
    'None':  Null,
    'Bool':  Boolean,
    'Int':   Integer,
    'Real':  Real,
    'Str':   String,
    'Array': array_type,
    'Union': union_type,
    'Any':   Any,
    'Void':  Bottom,
}


class Alpha(basic.Configuration):

    params_type_name = 'Params'
    result_type_name = 'Result'

    def parse(self, docstring):
        return self.parse_typedef(docstring)

    def parse_typedef(self, docstring):
        tree = parse_code_pragma(docstring.straight_text)

        typedef_list = tree.accept(AlphaGetTypedef())

        typedef_map = dict(typedef_list)

        for name in [self.params_type_name, self.result_type_name]:
            if name not in typedef_map:
                raise Exception("Expected type {} not found in {}".format(repr(name), repr(typedef_map.keys())))

        global_types = dict(alpha_global_types)
        for n in typedef_map:
            if n in global_types:
                del global_types[n]
        typing_list = tree.accept(AlphaTypingVisitor(global_types))
        typing_map = dict(typing_list)

        assert frozenset(typedef_map.keys()) == frozenset(typing_map.keys())

        d = {}
        for k in [self.params_type_name, self.result_type_name]:
            typedef = typedef_map[k]
            type_ = typedef.type_
            block_cm = typedef.block_comment.text
            line_cm = typedef.line_comment.text
            desc = '\n\n'.join([x for x in [block_cm, line_cm] if x is not None])

            typing = typing_map[k]

            td = TypeDocument(
                name=k,
                conf_args=docstring.conf_args,
                ast=type_,
                typing=typing,
                description=desc,
            )
            d[k] = td
        return FunctionDocument(
            params=d[self.params_type_name],
            result=d[self.result_type_name],
            description=docstring.rest_text,
        )


class AlphaGetTypedef(nodes.Visitor):

    def visit__default(self, node, **kwargs):
        return node

    def visit_TypeName(self, node):
        assert isinstance(node, nodes.TypeName)
        return node.name

    def visit_TypeDef(self, node, type_name, type_, block_comment, line_comment):
        assert isinstance(node, nodes.TypeDef)
        return (type_name, node)

    def visit__pre_Code(self, node, members):
        assert isinstance(node, nodes.Code)
        nm = []
        for m in members:
            if isinstance(m, nodes.TypeDef):
                nm.append(m)
            elif isinstance(m, nodes.Pragma):
                pass
            else:
                raise Exception("Unexpected Node {} in Code".format(repr(node.node_name)))
        return {
            'members': nm
        }

    def visit_Code(self, node, members):
        assert isinstance(node, nodes.Code)
        return members


class AlphaTypingVisitor(nodes.Visitor):

    def __init__(self, global_types, **kwargs):
        super(AlphaTypingVisitor, self).__init__(**kwargs)
        self._global_types = global_types

    def visit_Symbol(self, node):
        assert isinstance(node, nodes.Symbol)
        return node.name

    def visit_TypeName(self, node):
        assert isinstance(node, nodes.TypeName)
        return node.name

    def visit_LineComment(self, node):
        assert isinstance(node, nodes.LineComment)
        pass

    def visit_BlockComment(self, node):
        assert isinstance(node, nodes.BlockComment)
        pass

    def visit_Pragma(self, node):
        assert isinstance(node, nodes.Pragma)
        pass

    def visit_NamedType(self, node, type_name):
        assert isinstance(node, nodes.NamedType)
        return self._global_types[type_name]()

    def visit_ArrayType(self, node, element_type):
        assert isinstance(node, nodes.ArrayType)
        return array_type(element_type)

    def visit_GenericInstanceType(self, node, type_name, type_arguments):
        assert isinstance(node, nodes.GenericInstanceType)
        return self._global_types[type_name](*type_arguments)

    def visit_RecordType(self, node, entries):
        assert isinstance(node, nodes.RecordType)
        return RecordType(**dict(entries))

    def visit_RecordTypeEntry(self, node, symbol, element_type, block_comment, line_comment):
        assert isinstance(node, nodes.RecordTypeEntry)
        return (symbol, element_type)

    def visit_UnionType(self, node, choices):
        assert isinstance(node, nodes.UnionType)
        return union_type(*choices)

    def visit_TypeDef(self, node, type_name, type_, block_comment, line_comment):
        assert isinstance(node, nodes.TypeDef)
        return (type_name, type_)

    def visit__pre_Code(self, node, members):
        nm = []
        for m in members:
            if isinstance(m, nodes.TypeDef):
                nm.append(m)
            elif isinstance(m, nodes.Pragma):
                pass
            else:
                raise Exception("Unexpected Node {} in Code".format(repr(node.node_name)))
        return {
            'members': nm
        }

    def visit_Code(self, node, members):
        assert isinstance(node, nodes.Code)
        return members
