# -*- coding: UTF-8 -*-

from django.db.models import Q
from django.utils.html import escape
from django.conf import settings

from rpc.tool.log_tool import info_logger
from rpc.tool.dict_mixin import to_dict

LIST_TYPE = (list, tuple, set)
STR_TYPE = (str, unicode)
DEFAULT_DBS = [getattr(settings, 'HERA_READ_DB', ''), 'default']


def get_db_key(dbs):
    for db in dbs:
        if db in settings.DATABASES:
            return db
    return None


class BrExpandedDict(dict):
    """
    A special dictionary constructor that takes a dictionary in which the keys
    may contain brackets to specify inner dictionaries. It's confusing, but this
    example should make sense.

    >>> d = BrExpandedDict({'person[1][firstname]': ['Simon'], \
            'person[1][lastname]': ['Willison'], \
            'person[2][firstname]': ['Adrian'], \
            'person[2][lastname]': ['Holovaty']})
    >>> d
    {'person': {'1': {'lastname': ['Willison'], 'firstname': ['Simon']}, '2': {'lastname': ['Holovaty'], 'firstname': ['Adrian']}}}
    >>> d['person']
    {'1': {'lastname': ['Willison'], 'firstname': ['Simon']}, '2': {'lastname': ['Holovaty'], 'firstname': ['Adrian']}}
    >>> d['person']['1']
    {'lastname': ['Willison'], 'firstname': ['Simon']}

    """
    def __init__(self, key_to_list_mapping):
        for k, v in key_to_list_mapping.items():
            current = self
            k = k.replace(']', '')
            bits = k.split('[')
            for bit in bits[:-1]:
                current = current.setdefault(bit, {})
            # Now assign value to current position
            try:
                current[bits[-1]] = v
            except TypeError:  # Special-case if current isn't a dict.
                current = {bits[-1]: v}


class DataTable(object):
    def __init__(self, model, init_q=None):
        self.model = model
        self.init_q = init_q

    def parse_request_data(self, req_dict):
        """
        draw: draw
        paging:
            start, length
        order:[
            {name, dir}
        ]
        search:{
            _global: {value, regex},
            name: {value, regex}
        }
        """
        params = BrExpandedDict(req_dict)
        order_list = []
        for i in range(len(params['order'])):
            order_list.append(params['order'][str(i)])
        params['order'] = order_list

        columns_list = []
        for i in range(len(params['columns'])):
            columns_list.append(params['columns'][str(i)])
        params['columns'] = columns_list

        columns = params['columns']
        options = {
            'draw': params['draw'],
            'columns': [col['name'] for col in columns],
            'paging': {
                'start': int(params.get('start', 0)),
                'length': int(params.get('length', 10)),
            },
            'order': [{
                'name': columns[int(o['column'])]['name'],
                'dir': o['dir'],
                } for o in params['order']],
            'search': self.search_opts(params),
        }
        return options

    def search_opts(self, params):
        search_opts = {}
        item = params['search']
        v = item['value']
        if v:
            search_opts['_global'] = {
                'value': v,
                'regex': True if item['regex'] == 'true' else False,
            }
        columns = params['columns']
        for col in columns:
            if col['searchable'] and col['search']['value']:
                search_opts[col['name']] = {
                    'value': col['search']['value'],
                    'regex': True if col['search']['regex'] == 'true' else False,
                }
        return search_opts

    def process(self, req_data=None, global_fields=None):
        objs = self.build_queryset(req_data, global_fields)
        cnt = objs.count()

        # ceo feature, roll back soon
        from api.models import Order
        if self.model == Order:
            objs = objs.exclude(rel__is_hidden=True)

        options = self.parse_request_data(req_data)
        paging_opts = options.pop('paging')
        if paging_opts:
            objs = objs[paging_opts['start']:paging_opts['start']+paging_opts['length']]

        return self.build_resp_data(options, cnt, objs)

    ##################################################
    # return queryset without paging
    ##################################################
    def build_queryset(self, req_data=None, global_fields=None):
        options = self.parse_request_data(req_data)
        srch_opts = options.pop('search', {})
        order_opts = options.pop('order')

        using_dbs = DEFAULT_DBS
        using = get_db_key(using_dbs)

        q = self.init_q if self.init_q is not None else Q()
        global_srch = srch_opts.pop('_global', None)
        q &= self.global_query(global_fields, global_srch)
        q &= self.fields_query(srch_opts)

        if order_opts:
            order_fields = ['-%s' % (o['name']) if o['dir'] == 'desc' else o['name']
                            for o in order_opts]

        objs = self.model.objects.using(using).filter(q).distinct().order_by(*order_fields)
        return objs

    ##################################################
    # response data
    ##################################################
    def build_resp_data(self, options, total, objs):
        resp_data = {
            'draw': int(options['draw']),
            'recordsTotal': total,
            'recordsFiltered': total,
            'data': self.build_objs_data(objs, options['columns']),
        }
        return resp_data

    def build_objs_data(self, objs, columns):
        data = [self.build_obj_data(obj, columns) for obj in objs]
        return data

    def build_obj_data(self, obj, columns):
        obj_data = {}
        obj_data = to_dict(obj, columns)
        for field in columns:
            try:
                fn_vfield = 'getval_{}'.format(field)
                fn_val = getattr(self, fn_vfield, None)
                if callable(fn_val):
                    v = fn_val(obj)
                    if isinstance(v, (str, unicode)):
                        v = escape(v)
                    obj_data[field] = v
            except:
                obj_data[field] = u''

        obj_data['DT_RowId'] = u'row_{}'.format(obj.id)
        obj_data['DT_RowData'] = {'pk': unicode(obj.id)}
        return obj_data

    ##################################################
    # search
    ##################################################
    def _search_value(self, srch_opt):
        v = srch_opt['value'].strip().split(',')
        return v if len(v) > 1 else v[0]

    def _default_query(self, srch_key, srch_val, regex=False):
        if regex:
            # regex=True, 执行模糊搜索
            q = Q()
            for val in srch_val.split():
                q &= Q(**{'{}__icontains'.format(srch_key): val})
        else:
            if isinstance(srch_val, list):
                key = '{}__in'.format(srch_key)
            else:
                key = srch_key
            q = Q(**{key: srch_val})

        return q

    def _qry_time_range(self, srch_key, srch_val, regex=False):
        if not isinstance(srch_val, list) or len(srch_val) != 2:
            return Q()
        start, end = srch_val
        end = '{} 23:59:59'.format(end) if end else None
        if start and end:
            suf = 'range'
            val = [start, end]
        elif not start and not end:
            return Q()
        elif not start:
            suf = 'lt'
            val = end
        elif not end:
            suf = 'gte'
            val = start
        return Q(**{'{}__{}'.format(srch_key, suf): val})

    def global_query(self, keys, gsrch):
        if not keys or not gsrch:
            return Q()

        keyword = gsrch['value'].strip()  # 去掉前后的空格
        if not keyword:
            return Q()

        val = self._search_value(gsrch)
        if isinstance(keys, (str, unicode)):
            keys = [keys]
        q = Q()
        for key in keys:
            q |= self._default_query(key, val, gsrch['regex'])
        return q

    def fields_query(self, srch_opts):
        q = Q()
        for field, fsrch in srch_opts.iteritems():
            val = self._search_value(fsrch)
            fn_qfield = 'query_{}'.format(field)
            info_logger.info(getattr(self, fn_qfield, self._default_query))
            q &= getattr(self, fn_qfield, self._default_query)(field, val, fsrch['regex'])
        return q
