# -*- coding: utf8 -*-

from __future__ import unicode_literals, absolute_import, print_function

import redis

from talos.backbone.rpc import RPCMixin


class ICache(object):

    def __init__(self, cache_store):
        self._cache_store = cache_store

    def set(self, k, v, timeout):
        raise NotImplementedError

    def get(self, k):
        raise NotImplementedError

    def mget(self, k):
        raise NotImplementedError

    def keys(self, ks):
        """return cached keys"""
        raise NotImplementedError


class RedisCache(ICache):

    def __init__(self, redis_client):
        assert isinstance(redis_client, redis.StrictRedis)
        super(RedisCache, self).__init__(cache_store=redis_client)

    def set(self, k, v, timeout):
        return self._cache_store.setex(k, timeout, v)

    def get(self, k):
        return self._cache_store.get(k)

    def mget(self, ks):
        return self._cache_store.mget(ks)

    def keys(self, k_star):
        return self._cache_store.keys(k_star)


class RpcServiceModelCache(object):
    """cache layer, cache result keyed by item's id."""

    missing_k = 'missing'
    _default_k_format = 'talos:dac:{model_name}:{id}'

    __registered_models = []

    def __init__(self, cache, model_name):
        self._cache = cache

        assert model_name not in self.__registered_models
        self._model_name = model_name
        self.__registered_models.append(model_name)

    def _make_key(self, id):
        return self._default_k_format.format(
            model_name=self._model_name, id=id
        )

    def _get_cached_keys(self):
        ks = self._make_key('*')
        return self._cache.keys(ks)

    def get(self, id):
        """get from cache by k."""
        ck = self._make_key(id)
        return self._cache.get(ck)

    def mget(self, ids):
        """get cached info dict.

        return {
            id: value,
            id: value,
            'missing': [],
        }
        """
        result = {
            self.missing_k: [],
        }
        if not ids:
            return result

        ks = [self._make_key(id) for id in ids]
        cached = self._cache.mget(ks)
        zipped = zip(ids, cached)

        for id, cached_info in zipped:
            if cached_info:
                result[id] = cached_info
            else:
                result[self.missing_k].append(id)

        return result

    def set(self, id, value, timeout=120):
        ck = self._make_key(id)
        self._cache.set(ck, value, timeout)

    def mset(self, id_value_pair, timeout=120):
        for (id, v) in id_value_pair:
            self.set(id, v, timeout)


class ServiceBase(RPCMixin):
    """hmmm, for later use."""

    def __init__(self):
        self._rpc_calls = {}
        self.rpc = self.get_rpc_invoker()

    def add_rpc_parallel_call(self, endpoint, **kwargs):
        """并行调用rpc."""
        self._rpc_calls[endpoint] = self.rpc.parallel[endpoint](**kwargs)

    def get_parallel_result(self):
        """获取并行调用的结果."""
        assert len(self._rpc_calls) >= 1, 'there is no parallel calls!'

        result = {}

        for k in self._rpc_calls:
            v = self._rpc_calls[k]
            result[k] = v.unwrap()

        return result

    @staticmethod
    def get_gaia_graph_scheme():
        from gql.schema import schema
        return schema
