Skip to content

Add support for serialized rollback #721

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ _build
.Python
.eggs
*.egg
.idea/
11 changes: 11 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
Changelog
=========

NEXT
----

Features
^^^^^^^^
* Add support for serialized rollback in transactional tests. (#721)
Thanks to Piotr Karkut for `the bug report
<https://github.com/pytest-dev/pytest-django/issues/329>`_.



3.6.0 (2019-10-17)
------------------

Expand Down
28 changes: 27 additions & 1 deletion docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ on what marks are and for notes on using_ them.
``pytest.mark.django_db`` - request database access
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. :py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False]):
.. :py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, serialized_rollback=False]):

This is used to mark a test function as requiring the database. It
will ensure the database is set up correctly for the test. Each test
Expand Down Expand Up @@ -47,6 +47,14 @@ test will fail when trying to access the database.
effect. Please be aware that not all databases support this feature.
For details see :py:attr:`django.test.TransactionTestCase.reset_sequences`.

:type serialized_rollback: bool
:param serialized_rollback:
The ``serialized_rollback`` argument enables `rollback emulation`_.
After a `django.test.TransactionTestCase`_ runs, the database is
flushed, destroying data created in data migrations. This is the
default behavior of Django. Setting ``serialized_rollback=True``
tells Django to restore that data.

.. note::

If you want access to the Django database *inside a fixture*
Expand All @@ -63,6 +71,7 @@ test will fail when trying to access the database.
Test classes that subclass Python's ``unittest.TestCase`` need to have the
marker applied in order to access the database.

.. _rollback emulation: https://docs.djangoproject.com/en/stable/topics/testing/overview/#rollback-emulation
.. _django.test.TestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#testcase
.. _django.test.TransactionTestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#transactiontestcase

Expand Down Expand Up @@ -242,6 +251,17 @@ sequences (if your database supports it). This is only required for
fixtures which need database access themselves. A test function should
normally use the ``pytest.mark.django_db`` mark with ``transaction=True`` and ``reset_sequences=True``.

``django_db_serialized_rollback``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When the ``transactional_db`` fixture is enabled, this fixture can be
added to trigger `rollback emulation`_ and thus restores data created
in data migrations after each transaction test. This is only required
for fixtures which need to enforce this behavior. A test function
would use ``pytest.mark.django_db(serialized_rollback=True)``
to request this behavior.


``live_server``
~~~~~~~~~~~~~~~

Expand All @@ -251,6 +271,12 @@ or by requesting it's string value: ``unicode(live_server)``. You can
also directly concatenate a string to form a URL: ``live_server +
'/foo``.

Since the live server and the tests run in different threads, they
cannot share a database transaction. For this reason, ``live_server``
depends on the ``transactional_db`` fixture. If tests depend on data
created in data migrations, you should add the ``serialized_rollback``
fixture.

.. note:: Combining database access fixtures.

When using multiple database fixtures together, only one of them is
Expand Down
26 changes: 24 additions & 2 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
"db",
"transactional_db",
"django_db_reset_sequences",
"django_db_serialized_rollback",
"admin_user",
"django_user_model",
"django_username_field",
Expand Down Expand Up @@ -124,7 +125,8 @@ def teardown_database():


def _django_db_fixture_helper(
request, django_db_blocker, transactional=False, reset_sequences=False
request, django_db_blocker, transactional=False, reset_sequences=False,
serialized_rollback=False
):
if is_django_unittest(request):
return
Expand All @@ -149,6 +151,7 @@ class ResetSequenceTestCase(django_case):
from django.test import TestCase as django_case

test_case = django_case(methodName="__init__")
test_case.serialized_rollback = serialized_rollback
test_case._pre_setup()
request.addfinalizer(test_case._post_teardown)

Expand Down Expand Up @@ -207,13 +210,16 @@ def db(request, django_db_setup, django_db_blocker):
"""
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if "django_db_serialized_rollback" in request.fixturenames:
request.getfixturevalue("django_db_serialized_rollback")
if (
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
):
request.getfixturevalue("transactional_db")
else:
_django_db_fixture_helper(request, django_db_blocker, transactional=False)
_django_db_fixture_helper(request, django_db_blocker, transactional=False,
serialized_rollback=False)


@pytest.fixture(scope="function")
Expand All @@ -232,6 +238,8 @@ def transactional_db(request, django_db_setup, django_db_blocker):
"""
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if "django_db_serialized_rollback" in request.fixturenames:
request.getfixturevalue("django_db_serialized_rollback")
_django_db_fixture_helper(request, django_db_blocker, transactional=True)


