from concurrent.futures import (
    ThreadPoolExecutor,
    as_completed
)
from functools import wraps

from django.db import connections

from gm_tracer.context import register_tracer, current_tracer
from helios.rpc import get_mesh_info, set_mesh_info


class GroupRoutine(object):

    def __init__(self, workers=5, executor=ThreadPoolExecutor):

        self.executor = executor(max_workers=workers)
        self._start_id = 0
        self.futures = {}
        self.result_loader = {}
        self._done = False
        self.trace_info = self._get_trace_info()


    def _get_trace_info(self):

        trace_info = {
            'tracer': current_tracer(),
            'mesh_attr': get_mesh_info()
        }

        return trace_info

    @property
    def next_id(self):

        gr_id = self._start_id
        self._start_id += 1
        return gr_id

    def submit(self, result_name, func, *args, **kwargs):

        def trace(func):

            tracer = self.trace_info['tracer']
            mesh_attr = self.trace_info['mesh_attr']

            @wraps(func)
            def deco(*args, **kwargs):

                register_tracer(tracer)
                set_mesh_info(mesh_attr)

                return func(*args, **kwargs)

            return deco

        wrap_func = trace(func)

        future = self.executor.submit(wrap_func, *args, **kwargs)
        future.add_done_callback(self.on_done)

        gr_id = self.next_id
        self.futures[future] = gr_id

        self.result_loader[future] = result_name
        self._results = {}

    def on_done(self, future):
        connections.close_all()

    def go(self):

        if self._done:
            return

        for future in as_completed(self.futures):
            result_name = self.result_loader[future]
            self._results[result_name] = future.result()

        self._done = True
        self.executor.shutdown(wait=True)

    def __enter__(self):
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.executor.shutdown(wait=True)
        return False

    @property
    def done(self):
        return self._done

    @property
    def results(self):

        if self._done:
            return self._results

        raise Exception("Routine is not finish")


if __name__ == "__main__":

    import time
    import requests
    def test(url):
        t1 = time.time()
        requests.get(url)
        time.sleep(2)
        print(time.time() - t1)
        return url

    t1 = time.time()

    routine = GroupRoutine()
    urls = ["https://www.baidu.com/", "https://www.baidu.com/", "https://www.baidu.com/"]
    for i, url in enumerate(urls):
        routine.submit(i, test, url)
    routine.go()

    print("总耗时:", time.time() - t1)
    print(routine.results)