from __future__ import absolute_import import os from contextlib import contextmanager from celery.fixups.django import ( _maybe_close_fd, fixup, DjangoFixup, DjangoWorkerFixup, ) from celery.tests.case import ( AppCase, Mock, patch, patch_many, patch_modules, mask_modules, ) class FixupCase(AppCase): Fixup = None @contextmanager def fixup_context(self, app): with patch('celery.fixups.django.DjangoWorkerFixup.validate_models'): with patch('celery.fixups.django.symbol_by_name') as symbyname: with patch('celery.fixups.django.import_module') as impmod: f = self.Fixup(app) yield f, impmod, symbyname class test_DjangoFixup(FixupCase): Fixup = DjangoFixup def test_fixup(self): with patch('celery.fixups.django.DjangoFixup') as Fixup: with patch.dict(os.environ, DJANGO_SETTINGS_MODULE=''): fixup(self.app) self.assertFalse(Fixup.called) with patch.dict(os.environ, DJANGO_SETTINGS_MODULE='settings'): with mask_modules('django'): with self.assertWarnsRegex(UserWarning, 'but Django is'): fixup(self.app) self.assertFalse(Fixup.called) with patch_modules('django'): fixup(self.app) self.assertTrue(Fixup.called) def test_maybe_close_fd(self): with patch('os.close'): _maybe_close_fd(Mock()) _maybe_close_fd(object()) def test_init(self): with self.fixup_context(self.app) as (f, importmod, sym): self.assertTrue(f) def se(name): if name == 'django.utils.timezone:now': raise ImportError() return Mock() sym.side_effect = se self.assertTrue(self.Fixup(self.app)._now) def test_install(self): self.app.loader = Mock() with self.fixup_context(self.app) as (f, _, _): with patch_many('os.getcwd', 'sys.path', 'celery.fixups.django.signals') as (cw, p, sigs): cw.return_value = '/opt/vandelay' f.install() sigs.worker_init.connect.assert_called_with(f.on_worker_init) self.assertEqual(self.app.loader.now, f.now) self.assertEqual(self.app.loader.mail_admins, f.mail_admins) p.append.assert_called_with('/opt/vandelay') def test_now(self): with self.fixup_context(self.app) as (f, _, _): self.assertTrue(f.now(utc=True)) self.assertFalse(f._now.called) self.assertTrue(f.now(utc=False)) self.assertTrue(f._now.called) def test_mail_admins(self): with self.fixup_context(self.app) as (f, _, _): f.mail_admins('sub', 'body', True) f._mail_admins.assert_called_with( 'sub', 'body', fail_silently=True, ) def test_on_worker_init(self): with self.fixup_context(self.app) as (f, _, _): with patch('celery.fixups.django.DjangoWorkerFixup') as DWF: f.on_worker_init() DWF.assert_called_with(f.app) DWF.return_value.install.assert_called_with() self.assertIs(f._worker_fixup, DWF.return_value) class test_DjangoWorkerFixup(FixupCase): Fixup = DjangoWorkerFixup def test_init(self): with self.fixup_context(self.app) as (f, importmod, sym): self.assertTrue(f) def se(name): if name == 'django.db:close_old_connections': raise ImportError() return Mock() sym.side_effect = se self.assertIsNone(self.Fixup(self.app)._close_old_connections) def test_install(self): self.app.conf = {'CELERY_DB_REUSE_MAX': None} self.app.loader = Mock() with self.fixup_context(self.app) as (f, _, _): with patch_many('celery.fixups.django.signals') as (sigs, ): f.install() sigs.beat_embedded_init.connect.assert_called_with( f.close_database, ) sigs.worker_ready.connect.assert_called_with(f.on_worker_ready) sigs.task_prerun.connect.assert_called_with(f.on_task_prerun) sigs.task_postrun.connect.assert_called_with(f.on_task_postrun) sigs.worker_process_init.connect.assert_called_with( f.on_worker_process_init, ) def test_on_worker_process_init(self): with self.fixup_context(self.app) as (f, _, _): with patch('celery.fixups.django._maybe_close_fd') as mcf: _all = f._db.connections.all = Mock() conns = _all.return_value = [ Mock(), Mock(), ] conns[0].connection = None with patch.object(f, 'close_cache'): with patch.object(f, '_close_database'): f.on_worker_process_init() mcf.assert_called_with(conns[1].connection) f.close_cache.assert_called_with() f._close_database.assert_called_with() mcf.reset_mock() _all.side_effect = AttributeError() f.on_worker_process_init() mcf.assert_called_with(f._db.connection.connection) f._db.connection = None f.on_worker_process_init() def test_on_task_prerun(self): task = Mock() with self.fixup_context(self.app) as (f, _, _): task.request.is_eager = False with patch.object(f, 'close_database'): f.on_task_prerun(task) f.close_database.assert_called_with() task.request.is_eager = True with patch.object(f, 'close_database'): f.on_task_prerun(task) self.assertFalse(f.close_database.called) def test_on_task_postrun(self): task = Mock() with self.fixup_context(self.app) as (f, _, _): with patch.object(f, 'close_cache'): task.request.is_eager = False with patch.object(f, 'close_database'): f.on_task_postrun(task) self.assertTrue(f.close_database.called) self.assertTrue(f.close_cache.called) # when a task is eager, do not close connections with patch.object(f, 'close_cache'): task.request.is_eager = True with patch.object(f, 'close_database'): f.on_task_postrun(task) self.assertFalse(f.close_database.called) self.assertFalse(f.close_cache.called) def test_close_database(self): with self.fixup_context(self.app) as (f, _, _): f._close_old_connections = Mock() f.close_database() f._close_old_connections.assert_called_with() f._close_old_connections = None with patch.object(f, '_close_database') as _close: f.db_reuse_max = None f.close_database() _close.assert_called_with() _close.reset_mock() f.db_reuse_max = 10 f._db_recycles = 3 f.close_database() self.assertFalse(_close.called) self.assertEqual(f._db_recycles, 4) _close.reset_mock() f._db_recycles = 20 f.close_database() _close.assert_called_with() self.assertEqual(f._db_recycles, 1) def test__close_database(self): with self.fixup_context(self.app) as (f, _, _): conns = [Mock(), Mock(), Mock()] conns[1].close.side_effect = KeyError('already closed') f.database_errors = (KeyError, ) f._db.connections = Mock() # ConnectionHandler f._db.connections.all.side_effect = lambda: conns f._close_database() conns[0].close.assert_called_with() conns[1].close.assert_called_with() conns[2].close.assert_called_with() conns[1].close.side_effect = KeyError('omg') with self.assertRaises(KeyError): f._close_database() class Object(object): pass o = Object() o.close_connection = Mock() f._db = o f._close_database() o.close_connection.assert_called_with() def test_close_cache(self): with self.fixup_context(self.app) as (f, _, _): f.close_cache() f._cache.cache.close.assert_called_with() f._cache.cache.close.side_effect = TypeError() f.close_cache() def test_on_worker_ready(self): with self.fixup_context(self.app) as (f, _, _): f._settings.DEBUG = False f.on_worker_ready() with self.assertWarnsRegex(UserWarning, r'leads to a memory leak'): f._settings.DEBUG = True f.on_worker_ready() def test_mysql_errors(self): with patch_modules('MySQLdb'): import MySQLdb as mod mod.DatabaseError = Mock() mod.InterfaceError = Mock() mod.OperationalError = Mock() with self.fixup_context(self.app) as (f, _, _): self.assertIn(mod.DatabaseError, f.database_errors) self.assertIn(mod.InterfaceError, f.database_errors) self.assertIn(mod.OperationalError, f.database_errors) with mask_modules('MySQLdb'): with self.fixup_context(self.app): pass def test_pg_errors(self): with patch_modules('psycopg2'): import psycopg2 as mod mod.DatabaseError = Mock() mod.InterfaceError = Mock() mod.OperationalError = Mock() with self.fixup_context(self.app) as (f, _, _): self.assertIn(mod.DatabaseError, f.database_errors) self.assertIn(mod.InterfaceError, f.database_errors) self.assertIn(mod.OperationalError, f.database_errors) with mask_modules('psycopg2'): with self.fixup_context(self.app): pass def test_sqlite_errors(self): with patch_modules('sqlite3'): import sqlite3 as mod mod.DatabaseError = Mock() mod.InterfaceError = Mock() mod.OperationalError = Mock() with self.fixup_context(self.app) as (f, _, _): self.assertIn(mod.DatabaseError, f.database_errors) self.assertIn(mod.InterfaceError, f.database_errors) self.assertIn(mod.OperationalError, f.database_errors) with mask_modules('sqlite3'): with self.fixup_context(self.app): pass def test_oracle_errors(self): with patch_modules('cx_Oracle'): import cx_Oracle as mod mod.DatabaseError = Mock() mod.InterfaceError = Mock() mod.OperationalError = Mock() with self.fixup_context(self.app) as (f, _, _): self.assertIn(mod.DatabaseError, f.database_errors) self.assertIn(mod.InterfaceError, f.database_errors) self.assertIn(mod.OperationalError, f.database_errors) with mask_modules('cx_Oracle'): with self.fixup_context(self.app): pass