# coding=utf-8
from __future__ import unicode_literals, absolute_import, print_function

import contextlib

import logging

import six
import sys
import json
import functools
import time
import warnings
import traceback
from django.conf import settings
from django.db.models import Model
from raven.contrib.django.raven_compat.models import client
from hera.models import UserPerm
from rpc.logging import GaiaRequestInfoExtractor
from .context import Context, Session, ConnectionInfo, Request, ContextManager
from .exceptions import (
    GaiaRPCFaultException, RPCValidationError,
    RPCPermissionDeniedException, RPCStaffRequiredException
)
from .tool.log_tool import profile_logger, logging_exception, info_logger
from . import validation
from . import permissions

from gm_logging.gm_rpcd import request_logging_guard_maker
import gm_logging.py_logging

import threading

request_source = threading.local() # 记录请求来源

class MethodMiddleware(object):
    def is_active_for(self, options):
        return True

    def process(self, rest, request):
        return rest()


class SentryRecordMiddleWare(MethodMiddleware):
    def process(self, rest, request):
        client.context.merge({
            'tags': {
                'method': request.method,
                'sesion_key': request.session_key,
            },
            'extra': {
                'params': request.params,
                'environment': request.environment,
            }
        })
        return rest()


class SetActiveContextMiddleware(MethodMiddleware):
    def process(self, rest, request):
        with ContextManager.with_active_context(request.context):
            return rest()


class DeprecatedMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return bool(options['deprecated'])

    def process(self, rest, request):
        response = rest()

        message = 'method deprecated: {}'.format(request.method)
        deprecated = request.method_info.options['deprecated']
        if isinstance(deprecated, six.string_types):
            message += ", " + deprecated

        response.setdefault('debug', {}).setdefault('messages', []).append(message)
        return response


class LoggingExceptionMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['logging_exception']

    def process(self, rest, request):
        try:
            return rest()
        except:
            request_info = request.request_info
            logging_exception(tags={
                'log_id': request_info.log_id,
                'span_id': request_info.span_id,
                'parent_span_id': request_info.parent_span_id,
            })
            raise


class CatchFaultExceptionMethodMiddleware(MethodMiddleware):
    def process(self, rest, request):
        try:
            return rest()
        except GaiaRPCFaultException as e:
            return {
                'error': e.error,
                'message': e.message,
                'data': e.data,
            }


class LoginRequiredMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['login_required']

    def process(self, rest, request):
        info_logger.info("jindaole process")
        request.context.session.login_required()
        return rest()


class StaffRequiredMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['staff_required']

    def process(self, rest, request):
        session = request.context.session
        session.login_required()
        if not session.user.is_staff:
            raise RPCStaffRequiredException
        return rest()


class PermRequiredMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['perm_required']

    def process(self, rest, request):
        session = request.context.session
        session.login_required()
        if not UserPerm.check_perm(session.user, request.method_info.options['perm_required']):
            raise RPCPermissionDeniedException
        return rest()


class ConditionalSqlLogMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['conditional_sql_log']

    def process(self, rest, request):
        request_info = request.request_info
        enable_sql_log = request_info and request_info.trace_enabled
        if enable_sql_log:
            user_ip = request_info.user_ip
            user_ip_white_list = settings.GM_DEVELOPER_IP_WHITE_LIST
            if user_ip in user_ip_white_list:
                with self.enable_django_sql_debug_cursor():
                    return rest()
        return rest()

    @staticmethod
    @contextlib.contextmanager
    def enable_django_sql_debug_cursor():
        from django.db import connections

        old_value_list = []

        for alias in connections:
            c = connections[alias]
            old_value_list.append((c, c.force_debug_cursor))
            c.force_debug_cursor = True

        yield

        for c, force_debug_cursor in old_value_list:
            c.force_debug_cursor = force_debug_cursor

    @staticmethod
    def install():
        if not settings.ENABLE_CONDITIONAL_SQL_LOG:
            return
        logger = logging.getLogger('django.db.backends')
        logger.setLevel(logging.DEBUG)
        logger.addHandler(gm_logging.py_logging.inspect_handler)


ConditionalSqlLogMiddleware.install()


class ReturnDataOnlyMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['return_data_only']

    def process(self, rest, request):
        data = rest()
        response = {
            'error': 0,
            'message': '',
            'data': data,
        }
        return response


class DispatchWithContextMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return options['with_context']

    def process(self, rest, request):
        return request.method_info.func(request.context, **request.params)


class DispatchWithoutContextMethodMiddleware(MethodMiddleware):
    def is_active_for(self, options):
        return not options['with_context']

    def process(self, rest, request):
        return request.method_info.func(**request.params)


