Skip to content

Commit e310075

Browse files
committed
Add support for serialized rollback
1 parent 9dcc8cf commit e310075

File tree

6 files changed

+90
-12
lines changed

6 files changed

+90
-12
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@ _build
1616
.Python
1717
.eggs
1818
*.egg
19+
.idea/

docs/changelog.rst

+11
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
11
Changelog
22
=========
33

4+
NEXT
5+
----
6+
7+
Features
8+
^^^^^^^^
9+
* Add support for serialized rollback in transactional tests. (#721)
10+
Thanks to Piotr Karkut for `the bug report
11+
<https://github.com/pytest-dev/pytest-django/issues/329>`_.
12+
13+
14+
415
3.6.0 (2019-10-17)
516
------------------
617

docs/helpers.rst

+27-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ on what marks are and for notes on using_ them.
1616
``pytest.mark.django_db`` - request database access
1717
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1818

19-
.. :py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False]):
19+
.. :py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, serialized_rollback=False]):
2020
2121
This is used to mark a test function as requiring the database. It
2222
will ensure the database is set up correctly for the test. Each test
@@ -47,6 +47,14 @@ test will fail when trying to access the database.
4747
effect. Please be aware that not all databases support this feature.
4848
For details see :py:attr:`django.test.TransactionTestCase.reset_sequences`.
4949

50+
:type serialized_rollback: bool
51+
:param serialized_rollback:
52+
The ``serialized_rollback`` argument enables `rollback emulation`_.
53+
After a `django.test.TransactionTestCase`_ runs, the database is
54+
flushed, destroying data created in data migrations. This is the
55+
default behavior of Django. Setting ``serialized_rollback=True``
56+
tells Django to restore that data.
57+
5058
.. note::
5159

5260
If you want access to the Django database *inside a fixture*
@@ -63,6 +71,7 @@ test will fail when trying to access the database.
6371
Test classes that subclass Python's ``unittest.TestCase`` need to have the
6472
marker applied in order to access the database.
6573

74+
.. _rollback emulation: https://docs.djangoproject.com/en/stable/topics/testing/overview/#rollback-emulation
6675
.. _django.test.TestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#testcase
6776
.. _django.test.TransactionTestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#transactiontestcase
6877

@@ -242,6 +251,17 @@ sequences (if your database supports it). This is only required for
242251
fixtures which need database access themselves. A test function should
243252
normally use the ``pytest.mark.django_db`` mark with ``transaction=True`` and ``reset_sequences=True``.
244253

254+
``django_db_serialized_rollback``
255+
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
256+
257+
When the ``transactional_db`` fixture is enabled, this fixture can be
258+
added to trigger `rollback emulation`_ and thus restores data created
259+
in data migrations after each transaction test. This is only required
260+
for fixtures which need to enforce this behavior. A test function
261+
would use :py:func:`~pytest.mark.django_db(serialized_rollback=True)`
262+
to request this behavior.
263+
264+
245265
``live_server``
246266
~~~~~~~~~~~~~~~
247267

@@ -251,6 +271,12 @@ or by requesting it's string value: ``unicode(live_server)``. You can
251271
also directly concatenate a string to form a URL: ``live_server +
252272
'/foo``.
253273

274+
Since the live server and the tests run in different threads, they
275+
cannot share a database transaction. For this reason, ``live_server``
276+
depends on the ``transactional_db`` fixture. If tests depend on data
277+
created in data migrations, you should add the ``serialized_rollback``
278+
fixture.
279+
254280
.. note:: Combining database access fixtures.
255281

256282
When using multiple database fixtures together, only one of them is

pytest_django/fixtures.py

+24-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
"db",
1919
"transactional_db",
2020
"django_db_reset_sequences",
21+
"django_db_serialized_rollback",
2122
"admin_user",
2223
"django_user_model",
2324
"django_username_field",
@@ -124,7 +125,8 @@ def teardown_database():
124125

125126

126127
def _django_db_fixture_helper(
127-
request, django_db_blocker, transactional=False, reset_sequences=False
128+
request, django_db_blocker, transactional=False, reset_sequences=False,
129+
serialized_rollback=False
128130
):
129131
if is_django_unittest(request):
130132
return
@@ -149,6 +151,7 @@ class ResetSequenceTestCase(django_case):
149151
from django.test import TestCase as django_case
150152

151153
test_case = django_case(methodName="__init__")
154+
test_case.serialized_rollback = serialized_rollback
152155
test_case._pre_setup()
153156
request.addfinalizer(test_case._post_teardown)
154157

@@ -207,13 +210,16 @@ def db(request, django_db_setup, django_db_blocker):
207210
"""
208211
if "django_db_reset_sequences" in request.fixturenames:
209212
request.getfixturevalue("django_db_reset_sequences")
213+
if "django_db_serialized_rollback" in request.fixturenames:
214+
request.getfixturevalue("django_db_serialized_rollback")
210215
if (
211216
"transactional_db" in request.fixturenames
212217
or "live_server" in request.fixturenames
213218
):
214219
request.getfixturevalue("transactional_db")
215220
else:
216-
_django_db_fixture_helper(request, django_db_blocker, transactional=False)
221+
_django_db_fixture_helper(request, django_db_blocker, transactional=False,
222+
serialized_rollback=False)
217223

218224

219225
@pytest.fixture(scope="function")
@@ -232,6 +238,8 @@ def transactional_db(request, django_db_setup, django_db_blocker):
232238
"""
233239
if "django_db_reset_sequences" in request.fixturenames:
234240
request.getfixturevalue("django_db_reset_sequences")
241+
if "django_db_serialized_rollback" in request.fixturenames:
242+
request.getfixturevalue("django_db_serialized_rollback")
235243
_django_db_fixture_helper(request, django_db_blocker, transactional=True)
236244

