base_view.py 12.2 KB
# -*- coding: utf-8 -*-

import sys
import copy
import time
import traceback
import hashlib
from line_profiler import LineProfiler

from io import StringIO

from django.conf import settings
from django.views.generic import View
from django.views.decorators.csrf import csrf_exempt
from django.utils.http import cookie_date
from django.http import HttpResponseBadRequest

from helios.rpc.exceptions import RPCFaultException
from gm_protocol import GmProtocol
from alpha_types.venus.error import ERROR as ALPHA_ERROR

from libs.user import login_required
from engine.logger import logging_exception
from engine.response import json_response
from engine.rpc import get_current_rpc_invoker
from libs.utils import DictWrapperUseDot


class ErrorInfo(object):
    def __init__(self, code=None, msg=None):
        self.code = code
        self.msg = msg

    def __nonzero__(self):
        if self.code is None and self.msg is None:
            return False
        else:
            return True

    def __repr__(self):
        return '<ErrorInfo> code: %s, msg: %s' % (self.code, self.msg)


class LoginRequiredMixin(object):

    @classmethod
    def as_view(cls, **initkwargs):
        view = super(LoginRequiredMixin, cls).as_view(**initkwargs)
        return login_required(view)


class BaseView(View):
    """base view."""

    decorators = [csrf_exempt, ]

    def get_ErrorInfo(self, error_code):
        return ErrorInfo(code=error_code, msg=ALPHA_ERROR.getDesc(error_code))

    def _set_tracking_data(self, key, value):
        try:
            self._request.logger.app(**{key: value})
        except:
            pass

    def set_track(self, *args, **kwargs):
        """Set track:埋点"""

        data = {}
        for k in args:
            v = self._request.GET.get(k) or self._request.POST.get(k)
            if v:
                data.update({k: v})

        for k, v in kwargs.items():
            data.update({k: v})
        self._request.logger.app(**data)

    def _set_user_type(self, key, value):

        try:
            if not hasattr(self, 'user_type'):
                self.user_type = {}

            self.user_type[key] = value
            self._set_tracking_data(key, value)
        except:
            pass

    def _get_user_type(self):

        try:
            if hasattr(self, 'user_type'):
                return self.user_type
            return {}
        except:
            pass

    def __init__(self, *args, **kwargs):
        super(BaseView, self).__init__(*args, **kwargs)

        self.gm_protocol_helper = GmProtocol(settings.API_HOST, settings.WEB_API_HOST)

        self._rpc_client = None
        self._rpc_call_info = []
        self._rpc_profile_info = []
        self._profile = False

        # for preparing parallel rpc calls before dispatch
        self._rpc_calls = {}
        self._parallel_rpc_info = {}
        self.parallel_rpc_call_result = {}

    def add_parallel_rpc_call_info(self, name, api_endpoint, **kwargs):

        # NOTE: only accept kwargs, no profile information
        self._rpc_calls[name] = self._rpc_client.parallel[api_endpoint](**kwargs)
        if settings.DEBUG:
            self._parallel_rpc_info[name] = {}
            self._parallel_rpc_info[name].update(endpoint=api_endpoint, args=kwargs)

    def shoot_rpc_calls_in_parallel(self):
        for k, v in self._rpc_calls.items():
            result = {}
            error_info = None

            try:
                result = v.unwrap()
                if settings.DEBUG:
                    self._parallel_rpc_info[k]['result'] = copy.deepcopy(result)
            except RPCFaultException as e:
                error_info = ErrorInfo()
                error_msg = self.get_error_message(e.error)
                error_info.code, error_info.msg = e.error, error_msg

            except Exception as e:
                if settings.DEBUG:
                    raise Exception(e)
                logging_exception()
                error_info = ErrorInfo()
                error_msg = self.get_error_message()
                error_info.code, error_info.msg = ALPHA_ERROR.DEFAULT_ERROR_CODE, error_msg

            self.parallel_rpc_call_result[k] = (error_info, result)

        return self.parallel_rpc_call_result

    @classmethod
    def as_view(cls, **initkwargs):
        view = super(BaseView, cls).as_view(**initkwargs)

        for deco in cls.decorators[::-1]:
            view = deco(view)

        return view

    def ok(self, data=None, message='', extra=None):
        return self.as_json_response(0, message=message, data=data, extra=extra)

    def error(self, error=1, message='服务器开小差了~', data=None, error_code=None):
        # 传进来error info的话, 让as_json_response使用error info中的error message
        if error != 1:
            message = ''
        return self.as_json_response(error, message=message, data=data, error_code=error_code)

    def parameter_invalid_response(self):
        return self.as_json_response(True, '参数不合法', status=400)

    def as_json_response(self, error, message='', data=None, status=200, error_code=None, extra=None):
        """return a http response with json content.
            如果定义了传了message,且message不为空,
        """
        if isinstance(error, ErrorInfo) and error:
            result = {
                'error': 1,
                'error_code': error.code,
                'extra': extra,
            }

            result['message'] = message or error.msg
        else:
            result = {
                'error': error and 1 or 0,
                'message': message,
                'extra': extra,
            }

            if error and error_code is None:
                result['error_code'] = ALPHA_ERROR.DEFAULT_ERROR_CODE
            elif error and error_code is not None:   # 处理传递进来的error_code
                result['error_code'] = error_code
                result['message'] = ALPHA_ERROR.getDesc(error_code)

        if data is not None:
            result['data'] = data

        # 如果是没有登录状态码统一设置为403
        if result.get('error_code') == ALPHA_ERROR.LOGIN_REQUIRED:
            status = 403

        result['user_type'] = self._get_user_type()

        if settings.DEBUG and self._request.GET.get('debug'):
            result['_debug'] = {
                'rpc_call_info': self._rpc_call_info,
                'parallel_rpc_info': self._parallel_rpc_info,
                'session': self._request.session.session_key,
                'request': {
                    'GET': self._request.GET.items(),
                    'POST': self._request.POST.items(),
                },
            }

        return json_response(result, status)

    def call_rpc(self, endpoint, raw=True, **kwargs):
        """call rpc, return wrapped result if raw is false.

        args:
            raw: default as true, if it's false, then get an wrapped object,
                 this can only be used if the rpc call result is a dict.

        return:
            *error_msg, result. `if error_msg is not None, exception raise`*

        NOTE:
            check error_msg for every rpc call
        """
        error_msg = None
        result = {}
        error_info = None

        if self._profile:
            kwargs['_profile'] = True

        call_info = {}
        start_rpc_call = time.time()
        try:
            resp = self._rpc_client[endpoint](**kwargs)
            call_info.update(endpoint=endpoint, args=kwargs)
            result = resp.unwrap()
            call_info['result'] = copy.deepcopy(result)

            if self._profile:
                self._rpc_profile_info.append(result['_profile'])
                del result['_profile']

        except RPCFaultException as e:
            error_info = ErrorInfo()
            error_msg = self.get_error_message(e.error)
            error_info.code, error_info.msg = e.error, error_msg

        except Exception as e:
            if settings.DEBUG:
                raise Exception(e)
            logging_exception()
            error_info = ErrorInfo()
            error_msg = self.get_error_message()
            error_info.code, error_info.msg = ALPHA_ERROR.DEFAULT_ERROR_CODE, error_msg
            exc_type, exc_value, exc_traceback = sys.exc_info()
            exc = traceback.format_exception(exc_type, exc_value, exc_traceback)
            call_info['result'] = '\n'.join(exc)

        call_info['duration'] = time.time() - start_rpc_call

        self._rpc_call_info.append(call_info)

        # return raw result
        if raw:
            return error_info, result

        """
            DictWrapperUseDot HOWTO:

                result = {
                    'key': value,
                    'keynested': {
                        'keyn': value,
                    },
                }

            wrap result as an object:

                obj = DictWrapperUseDot(result)

            then you can use it as:

                obj.key
                obj.keynested.keyn
        """

        return error_info, DictWrapperUseDot(result)

    def get_error_message(self, rpc_exception_code=-1):
        """get rpc error message.

        args:
            rpc_exception_code: -1 as the default error msg
        """
        default_msg = '服务器忙,请稍后再试'
        msg = ALPHA_ERROR.getDesc(rpc_exception_code)
        return msg and msg or default_msg

    def _inner_dispatch_for_baseview(self, request, *args, **kwargs):
        # TODO: validate and clean args
        response = super(BaseView, self).dispatch(request, *args, **kwargs)
        return response

    def dispatch(self, request, *args, **kwargs):
        start_num = request.GET.get('start_num')
        try:
            start_num = int(start_num)
        except:
            start_num = 0

        if start_num > settings.START_NUM_UPLIMIT:
            return HttpResponseBadRequest()

        self._set_tracking_data('from', request.GET.get('from', ''))
        rpc_client = getattr(request, 'rpc', None)
        if rpc_client is None:
            rpc_client = get_current_rpc_invoker()
        self._rpc_client = rpc_client
        self._request = request

        if settings.DEBUG and '_profile' in request.GET:
            self._profile = True

            profiler = LineProfiler()
            handler = getattr(self, request.method.lower(), self.http_method_not_allowed)
            profiler.add_function(handler)

            profiler.enable()
            response = self._inner_dispatch_for_baseview(request, *args, **kwargs)
            profiler.disable()

            out = StringIO()
            profiler.print_stats(stream=out, stripzeros=True)

            content = out.getvalue()
            rpc_profile_info = '#' * 10 + " RPC PROFILE " + '#' * 10 + '\n'
            rpc_profile_info += '\n'.join(self._rpc_profile_info)
            content += rpc_profile_info
            response.content = content
            response['Content-type'] = 'text/plain'

        else:
            response = self._inner_dispatch_for_baseview(request, *args, **kwargs)

        return response

    def set_response_cookie(self, response, session_key):
        expires_time = time.time() + settings.SESSION_COOKIE_AGE
        expires = cookie_date(expires_time)

        domain = settings.SESSION_COOKIE_DOMAIN_IGENGMEI
        response.set_cookie(
            settings.USER_COOKIE_NAME,
            session_key,
            max_age=settings.SESSION_COOKIE_AGE,
            expires=expires,
            domain=domain,
            path=settings.SESSION_COOKIE_PATH,
            secure=settings.SESSION_COOKIE_SECURE or None,
            httponly=settings.SESSION_COOKIE_HTTPONLY or None
        )
        return response

    def del_cookie(self, response):
        response.delete_cookie(settings.USER_COOKIE_NAME)
        return response

    @property
    def parallel_client(self):
        return self.request.rpc.parallel


class BaseViewLoginRequired(LoginRequiredMixin, BaseView):
    """base view for login required api."""

    def __init__(self, *args, **kwargs):
        super(BaseViewLoginRequired, self).__init__(*args, **kwargs)


def get_offset_count(request):

    try:
        page = int(request.GET.get('page', 1))
    except:
        page = 1
    try:
        count = int(request.GET.get('count', 10))
    except:
        count = 10

    offset = count * (page-1)

    return offset, count