# coding=utf-8
from __future__ import unicode_literals, print_function, absolute_import

import six
import django.db.models

_stored_attr = str('gaia_db_opt_stored')
_select_related_index = 0
_prefetch_related_index = 1
_select_nested_index = 2
_prefetch_nested_index = 3


def _get_or_create_stored(func):
    if not hasattr(func, _stored_attr):
        setattr(func, _stored_attr, (set(), set(), set(), set()))
    return getattr(func, _stored_attr)


def _is_stored(func):
    return hasattr(func, _stored_attr)


def _get_stored(func):
    return getattr(func, _stored_attr, (frozenset(), frozenset(), frozenset(), frozenset()))


def _get_nested_func(nested_func):
    if _is_stored(nested_func):
        return nested_func
    return nested_func()


def select_related(*fields):

    for field in fields:
        assert isinstance(field, six.string_types)

    def decorator(func):
        stored = _get_or_create_stored(func)
        stored[_select_related_index].update(fields)
        return func

    return decorator


def prefetch_related(*fields):

    for field in fields:
        assert isinstance(field, six.string_types)

    def decorator(func):
        stored = _get_or_create_stored(func)
        stored[_prefetch_related_index].update(fields)
        return func

    return decorator


def select_nested(nested_func, prefix=''):

    assert isinstance(prefix, six.string_types)

    def decorator(func):
        stored = _get_or_create_stored(func)
        stored[_select_nested_index].add((nested_func, prefix))
        return func

    return decorator


def prefetch_nested(nested_func, prefix=''):

    assert isinstance(prefix, six.string_types)

    def decorator(func):
        stored = _get_or_create_stored(func)
        stored[_prefetch_nested_index].add((nested_func, prefix))
        return func

    return decorator


def _add_select_related(query_set, fields):
    if len(fields) == 0:
        return query_set

    return query_set.select_related(*fields)


def _add_prefetch_related(query_set, fields):
    if len(fields) == 0:
        return query_set

    return query_set.prefetch_related(*fields)


def __get_fields(func, prefix=''):
    stored = _get_stored(func)
    data_select_related = stored[_select_related_index]
    data_prefetch_related = stored[_prefetch_related_index]
    data_select_nested = stored[_select_nested_index]
    data_prefetch_nested = stored[_prefetch_nested_index]

    select_fields = []
    prefetch_fields = []

    select_fields += [prefix + x for x in data_select_related]
    prefetch_fields += [prefix + x for x in data_prefetch_related]

    for nested_func, nested_prefix in data_select_nested:
        s, p = __get_fields(
            func=_get_nested_func(nested_func),
            prefix=prefix+nested_prefix,
        )
        select_fields += s
        prefetch_fields += p

    for nested_func, nested_prefix in data_prefetch_nested:
        s, p = __get_fields(
            func=_get_nested_func(nested_func),
            prefix=prefix+nested_prefix,
        )
        prefetch_fields += s
        prefetch_fields += p

    return select_fields, prefetch_fields


def add_related(query_set, func, prefix=''):
    """
    根据 func 与 prefix 的信息添加相应的 select_related 与 prefetch_related
    :param query_set: 原始 QuerySet
    :param func: 需要访问 QuerySet 中成员的函数
    :param prefix: func 访问的 query_set 中的成员的前缀
    :return: 添加了 select_related 与 prefetch_related 的 QuerySet
    """
    assert isinstance(query_set, django.db.models.QuerySet)
    select_fields, prefetch_fields = __get_fields(
        func=func,
        prefix=prefix,
    )
    query_set = _add_select_related(
        query_set=query_set,
        fields=select_fields,
    )
    query_set = _add_prefetch_related(
        query_set=query_set,
        fields=prefetch_fields,
    )
    return query_set