class TailMethodMiddleware(MethodMiddleware):
    def process(self, rest, request):
        raise Exception('method not invoked')


class ListInterfaceDescriptor(object):
    instance_list = []

    def __init__(self, func, offset_name, limit_name, element_model, element_func_list):
        assert isinstance(offset_name, six.string_types + (type(None),))
        assert isinstance(limit_name, six.string_types)
        if element_model is not None:
            assert isinstance(element_model, type)
            assert issubclass(element_model, Model)
        if element_func_list is None:
            element_func_list = []
        for ef in element_func_list:
            assert any([
                callable(ef),
                isinstance(ef, property),
            ])

        self.func = func
        self.offset_name = offset_name
        self.limit_name = limit_name
        self.element_model = element_model
        self.element_func_list = element_func_list

        self.instance_list.append(self)

    @classmethod
    def decorator(cls, offset_name=None, limit_name=None, element_model=None, element_func_list=None):

        def inner_decorator(func):
            @functools.wraps(func)
            def wrapper(*args, **kwargs):
                limit = kwargs.get(limit_name, 0)
                if limit > settings.COUNT_LIMIT:
                    kwargs[limit_name] = settings.COUNT_LIMIT

                '''
                    fix http://sentry.wanmeizhensuo.com/sentry/prod-gaia/issues/1903072/
                    TransportError(500, u'search_phase_execution_exception',
                    u'Result window is too large, from + size must be less than or equal to: [10000] but was [25980].
                    See the scroll api for a more efficient way to request large data sets.
                    This limit can be set by changing the [index.max_result_window] index level parameter.')
                '''
                offset = kwargs.get(offset_name, 0)
                if offset > settings.OFFSET_LIMIT:
                    kwargs[offset_name] = settings.OFFSET_LIMIT

                return func(*args, **kwargs)

            descriptor = cls(
                func=wrapper,
                offset_name=offset_name,
                limit_name=limit_name,
                element_model=element_model,
                element_func_list=element_func_list,
            )
            wrapper.list_interface_descriptor = descriptor

            return wrapper

        return inner_decorator

    @classmethod
    def get_descriptor_from_decorated_func(cls, func):
        return


class MethodInfo(object):
    middleware_list = [
        SetActiveContextMiddleware(),
        DeprecatedMethodMiddleware(),
        LoggingExceptionMethodMiddleware(),
        SentryRecordMiddleWare(),
        CatchFaultExceptionMethodMiddleware(),
        LoginRequiredMethodMiddleware(),
        StaffRequiredMethodMiddleware(),
        PermRequiredMethodMiddleware(),
        ConditionalSqlLogMiddleware(),
        ReturnDataOnlyMethodMiddleware(),
        DispatchWithContextMethodMiddleware(),
        DispatchWithoutContextMethodMiddleware(),
        TailMethodMiddleware(),
    ]

    option_defaults = {
        'with_context': True,
        'show_document': True,
        'login_required': False,
        'operation_perm_required': False,
        'return_data_only': True,
        'logging_exception': True,
        'perm_required': None,
        'staff_required': False,
        'conditional_sql_log': settings.ENABLE_CONDITIONAL_SQL_LOG,
        'readonly': False,
        'deprecated': False,
        'allow_direct_call': False,
    }

    option_keys = set(option_defaults.keys())

    def __init__(self, name, func, options):
        assert isinstance(name, basestring)
        assert callable(func)
        assert isinstance(options, dict)
        self.name = name
        self.func = func
        self.list_interface_descriptor = getattr(func, 'list_interface_descriptor', None)

        options = dict(options)
        for k in options.keys():
            if k not in self.option_keys:
                raise Exception("Unexpected Option: {}\n"
                                "Valid Options: {}".format(k, ', '.join(self.option_keys)))
        for k, d in self.option_defaults.items():
            options.setdefault(k, d)
        self.options = options

        self.o_with_context = self.options['with_context']

        active_middleware_list = []
        for m in self.middleware_list:
            assert isinstance(m, MethodMiddleware)
            if m.is_active_for(options):
                active_middleware_list.append(m)
        self.active_middleware_list = tuple(active_middleware_list)
        self._request_processor = self._get_request_processor()

    def _get_request_processor(self):
        reversed_active_middleware_list = tuple(reversed(self.active_middleware_list))

        def request_processor(request):
            rest = None
            for m in reversed_active_middleware_list:
                assert isinstance(m, MethodMiddleware)
                rest = functools.partial(m.process, rest, request)
            return rest()

        return request_processor

    def process_request(self, request):
        return self._request_processor(request)

    def direct_call(self, *args, **kwargs):
        """
        这么做是否合适还需要考虑
        """
        if not self.options['allow_direct_call']:
            raise Exception('direct_call is not enabled for: {}'.format(self.name))
        return self.func(*args, **kwargs)


