import threading
from functools import wraps

threadlocal = threading.local()


class thread_local_router(object):
    """ a decorator that wraps a function in a thread local definition block
    useful for passing variables down the stack w/o.

    Usage:

    @thread_local(DB_FOR_READ_OVERRIDE='foobar')
    def override(request):
        ...

    """

    def __init__(self, **kwargs):
        self.options = kwargs

    def __enter__(self):
        for attr, value in self.options.items():
            setattr(threadlocal, attr, value)

    def __exit__(self, exc_type, exc_value, traceback):
        for attr in self.options.keys():
            delattr(threadlocal, attr)

    def __call__(self, func):

        @wraps(func)
        def inner(*args, **kwargs):
            # the thread_local class is also a context manager
            # which means it will call __enter__ and __exit__
            with self:
                return func(*args, **kwargs)

        return inner