from __future__ import absolute_import from collections import defaultdict from contextlib import contextmanager from kombu.utils.limits import TokenBucket from celery.worker import state from celery.utils.timeutils import rate from celery.tests.case import AppCase, Mock, patch, body_from_sig class test_default_strategy(AppCase): def setup(self): @self.app.task(shared=False) def add(x, y): return x + y self.add = add class Context(object): def __init__(self, sig, s, reserved, consumer, message, body): self.sig = sig self.s = s self.reserved = reserved self.consumer = consumer self.message = message self.body = body def __call__(self, **kwargs): return self.s( self.message, self.body, self.message.ack, self.message.reject, [], **kwargs ) def was_reserved(self): return self.reserved.called def was_rate_limited(self): assert not self.was_reserved() return self.consumer._limit_task.called def was_scheduled(self): assert not self.was_reserved() assert not self.was_rate_limited() return self.consumer.timer.call_at.called def event_sent(self): return self.consumer.event_dispatcher.send.call_args def get_request(self): if self.was_reserved(): return self.reserved.call_args[0][0] if self.was_rate_limited(): return self.consumer._limit_task.call_args[0][0] if self.was_scheduled(): return self.consumer.timer.call_at.call_args[0][0] raise ValueError('request not handled') @contextmanager def _context(self, sig, rate_limits=True, events=True, utc=True, limit=None): self.assertTrue(sig.type.Strategy) reserved = Mock() consumer = Mock() consumer.task_buckets = defaultdict(lambda: None) if limit: bucket = TokenBucket(rate(limit), capacity=1) consumer.task_buckets[sig.task] = bucket consumer.disable_rate_limits = not rate_limits consumer.event_dispatcher.enabled = events s = sig.type.start_strategy(self.app, consumer, task_reserved=reserved) self.assertTrue(s) message = Mock() body = body_from_sig(self.app, sig, utc=utc) yield self.Context(sig, s, reserved, consumer, message, body) def test_when_logging_disabled(self): with patch('celery.worker.strategy.logger') as logger: logger.isEnabledFor.return_value = False with self._context(self.add.s(2, 2)) as C: C() self.assertFalse(logger.info.called) def test_task_strategy(self): with self._context(self.add.s(2, 2)) as C: C() self.assertTrue(C.was_reserved()) req = C.get_request() C.consumer.on_task_request.assert_called_with(req) self.assertTrue(C.event_sent()) def test_when_events_disabled(self): with self._context(self.add.s(2, 2), events=False) as C: C() self.assertTrue(C.was_reserved()) self.assertFalse(C.event_sent()) def test_eta_task(self): with self._context(self.add.s(2, 2).set(countdown=10)) as C: C() self.assertTrue(C.was_scheduled()) C.consumer.qos.increment_eventually.assert_called_with() def test_eta_task_utc_disabled(self): with self._context(self.add.s(2, 2).set(countdown=10), utc=False) as C: C() self.assertTrue(C.was_scheduled()) C.consumer.qos.increment_eventually.assert_called_with() def test_when_rate_limited(self): task = self.add.s(2, 2) with self._context(task, rate_limits=True, limit='1/m') as C: C() self.assertTrue(C.was_rate_limited()) def test_when_rate_limited__limits_disabled(self): task = self.add.s(2, 2) with self._context(task, rate_limits=False, limit='1/m') as C: C() self.assertTrue(C.was_reserved()) def test_when_revoked(self): task = self.add.s(2, 2) task.freeze() state.revoked.add(task.id) try: with self._context(task) as C: C() with self.assertRaises(ValueError): C.get_request() finally: state.revoked.discard(task.id)