class APIManager(object):
    def __init__(self):
        self._freezed = False
        self.method_map = {}
        self.documented_method_map = {}
        self._request_info_extractor = GaiaRequestInfoExtractor()

    def register_method_info(self, method_info):
        assert isinstance(method_info, MethodInfo)
        name = method_info.name
        if name in self.method_map:
            raise Exception("method {} already registered".format(name))

        if self._freezed and settings.DEBUG:
            warnings.warn('method registered after freezed: {}\n{}'.format(
                name,
                traceback.format_stack(),
            ))

        self.method_map[name] = method_info
        if method_info.options['show_document']:
            self.documented_method_map[name] = method_info

    def bind_general(self, name, options, handler):
        original_method_info = None
        if isinstance(handler, MethodInfo):
            original_method_info = handler
            handler = handler.func
        method_info = MethodInfo(
            name=name,
            func=handler,
            options=options,
        )
        self.register_method_info(method_info)

        if original_method_info is not None:
            if original_method_info.options['allow_direct_call'] != method_info.options['allow_direct_call']:
                raise Exception('different option allow_direct_call between: {}, {}'.format(
                    original_method_info.name,
                    method_info.name,
                ))

        return method_info

    def freeze(self):
        self._freezed = True
        self.validate()

    def validate(self):
        traced_lld = set()
        for mi in self.method_map.values():
            assert isinstance(mi, MethodInfo)
            lld = mi.list_interface_descriptor
            if not lld:
                continue
            traced_lld.add(lld)
        all_lld = frozenset(ListInterfaceDescriptor.instance_list)
        assert traced_lld.issubset(all_lld)
        if traced_lld != all_lld:
            for lld in all_lld - traced_lld:
                assert isinstance(lld, ListInterfaceDescriptor)
                print('Not traced ListInterfaceDescriptor for: {}.{}, {}'.format(
                    getattr(lld.func, '__module__', ''),
                    getattr(lld.func, '__name__', ''),
                    repr(lld.func)
                ), file=sys.stderr)

    def get_method_list(self):
        return self.method_map.keys()

    def get_method_info(self, method):
        return self.method_map.get(method, None)

    def dispatch(self, request, connection_info=None):
        request_info = self._request_info_extractor.get_request_info(request)
        request_source.req_source = getattr(request_info, 'req_source', None)
        request.request_info = request_info
        with request_logging_guard_maker(request_info) as guard:
            request.context.logger = guard.logger
            result = self._dispatch(request=request, connection_info=connection_info)
            guard.set_errno(result['error'])
            return result

    def _dispatch(self, request, connection_info=None):
        time_start = time.time()
        clock_start = time.clock()

        method_info = self.method_map.get(request.method, None)

        # 冻结原始参数
        params_serialized = json.dumps(request.params)

        error = 'unexpected'
        try:
            if method_info:
                request.method_info = method_info

                if not ('_profile' in request.params and settings.DEBUG):
                    response = method_info.process_request(request)

                else:
                    # delete args _profile
                    del request.params['_profile']

                    # do profile
                    import inspect
                    from six import StringIO
                    import line_profiler

                    p = line_profiler.LineProfiler()
                    p.add_function(method_info.func)
                    p.add_module(inspect.getmodule(method_info.func))
                    try:
                        response = p.runcall(method_info.process_request, request)
                    except:
                        raise

                    try:
                        output = StringIO()
                        p.print_stats(stream=output, stripzeros=True)
                        report = output.getvalue()
                        info_logger.info(report)
                        response['data']['_profile'] = report
                    except:
                        pass

            else:
                response = {
                    'error': 404,
                    'data': None,
                    'message': 'Method not found: method={}'.format(request.method),
                }
        except Exception:
            raise
        else:
            error = response['error']
        finally:
            time_end = time.time()
            clock_end = time.clock()

            format_string = (
                "!fmt6 {client_ip} {method} {nested} "
                "{time:.6f} {clock:.6f} {error} {session_key} {params}"
            )
            message = format_string.format(
                client_ip=connection_info.client_ip if connection_info else '-',
                time=time_end - time_start,
                clock=clock_end - clock_start,
                error=error,
                method=request.method,
                nested='nested' if request.is_nested_call else '-',
                session_key=request.session_key or '-',
                params=params_serialized,
            )
            profile_logger.info(message)

        return response