237245

@@ -253,6 +261,20 @@ def django_db_reset_sequences(request, django_db_setup, django_db_blocker):
253261
)
254262

255263

264+
@pytest.fixture(scope="function")
265+
def django_db_serialized_rollback(request, django_db_setup, django_db_blocker):
266+
"""Enable serialized rollback after transaction test cases
267+
268+
This fixture only has an effect when the ``transactional_db``
269+
fixture is active, which happen as a side-effect of requesting
270+
``live_server``.
271+
272+
"""
273+
_django_db_fixture_helper(
274+
request, django_db_blocker, transactional=True, serialized_rollback=True
275+
)
276+
277+
256278
@pytest.fixture()
257279
def client():
258280
"""A Django test client instance."""

pytest_django/plugin.py

+13-8
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .fixtures import django_username_field # noqa
3535
from .fixtures import live_server # noqa
3636
from .fixtures import django_db_reset_sequences # noqa
37+
from .fixtures import django_db_serialized_rollback # noqa
3738
from .fixtures import rf # noqa
3839
from .fixtures import settings # noqa
3940
from .fixtures import transactional_db # noqa
@@ -497,14 +498,17 @@ def django_db_blocker():
497498
def _django_db_marker(request):
498499
"""Implement the django_db marker, internal to pytest-django.
499500
500-
This will dynamically request the ``db``, ``transactional_db`` or
501-
``django_db_reset_sequences`` fixtures as required by the django_db marker.
501+
This will dynamically request the ``db``, ``transactional_db``,
502+
``django_db_reset_sequences`` or ``django_db_serialized_rollback``
503+
fixtures as required by the django_db marker.
502504
"""
503505
marker = request.node.get_closest_marker("django_db")
504506
if marker:
505-
transaction, reset_sequences = validate_django_db(marker)
507+
transaction, reset_sequences, serialized_rollback = validate_django_db(marker)
506508
if reset_sequences:
507509
request.getfixturevalue("django_db_reset_sequences")
510+
elif serialized_rollback:
511+
request.getfixturevalue("django_db_serialized_rollback")
508512
elif transaction:
509513
request.getfixturevalue("transactional_db")
510514
else:
@@ -805,15 +809,16 @@ def restore(self):
805809
def validate_django_db(marker):
806810
"""Validate the django_db marker.
807811
808-
It checks the signature and creates the ``transaction`` and
809-
``reset_sequences`` attributes on the marker which will have the
810-
correct values.
812+
It checks the signature and creates the ``transaction``,
813+
``reset_sequences`` and ``serialized_rollback`` attributes on
814+
the marker which will have the correct values.
811815
812816
A sequence reset is only allowed when combined with a transaction.
817+
A serialized rollback is only allowed when combined with a transaction.
813818
"""
814819

815-
def apifun(transaction=False, reset_sequences=False):
816-
return transaction, reset_sequences
820+
def apifun(transaction=False, reset_sequences=False, serialized_rollback=False):
821+
return transaction, reset_sequences, serialized_rollback
817822

818823
return apifun(*marker.args, **marker.kwargs)
819824

tests/test_database.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,13 @@ def non_zero_sequences_counter(db):
5050
class TestDatabaseFixtures:
5151
"""Tests for the different database fixtures."""
5252

53-
@pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences"])
53+
@pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences",
54+
"django_db_serialized_rollback"])
5455
def all_dbs(self, request):
5556
if request.param == "django_db_reset_sequences":
5657
return request.getfixturevalue("django_db_reset_sequences")
58+
elif request.param == "django_db_serialized_rollback":
59+
return request.getfixturevalue("django_db_serialized_rollback")
5760
elif request.param == "transactional_db":
5861
return request.getfixturevalue("transactional_db")
5962
elif request.param == "db":
@@ -215,6 +218,16 @@ def test_reset_sequences_enabled(self, request):
215218
marker = request.node.get_closest_marker("django_db")
216219
assert marker.kwargs["reset_sequences"]
217220

221+
@pytest.mark.django_db
222+
def test_serialized_rollback_disabled(self, request):
223+
marker = request.node.get_closest_marker("django_db")
224+
assert not marker.kwargs
225+
226+
@pytest.mark.django_db(serialized_rollback=True)
227+
def test_serialized_rollback_enabled(self, request):
228+
marker = request.node.get_closest_marker("django_db")
229+
assert marker.kwargs["serialized_rollback"]
230+
218231

219232
def test_unittest_interaction(django_testdir):
220233
"Test that (non-Django) unittests cannot access the DB."

0 commit comments

Comments
 (0)