Expand All @@ -253,6 +261,20 @@ def django_db_reset_sequences(request, django_db_setup, django_db_blocker):
)


@pytest.fixture(scope="function")
def django_db_serialized_rollback(request, django_db_setup, django_db_blocker):
"""Enable serialized rollback after transaction test cases

This fixture only has an effect when the ``transactional_db``
fixture is active, which happen as a side-effect of requesting
``live_server``.

"""
_django_db_fixture_helper(
request, django_db_blocker, transactional=True, serialized_rollback=True
)


@pytest.fixture()
def client():
"""A Django test client instance."""
Expand Down
21 changes: 13 additions & 8 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from .fixtures import django_username_field # noqa
from .fixtures import live_server # noqa
from .fixtures import django_db_reset_sequences # noqa
from .fixtures import django_db_serialized_rollback # noqa
from .fixtures import rf # noqa
from .fixtures import settings # noqa
from .fixtures import transactional_db # noqa
Expand Down Expand Up @@ -497,14 +498,17 @@ def django_db_blocker():
def _django_db_marker(request):
"""Implement the django_db marker, internal to pytest-django.

This will dynamically request the ``db``, ``transactional_db`` or
``django_db_reset_sequences`` fixtures as required by the django_db marker.
This will dynamically request the ``db``, ``transactional_db``,
``django_db_reset_sequences`` or ``django_db_serialized_rollback``
fixtures as required by the django_db marker.
"""
marker = request.node.get_closest_marker("django_db")
if marker:
transaction, reset_sequences = validate_django_db(marker)
transaction, reset_sequences, serialized_rollback = validate_django_db(marker)
if reset_sequences:
request.getfixturevalue("django_db_reset_sequences")
elif serialized_rollback:
request.getfixturevalue("django_db_serialized_rollback")
elif transaction:
request.getfixturevalue("transactional_db")
else:
Expand Down Expand Up @@ -805,15 +809,16 @@ def restore(self):
def validate_django_db(marker):
"""Validate the django_db marker.

It checks the signature and creates the ``transaction`` and
``reset_sequences`` attributes on the marker which will have the
correct values.
It checks the signature and creates the ``transaction``,
``reset_sequences`` and ``serialized_rollback`` attributes on
the marker which will have the correct values.

A sequence reset is only allowed when combined with a transaction.
A serialized rollback is only allowed when combined with a transaction.
"""

def apifun(transaction=False, reset_sequences=False):
return transaction, reset_sequences
def apifun(transaction=False, reset_sequences=False, serialized_rollback=False):
return transaction, reset_sequences, serialized_rollback

return apifun(*marker.args, **marker.kwargs)

Expand Down
15 changes: 14 additions & 1 deletion tests/test_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,13 @@ def non_zero_sequences_counter(db):
class TestDatabaseFixtures:
"""Tests for the different database fixtures."""

@pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences"])
@pytest.fixture(params=["db", "transactional_db", "django_db_reset_sequences",
"django_db_serialized_rollback"])
def all_dbs(self, request):
if request.param == "django_db_reset_sequences":
return request.getfixturevalue("django_db_reset_sequences")
elif request.param == "django_db_serialized_rollback":
return request.getfixturevalue("django_db_serialized_rollback")
elif request.param == "transactional_db":
return request.getfixturevalue("transactional_db")
elif request.param == "db":
Expand Down Expand Up @@ -215,6 +218,16 @@ def test_reset_sequences_enabled(self, request):
marker = request.node.get_closest_marker("django_db")
assert marker.kwargs["reset_sequences"]

@pytest.mark.django_db
def test_serialized_rollback_disabled(self, request):
marker = request.node.get_closest_marker("django_db")
assert not marker.kwargs

@pytest.mark.django_db(serialized_rollback=True)
def test_serialized_rollback_enabled(self, request):
marker = request.node.get_closest_marker("django_db")
assert marker.kwargs["serialized_rollback"]


def test_unittest_interaction(django_testdir):
"Test that (non-Django) unittests cannot access the DB."
Expand Down