Commit a46dbd94 authored by 张宇's avatar 张宇

exception handler

parent e9ec4ed1
...@@ -257,4 +257,4 @@ fabric.properties ...@@ -257,4 +257,4 @@ fabric.properties
# End of https://www.gitignore.io/api/python,django,pycharm # End of https://www.gitignore.io/api/python,django,pycharm
.idea/* .idea/*
courier courier/settings_local.py
\ No newline at end of file \ No newline at end of file
# -*- coding: utf-8 -*-
from django.conf import settings
class MessageRouter(object):
MODEL_PREFIXS = 'api.models.message'
TABLES = ['api_conversation', 'api_conversationuserstatus', 'api_message', 'api_conversationtags']
def validate(self, model):
return model.__module__.startswith(self.MODEL_PREFIXS) and model._meta.db_table in self.TABLES
def db_for_read(self, model, **hints):
if self.validate(model):
return getattr(settings, 'MESSAGE_SLAVE_DB_NAME', settings.MESSAGE_DB_NAME)
return None
def db_for_write(self, model, **hints):
if self.validate(model):
return settings.MESSAGE_DB_NAME
return None
def allow_relation(self, obj1, obj2, **hints):
if self.validate(obj1) and self.validate(obj2):
return True
return None
def allow_migrate(self, db, app_label, model_name=None, **hints):
return True
...@@ -2,6 +2,5 @@ ...@@ -2,6 +2,5 @@
from gm_types.utils.enum import Enum, unique from gm_types.utils.enum import Enum, unique
class CourierError(int, Enum): class Error(int, Enum):
pass PARAMS_INVALID = (10001, '非法参数')
\ No newline at end of file
...@@ -11,7 +11,9 @@ from adapter.old_system import bind_prefix ...@@ -11,7 +11,9 @@ from adapter.old_system import bind_prefix
from adapter.rpcd.exceptions import RPCPermanentError from adapter.rpcd.exceptions import RPCPermanentError
from api.models.message import ConversationUserStatus from api.models.message import ConversationUserStatus
from rpc import gaia_client from rpc import gaia_client
from rpc_framework.decorators import rpc_view from rpc_framework.context import RPCViewContext
from rpc_framework.decorators import rpc_view, interceptor_classes
from rpc_framework.interceptors import CalleeParametersInterceptor
from search.utils import search_conversation_from_es from search.utils import search_conversation_from_es
from services.unread.stat import UserUnread from services.unread.stat import UserUnread
...@@ -234,7 +236,6 @@ def message_conversation_list_v3(user_ids: List[int], ...@@ -234,7 +236,6 @@ def message_conversation_list_v3(user_ids: List[int],
@rpc_view('message/conversation/can_send') @rpc_view('message/conversation/can_send')
def check_can_send_message(context, target_uid: str) -> Dict: @interceptor_classes([CalleeParametersInterceptor])
# print(context.session_id) def check_can_send_message(context: RPCViewContext, target_uid: str) -> Dict:
print('+++', context, dir(context), context._context.session_id)
return {'a': 'b'} return {'a': 'b'}
...@@ -18,3 +18,4 @@ elasticsearch==2.3.0 ...@@ -18,3 +18,4 @@ elasticsearch==2.3.0
kafka-python==1.4.7 kafka-python==1.4.7
gunicorn==20.0.4 gunicorn==20.0.4
#djangorestframework==3.11.0 #djangorestframework==3.11.0
pydantic==1.3
\ No newline at end of file
...@@ -3,7 +3,9 @@ import sys ...@@ -3,7 +3,9 @@ import sys
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Tuple from typing import Optional, Tuple
from gm_logging import RequestInfo
from gm_rpcd.internals.context import Context from gm_rpcd.internals.context import Context
from gm_rpcd.internals.protocol.request import Request
from rpc_framework import exceptions from rpc_framework import exceptions
from rpc_framework.settings import api_settings from rpc_framework.settings import api_settings
...@@ -28,6 +30,26 @@ def wrap_attributeerrors(): ...@@ -28,6 +30,26 @@ def wrap_attributeerrors():
class RPCViewContext(object): class RPCViewContext(object):
'''
Wrapper allowing to enhance a `gm_rpcd.internals.context.Context` instance.
Kwargs:
- context(gm_rpcd.internals.context.Context). The original context instance.
- parsers_classes(list/tuple). The parsers to use for parsing the
request content.
- authentication_classes(list/tuple). The authentications used to try
authenticating the request's user.
Usage Examples:
inspect property of `gm_rpcd.internals.context.Context` instance:
print(context.request.request_id)
print(context.request.session_id)
print(context.request.method)
print(context.request.params)
print(context.request.environment)
print(context.request_info.user_id)
'''
def __init__(self, def __init__(self,
context: Context, context: Context,
authenticators: Optional[Tuple]=None, authenticators: Optional[Tuple]=None,
...@@ -62,6 +84,14 @@ class RPCViewContext(object): ...@@ -62,6 +84,14 @@ class RPCViewContext(object):
self._user = value self._user = value
self._request.user = value self._request.user = value
@property
def request(self) -> Optional[Request]:
return getattr(self._context, '_Context__request', None)
@property
def request_info(self) -> Optional[RequestInfo]:
return getattr(self._context, '_Context__request_info', None)
def _authenticate(self): def _authenticate(self):
""" """
Attempt to authenticate the request using each authentication instance Attempt to authenticate the request using each authentication instance
...@@ -99,4 +129,14 @@ class RPCViewContext(object): ...@@ -99,4 +129,14 @@ class RPCViewContext(object):
else: else:
self.auth = None self.auth = None
def __getattr__(self, attr):
"""
If an attribute does not exist on this instance, then we also attempt
to proxy it to the underlying HttpRequest object.
"""
try:
return getattr(self._request, attr)
except AttributeError:
return self.__getattribute__(attr)
...@@ -19,7 +19,9 @@ def rpc_view(endpoint): ...@@ -19,7 +19,9 @@ def rpc_view(endpoint):
WrappedAPIView.endpoint = endpoint WrappedAPIView.endpoint = endpoint
def handler(self, context, *args, **kwargs): def handler(self, context, *args, **kwargs):
return func(context, *args, **kwargs) return func(context, *args, **kwargs)
setattr(WrappedAPIView, 'handler', handler) setattr(WrappedAPIView, 'handler', handler)
setattr(WrappedAPIView, 'callee', func)
WrappedAPIView.__name__ = func.__name__ WrappedAPIView.__name__ = func.__name__
WrappedAPIView.__module__ = func.__module__ WrappedAPIView.__module__ = func.__module__
...@@ -30,6 +32,9 @@ def rpc_view(endpoint): ...@@ -30,6 +32,9 @@ def rpc_view(endpoint):
WrappedAPIView.parser_classes = getattr(func, 'parser_classes', WrappedAPIView.parser_classes = getattr(func, 'parser_classes',
RPCView.parser_classes) RPCView.parser_classes)
WrappedAPIView.interceptor_classes = getattr(func, 'interceptor_classes',
RPCView.interceptor_classes)
WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes', WrappedAPIView.authentication_classes = getattr(func, 'authentication_classes',
RPCView.authentication_classes) RPCView.authentication_classes)
...@@ -44,3 +49,52 @@ def rpc_view(endpoint): ...@@ -44,3 +49,52 @@ def rpc_view(endpoint):
return WrappedAPIView.rpc_bind() return WrappedAPIView.rpc_bind()
return decorator return decorator
def renderer_classes(renderer_classes):
def decorator(func):
func.renderer_classes = renderer_classes
return func
return decorator
def parser_classes(parser_classes):
def decorator(func):
func.parser_classes = parser_classes
return func
return decorator
def interceptor_classes(interceptor_classes):
def decorator(func):
func.interceptor_classes = interceptor_classes
return func
return decorator
def authentication_classes(authentication_classes):
def decorator(func):
func.authentication_classes = authentication_classes
return func
return decorator
def throttle_classes(throttle_classes):
def decorator(func):
func.throttle_classes = throttle_classes
return func
return decorator
def permission_classes(permission_classes):
def decorator(func):
func.permission_classes = permission_classes
return func
return decorator
def schema(view_inspector):
def decorator(func):
func.schema = view_inspector
return func
return decorator
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from abc import ABCMeta, abstractmethod
from typing import Sequence
from gm_rpcd.internals.exceptions import RPCDFaultException
from pydantic import ValidationError, create_model
from pydantic.error_wrappers import ErrorList
from extension.types import Error
def _check_methods(C, *methods):
mro = C.__mro__
for method in methods:
for B in mro:
if method in B.__dict__:
if B.__dict__[method] is None:
return NotImplemented
break
else:
return NotImplemented
return True
class RPCViewBaseException(metaclass=ABCMeta):
code: int
message: str
__slots__ = ()
@classmethod
def __subclasshook__(cls, C):
if cls is RPCViewBaseException:
return _check_methods(C, "as_rpcd_fault")
return NotImplemented
@abstractmethod
def as_rpcd_fault(self):
return RPCDFaultException(error=self.error, message=self.message)
RequestErrorModel = create_model("Request")
class CalleeTypeHintValidationError(ValidationError):
def __init__(self, errors: Sequence[ErrorList]) -> None:
return super().__init__(errors, RequestErrorModel)
def as_rpcd_fault(self):
return RPCDFaultException(code=Error.PARAMS_INVALID, message=self.json())
class RPCViewBaseException(Exception):
pass
# -*- coding: utf-8 -*-
import inspect
from typing import Callable, Dict, Any
from pydantic.typing import ForwardRef, evaluate_forwardref
def get_typed_annotation(param: inspect.Parameter, globalns: Dict[str, Any]) -> Any:
annotation = param.annotation
if isinstance(annotation, str):
annotation = ForwardRef(annotation)
annotation = evaluate_forwardref(annotation, globalns, globalns)
return annotation
def get_typed_signature(call: Callable) -> inspect.Signature:
signature = inspect.signature(call)
globalns = getattr(call, "__globals__", {})
typed_params = [
inspect.Parameter(
name=param.name,
kind=param.kind,
default=param.default,
annotation=get_typed_annotation(param, globalns),
)
for param in signature.parameters.values()
]
typed_signature = inspect.Signature(typed_params)
return typed_signature
if __name__ == '__main__':
def a(a: int=10):
pass
res = get_typed_signature(a)
print(res, type(res))
\ No newline at end of file
# -*- coding: utf-8 -*-
import inspect
from abc import ABCMeta, abstractmethod
from copy import deepcopy
from typing import Type, Any, List, Optional, Dict, Tuple
from pydantic import BaseModel, BaseConfig, MissingError
from pydantic.error_wrappers import ErrorWrapper
from pydantic.fields import ModelField, Required, FieldInfo, SHAPE_SINGLETON, \
SHAPE_LIST, SHAPE_SET, SHAPE_TUPLE, SHAPE_SEQUENCE, SHAPE_TUPLE_ELLIPSIS
from pydantic.schema import get_annotation_from_field_info
from pydantic.utils import lenient_issubclass
from rpc_framework import params
from rpc_framework.context import RPCViewContext
from rpc_framework.exceptions import CalleeTypeHintValidationError
from rpc_framework.inspector import get_typed_signature
from rpc_framework.views import RPCView
sequence_shapes = {
SHAPE_LIST,
SHAPE_SET,
SHAPE_TUPLE,
SHAPE_SEQUENCE,
SHAPE_TUPLE_ELLIPSIS,
}
sequence_types = (list, set, tuple)
class Interceptor(metaclass=ABCMeta):
__slots__ = ()
@abstractmethod
def intercept(self, context, rpc_view):
pass
def get_field_info(field: ModelField) -> FieldInfo:
return field.field_info
def is_scalar_field(field: ModelField) -> bool:
field_info = get_field_info(field)
if not (
field.shape == SHAPE_SINGLETON
and not lenient_issubclass(field.type_, BaseModel)
and not lenient_issubclass(field.type_, sequence_types + (dict,))
and not isinstance(field_info, params.Body)
):
return False
if field.sub_fields:
if not all(is_scalar_field(f) for f in field.sub_fields):
return False
return True
def get_param_field(
*,
param: inspect.Parameter,
param_name: str,
default_field_info: Type[params.Param] = params.Param,
force_type: params.ParamTypes = None,
ignore_default: bool = False,
) -> ModelField:
default_value = Required
had_schema = False
if not param.default == param.empty and ignore_default is False:
default_value = param.default
if isinstance(default_value, FieldInfo):
had_schema = True
field_info = default_value
default_value = field_info.default
if (
isinstance(field_info, params.Param)
and getattr(field_info, "in_", None) is None
):
field_info.in_ = default_field_info.in_
if force_type:
field_info.in_ = force_type # type: ignore
else:
field_info = default_field_info(default_value)
required = default_value == Required
annotation: Any = Any
if not param.annotation == param.empty:
annotation = param.annotation
annotation = get_annotation_from_field_info(annotation, field_info, param_name)
if not field_info.alias and getattr(field_info, "convert_underscores", None):
alias = param.name.replace("_", "-")
else:
alias = field_info.alias or param.name
field = ModelField(
name=param.name,
type_=annotation,
default=None if required else default_value,
alias=alias,
required=required,
model_config=BaseConfig,
class_validators={},
field_info=field_info,
)
field.required = required
if not had_schema and not is_scalar_field(field=field):
field.field_info = params.Body(field_info.default)
return field
def get_callee_field_list(callee):
param_field_list = []
endpoint_signature = get_typed_signature(callee)
signature_params = endpoint_signature.parameters
if inspect.isgeneratorfunction(callee) or inspect.isasyncgenfunction(callee):
# check_dependency_contextmanagers()
pass
for param_name, param in signature_params.items():
param_field = get_param_field(
param=param, default_field_info=params.Query, param_name=param_name
)
param_field_list.append(param_field)
return param_field_list
def is_null_subfield_in_pack(packed_field_info, subfield, subfield_value):
return (isinstance(packed_field_info, params.Form) and subfield_value == "") or \
(
isinstance(packed_field_info, params.Form)
and subfield.shape in sequence_shapes
and len(subfield_value) == 0
)
def check_parameters_fit_signature(
required_params: List[ModelField],
received_body: Optional[Dict],
is_pack_params: bool=False
) -> Tuple[Dict[str, Any], List[ErrorWrapper]]:
values = {}
errors = []
if required_params:
if is_pack_params:
packed_field = required_params[0]
packed_field_info = get_field_info(packed_field)
embed = getattr(packed_field_info, "embed", None)
if len(required_params) == 1 and not embed:
received_body = {packed_field.alias: received_body}
for field in required_params:
value: Any = None
if received_body is not None:
value = received_body.get(field.alias)
if (
value is None
or (is_pack_params and is_null_subfield_in_pack(packed_field_info, field, value))
):
if field.required:
errors.append(
ErrorWrapper(MissingError(), loc=("body", field.alias))
)
else:
values[field.name] = deepcopy(field.default)
continue
print(value, values)
v_, errors_ = field.validate(value, values, loc=("body", field.alias))
if isinstance(errors_, ErrorWrapper):
errors.append(errors_)
elif isinstance(errors_, list):
errors.extend(errors_)
else:
values[field.name] = v_
return values, errors
class CalleeParametersInterceptor(Interceptor):
def intercept(self, context: RPCViewContext, rpc_view: RPCView) -> bool:
callee_field_list = get_callee_field_list(rpc_view.callee)
validated_values, errors = check_parameters_fit_signature(callee_field_list, context.request.params)
if errors:
raise CalleeTypeHintValidationError(errors)
return True
\ No newline at end of file
from enum import Enum
from typing import Any, Callable, Sequence
from pydantic.fields import FieldInfo
class ParamTypes(Enum):
query = "query"
header = "header"
path = "path"
cookie = "cookie"
class Param(FieldInfo):
in_: ParamTypes
def __init__(
self,
default: Any,
*,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
deprecated: bool = None,
**extra: Any,
):
self.deprecated = deprecated
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Path(Param):
in_ = ParamTypes.path
def __init__(
self,
default: Any,
*,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
deprecated: bool = None,
**extra: Any,
):
self.in_ = self.in_
super().__init__(
...,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Query(Param):
in_ = ParamTypes.query
def __init__(
self,
default: Any,
*,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
deprecated: bool = None,
**extra: Any,
):
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Header(Param):
in_ = ParamTypes.header
def __init__(
self,
default: Any,
*,
alias: str = None,
convert_underscores: bool = True,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
deprecated: bool = None,
**extra: Any,
):
self.convert_underscores = convert_underscores
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Cookie(Param):
in_ = ParamTypes.cookie
def __init__(
self,
default: Any,
*,
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
deprecated: bool = None,
**extra: Any,
):
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
deprecated=deprecated,
**extra,
)
class Body(FieldInfo):
def __init__(
self,
default: Any,
*,
embed: bool = False,
media_type: str = "application/json",
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Any,
):
self.embed = embed
self.media_type = media_type
super().__init__(
default,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Form(Body):
def __init__(
self,
default: Any,
*,
media_type: str = "application/x-www-form-urlencoded",
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Any,
):
super().__init__(
default,
embed=True,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class File(Form):
def __init__(
self,
default: Any,
*,
media_type: str = "multipart/form-data",
alias: str = None,
title: str = None,
description: str = None,
gt: float = None,
ge: float = None,
lt: float = None,
le: float = None,
min_length: int = None,
max_length: int = None,
regex: str = None,
**extra: Any,
):
super().__init__(
default,
media_type=media_type,
alias=alias,
title=title,
description=description,
gt=gt,
ge=ge,
lt=lt,
le=le,
min_length=min_length,
max_length=max_length,
regex=regex,
**extra,
)
class Depends:
def __init__(self, dependency: Callable = None, *, use_cache: bool = True):
self.dependency = dependency
self.use_cache = use_cache
class Security(Depends):
def __init__(
self,
dependency: Callable = None,
*,
scopes: Sequence[str] = None,
use_cache: bool = True,
):
super().__init__(dependency=dependency, use_cache=use_cache)
self.scopes = scopes or []
# # -*- coding: utf-8 -*-
# class Result(dict):
# def __init__(self, raw_data=None, exception=False):
# self.raw_data = raw_data
# self.exception = exception
#
# @property
# def value(self):
# if self.exception:
# return {}
# return self.raw_data
\ No newline at end of file
...@@ -82,7 +82,7 @@ DEFAULTS = { ...@@ -82,7 +82,7 @@ DEFAULTS = {
# 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description', # 'VIEW_DESCRIPTION_FUNCTION': 'rest_framework.views.get_view_description',
# Exception handling # Exception handling
# 'EXCEPTION_HANDLER': 'rest_framework.views.exception_handler', 'EXCEPTION_HANDLER': 'rpc_framework.views.exception_handler',
# 'NON_FIELD_ERRORS_KEY': 'non_field_errors', # 'NON_FIELD_ERRORS_KEY': 'non_field_errors',
# Testing # Testing
...@@ -140,7 +140,7 @@ IMPORT_STRINGS = [ ...@@ -140,7 +140,7 @@ IMPORT_STRINGS = [
# 'DEFAULT_PAGINATION_CLASS', # 'DEFAULT_PAGINATION_CLASS',
# 'DEFAULT_FILTER_BACKENDS', # 'DEFAULT_FILTER_BACKENDS',
# 'DEFAULT_SCHEMA_CLASS', # 'DEFAULT_SCHEMA_CLASS',
# 'EXCEPTION_HANDLER', 'EXCEPTION_HANDLER',
# 'TEST_REQUEST_RENDERER_CLASSES', # 'TEST_REQUEST_RENDERER_CLASSES',
# 'UNAUTHENTICATED_USER', # 'UNAUTHENTICATED_USER',
# 'UNAUTHENTICATED_TOKEN', # 'UNAUTHENTICATED_TOKEN',
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from django.db import connection, transaction
from rpc_framework import exceptions
from rpc_framework.context import RPCViewContext from rpc_framework.context import RPCViewContext
from rpc_framework.generic.base import RPCAbstractView from rpc_framework.generic.base import RPCAbstractView
from rpc_framework.settings import api_settings from rpc_framework.settings import api_settings
def set_rollback():
atomic_requests = connection.settings_dict.get('ATOMIC_REQUESTS', False)
if atomic_requests and connection.in_atomic_block:
transaction.set_rollback(True)
def exception_handler(exc, wrapped_context) -> None:
if issubclass(exc.__class__, exceptions.RPCViewBaseException):
set_rollback()
raise exc.as_rpcd_fault() # type: RPCDFaultException
return None
class RPCView(RPCAbstractView): class RPCView(RPCAbstractView):
# The following policies may be set at either globally, or per-view. # The following policies may be set at either globally, or per-view.
# renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES # renderer_classes = api_settings.DEFAULT_RENDERER_CLASSES
...@@ -38,6 +55,9 @@ class RPCView(RPCAbstractView): ...@@ -38,6 +55,9 @@ class RPCView(RPCAbstractView):
view.initkwargs = initkwargs view.initkwargs = initkwargs
return view return view
# Note: Views are made CSRF exempt from within `as_view` as to prevent
# accidental removal of this exemption in cases where `dispatch` needs to
# be overridden.
def dispatch(self, context, *args, **kwargs): def dispatch(self, context, *args, **kwargs):
""" """
`.dispatch()` is pretty much the same as Django's regular dispatch, `.dispatch()` is pretty much the same as Django's regular dispatch,
...@@ -47,47 +67,21 @@ class RPCView(RPCAbstractView): ...@@ -47,47 +67,21 @@ class RPCView(RPCAbstractView):
self.kwargs = kwargs self.kwargs = kwargs
context = self.initialize_context(context, *args, **kwargs) context = self.initialize_context(context, *args, **kwargs)
self.context = context self.context = context
# self.headers = self.default_response_headers # deprecate?
try: try:
self.initial(context, *args, **kwargs) self.initial(context, *args, **kwargs)
# Get the appropriate handler method # Get the appropriate handler method
handler = self.handler handler = self.handler
response = handler(context, *args, **kwargs) result = handler(context, *args, **kwargs)
except Exception as exc: except Exception as exc:
print('+' * 10 , exc) print('error::', exc, type(exc))
response = self.handle_exception(exc) result = self.handle_exception(exc)
self.response = self.finalize_response(context, response, *args, **kwargs)
return self.response
self.result = self.finalize_response(context, result, *args, **kwargs)
return self.result
# @property
# def allowed_methods(self):
# """
# Wrap Django's private `_allowed_methods` interface in a public property.
# """
# return self._allowed_methods()
#
# @property
# def default_response_headers(self):
# headers = {
# 'Allow': ', '.join(self.allowed_methods),
# }
# if len(self.renderer_classes) > 1:
# headers['Vary'] = 'Accept'
# return headers
#
# def http_method_not_allowed(self, request, *args, **kwargs):
# """
# If `request.method` does not correspond to a handler method,
# determine what kind of exception to raise.
# """
# raise exceptions.MethodNotAllowed(request.method)
#
# def permission_denied(self, request, message=None): # def permission_denied(self, request, message=None):
# """ # """
# If request is not permitted, determine what kind of exception to raise. # If request is not permitted, determine what kind of exception to raise.
...@@ -137,19 +131,19 @@ class RPCView(RPCAbstractView): ...@@ -137,19 +131,19 @@ class RPCView(RPCAbstractView):
# 'kwargs': getattr(self, 'kwargs', {}), # 'kwargs': getattr(self, 'kwargs', {}),
# 'request': getattr(self, 'request', None) # 'request': getattr(self, 'request', None)
# } # }
#
# def get_exception_handler_context(self): def get_exception_handler_context_wrapper(self):
# """ """
# Returns a dict that is passed through to EXCEPTION_HANDLER, Returns a dict that is passed through to EXCEPTION_HANDLER,
# as the `context` argument. as the `context` argument.
# """ """
# return { return {
# 'view': self, 'view': self,
# 'args': getattr(self, 'args', ()), 'args': getattr(self, 'args', ()),
# 'kwargs': getattr(self, 'kwargs', {}), 'kwargs': getattr(self, 'kwargs', {}),
# 'request': getattr(self, 'request', None) 'context': getattr(self, 'context', None)
# } }
#
# def get_view_name(self): # def get_view_name(self):
# """ # """
# Return the view name, as used in OPTIONS responses and in the # Return the view name, as used in OPTIONS responses and in the
...@@ -210,14 +204,14 @@ class RPCView(RPCAbstractView): ...@@ -210,14 +204,14 @@ class RPCView(RPCAbstractView):
self._negotiator = self.content_negotiation_class() self._negotiator = self.content_negotiation_class()
return self._negotiator return self._negotiator
# def get_exception_handler(self): def get_exception_handler(self):
# """ """
# Returns the exception handler that this view uses. Returns the exception handler that this view uses.
# """ """
# return self.settings.EXCEPTION_HANDLER return self.settings.EXCEPTION_HANDLER
#
# # API policy implementation methods # API policy implementation methods
#
# def perform_content_negotiation(self, request, force=False): # def perform_content_negotiation(self, request, force=False):
# """ # """
# Determine which renderer and media type to use render the response. # Determine which renderer and media type to use render the response.
...@@ -243,6 +237,14 @@ class RPCView(RPCAbstractView): ...@@ -243,6 +237,14 @@ class RPCView(RPCAbstractView):
# context.user # context.user
return None return None
def perform_interceptor(self, context):
"""
Check if the request should be permitted.
Raises an appropriate exception if the request is not permitted.
"""
for interceptor in self.get_interceptors():
interceptor.intercept(context, self)
def check_permissions(self, context): def check_permissions(self, context):
""" """
Check if the request should be permitted. Check if the request should be permitted.
...@@ -326,39 +328,30 @@ class RPCView(RPCAbstractView): ...@@ -326,39 +328,30 @@ class RPCView(RPCAbstractView):
# request.version, request.versioning_scheme = version, scheme # request.version, request.versioning_scheme = version, scheme
# Ensure that the incoming request is permitted # Ensure that the incoming request is permitted
self.perform_interceptor(context)
self.perform_authentication(context) self.perform_authentication(context)
self.check_permissions(context) self.check_permissions(context)
self.check_throttles(context) self.check_throttles(context)
def finalize_response(self, context, response, *args, **kwargs): def finalize_response(self, context, result, *args, **kwargs):
return response return result
# def handle_exception(self, exc): def handle_exception(self, exc):
# """ """
# Handle any exception that occurs, by returning an appropriate response, Handle any exception that occurs, by returning an appropriate response,
# or re-raising the error. or re-raising the error.
# """ """
# if isinstance(exc, (exceptions.NotAuthenticated, exception_handler = self.get_exception_handler()
# exceptions.AuthenticationFailed)):
# # WWW-Authenticate header for 401 responses, else coerce to 403 wrapped_context = self.get_exception_handler_context_wrapper()
# auth_header = self.get_authenticate_header(self.request) result = exception_handler(exc, wrapped_context)
#
# if auth_header: # if not result:
# exc.auth_header = auth_header
# else:
# exc.status_code = status.HTTP_403_FORBIDDEN
#
# exception_handler = self.get_exception_handler()
#
# context = self.get_exception_handler_context()
# response = exception_handler(exc, context)
#
# if response is None:
# self.raise_uncaught_exception(exc) # self.raise_uncaught_exception(exc)
#
# response.exception = True # result.exception = True
# return response return result
#
# def raise_uncaught_exception(self, exc): # def raise_uncaught_exception(self, exc):
# if settings.DEBUG: # if settings.DEBUG:
# request = self.request # request = self.request
...@@ -366,44 +359,3 @@ class RPCView(RPCAbstractView): ...@@ -366,44 +359,3 @@ class RPCView(RPCAbstractView):
# use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin') # use_plaintext_traceback = renderer_format not in ('html', 'api', 'admin')
# request.force_plaintext_errors(use_plaintext_traceback) # request.force_plaintext_errors(use_plaintext_traceback)
# raise exc # raise exc
\ No newline at end of file
#
# # Note: Views are made CSRF exempt from within `as_view` as to prevent
# # accidental removal of this exemption in cases where `dispatch` needs to
# # be overridden.
# def dispatch(self, request, *args, **kwargs):
# """
# `.dispatch()` is pretty much the same as Django's regular dispatch,
# but with extra hooks for startup, finalize, and exception handling.
# """
# self.args = args
# self.kwargs = kwargs
# request = self.initialize_request(request, *args, **kwargs)
# self.request = request
# self.headers = self.default_response_headers # deprecate?
#
# try:
# self.initial(request, *args, **kwargs)
#
# # Get the appropriate handler method
# if request.method.lower() in self.http_method_names:
# handler = getattr(self, request.method.lower(),
# self.http_method_not_allowed)
# else:
# handler = self.http_method_not_allowed
#
# response = handler(request, *args, **kwargs)
#
# except Exception as exc:
# response = self.handle_exception(exc)
#
# self.response = self.finalize_response(request, response, *args, **kwargs)
# return self.response
#
# def options(self, request, *args, **kwargs):
# """
# Handler method for HTTP 'OPTIONS' request.
# """
# if self.metadata_class is None:
# return self.http_method_not_allowed(request, *args, **kwargs)
# data = self.metadata_class().determine_metadata(request, self)
# return Response(data, status=status.HTTP_200_OK)
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment