# -*- coding: utf-8 -*-
import inspect
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from typing import Type, Any, List, Optional, Dict, Tuple

from django.utils.functional import SimpleLazyObject
from pydantic import BaseModel, BaseConfig, MissingError
from pydantic.error_wrappers import ErrorWrapper
from pydantic.fields import ModelField, Required, FieldInfo, SHAPE_SINGLETON, \
    SHAPE_LIST, SHAPE_SET, SHAPE_TUPLE, SHAPE_SEQUENCE, SHAPE_TUPLE_ELLIPSIS
from pydantic.schema import get_annotation_from_field_info
from pydantic.utils import lenient_issubclass

from rpc_framework import params, auth, exceptions
from rpc_framework.context import RPCViewContext
from rpc_framework.exceptions import CalleeTypeHintValidationError
from rpc_framework.inspector import get_typed_signature
from rpc_framework.views import RPCView


sequence_shapes = {
    SHAPE_LIST,
    SHAPE_SET,
    SHAPE_TUPLE,
    SHAPE_SEQUENCE,
    SHAPE_TUPLE_ELLIPSIS,
}

sequence_types = (list, set, tuple)


class Interceptor(metaclass=ABCMeta):

    __slots__ = ()

    @abstractmethod
    def intercept(self, context, rpc_view):
        pass


def get_field_info(field: ModelField) -> FieldInfo:
    return field.field_info


def is_scalar_field(field: ModelField) -> bool:
    field_info = get_field_info(field)
    if not (
        field.shape == SHAPE_SINGLETON
        and not lenient_issubclass(field.type_, BaseModel)
        and not lenient_issubclass(field.type_, sequence_types + (dict,))
        and not isinstance(field_info, params.Body)
    ):
        return False
    if field.sub_fields:
        if not all(is_scalar_field(f) for f in field.sub_fields):
            return False
    return True


def get_param_field(
    *,
    param: inspect.Parameter,
    param_name: str,
    default_field_info: Type[params.Param] = params.Param,
    force_type: params.ParamTypes = None,
    ignore_default: bool = False,
) -> ModelField:
    default_value = Required
    had_schema = False
    if not param.default == param.empty and ignore_default is False:
        default_value = param.default
    if isinstance(default_value, FieldInfo):
        had_schema = True
        field_info = default_value
        default_value = field_info.default
        if (
            isinstance(field_info, params.Param)
            and getattr(field_info, "in_", None) is None
        ):
            field_info.in_ = default_field_info.in_
        if force_type:
            field_info.in_ = force_type  # type: ignore
    else:
        field_info = default_field_info(default_value)
    required = default_value == Required
    annotation: Any = Any
    if not param.annotation == param.empty:
        annotation = param.annotation
    annotation = get_annotation_from_field_info(annotation, field_info, param_name)
    if not field_info.alias and getattr(field_info, "convert_underscores", None):
        alias = param.name.replace("_", "-")
    else:
        alias = field_info.alias or param.name
    field = ModelField(
        name=param.name,
        type_=annotation,
        default=None if required else default_value,
        alias=alias,
        required=required,
        model_config=BaseConfig,
        class_validators={},
        field_info=field_info,
    )
    field.required = required
    if not had_schema and not is_scalar_field(field=field):
        field.field_info = params.Body(field_info.default)

    return field


def get_callee_field_list(callee):
    param_field_list = []

    endpoint_signature = get_typed_signature(callee)
    signature_params = endpoint_signature.parameters
    if inspect.isgeneratorfunction(callee) or inspect.isasyncgenfunction(callee):
        # check_dependency_contextmanagers()
        pass

    for param_name, param in signature_params.items():
        param_field = get_param_field(
            param=param, default_field_info=params.Query, param_name=param_name
        )
        param_field_list.append(param_field)
    return param_field_list


def is_null_subfield_in_pack(packed_field_info, subfield, subfield_value):
    return (isinstance(packed_field_info, params.Form) and subfield_value == "") or \
           (
                   isinstance(packed_field_info, params.Form)
                   and subfield.shape in sequence_shapes
                   and len(subfield_value) == 0
           )


def check_parameters_fit_signature(
    required_params: List[ModelField],
    received_body: Optional[Dict],
    is_pack_params: bool=False
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
    values = {}
    errors = []

    if required_params:
        if is_pack_params:
            packed_field = required_params[0]
            packed_field_info = get_field_info(packed_field)
            embed = getattr(packed_field_info, "embed", None)
            if len(required_params) == 1 and not embed:
                received_body = {packed_field.alias: received_body}

        for field in required_params:
            value: Any = None
            if received_body is not None:
                value = received_body.get(field.alias)
            if (
                value is None
                or (is_pack_params and is_null_subfield_in_pack(packed_field_info, field, value))
            ):
                if field.required:
                    errors.append(
                        ErrorWrapper(MissingError(), loc=("body", field.alias))
                    )
                else:
                    values[field.name] = deepcopy(field.default)
                continue
            v_, errors_ = field.validate(value, values, loc=("body", field.alias))
            if isinstance(errors_, ErrorWrapper):
                errors.append(errors_)
            elif isinstance(errors_, list):
                errors.extend(errors_)
            else:
                values[field.name] = v_
    return values, errors


class CalleeParametersInterceptor(Interceptor):

    def intercept(self, context: RPCViewContext, rpc_view: RPCView) -> bool:
        callee_field_list = get_callee_field_list(rpc_view.callee)
        validated_values, errors = check_parameters_fit_signature(callee_field_list, context.request.params)
        if errors:
            raise CalleeTypeHintValidationError.from_error_wrapper_list(errors)
        return True


class SessionUserInterceptor(Interceptor):

    def intercept(self, context: RPCViewContext, rpc_view: RPCView) -> bool:
        user = auth.get_user(context)
        context.user = user
        if context.user:
            return True
        raise exceptions.AuthenticationFailed()
