Skip to content

Feature/serialized rollback #956

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,16 @@
Changelog
=========

<<<<<<< HEAD
NEXT
----

Features
^^^^^^^^
* Add support for serialized rollback in transactional tests. (#721)
Thanks to Piotr Karkut for `the bug report
<https://github.com/pytest-dev/pytest-django/issues/329>`_.
=======
v4.4.0 (2021-06-06)
-------------------

Expand Down Expand Up @@ -40,6 +50,7 @@ Bugfixes
^^^^^^^^

* Disable atomic durability check on non-transactional tests (#910).
>>>>>>> master


v4.1.0 (2020-10-22)
Expand Down
37 changes: 36 additions & 1 deletion docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,11 @@ dynamically in a hook or fixture.
``pytest.mark.django_db`` - request database access
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

<<<<<<< HEAD
.. py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, serialized_rollback=False])
=======
.. py:function:: pytest.mark.django_db([transaction=False, reset_sequences=False, databases=None])
>>>>>>> master

This is used to mark a test function as requiring the database. It
will ensure the database is set up correctly for the test. Each test
Expand Down Expand Up @@ -56,6 +60,15 @@ dynamically in a hook or fixture.
effect. Please be aware that not all databases support this feature.
For details see :py:attr:`django.test.TransactionTestCase.reset_sequences`.

<<<<<<< HEAD
:type serialized_rollback: bool
:param serialized_rollback:
The ``serialized_rollback`` argument enables `rollback emulation`_.
After a `django.test.TransactionTestCase`_ runs, the database is
flushed, destroying data created in data migrations. This is the
default behavior of Django. Setting ``serialized_rollback=True``
tells Django to restore that data.
=======

:type databases: Union[Iterable[str], str, None]
:param databases:
Expand All @@ -72,6 +85,7 @@ dynamically in a hook or fixture.
to specify all configured databases.
For details see :py:attr:`django.test.TransactionTestCase.databases` and
:py:attr:`django.test.TestCase.databases`.
>>>>>>> master

.. note::

Expand All @@ -88,7 +102,11 @@ dynamically in a hook or fixture.
Test classes that subclass :class:`django.test.TestCase` will have access to
the database always to make them compatible with existing Django tests.
Test classes that subclass Python's :class:`unittest.TestCase` need to have
the marker applied in order to access the database.
marker applied in order to access the database.

.. _rollback emulation: https://docs.djangoproject.com/en/stable/topics/testing/overview/#rollback-emulation
.. _django.test.TestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#testcase
.. _django.test.TransactionTestCase: https://docs.djangoproject.com/en/dev/topics/testing/overview/#transactiontestcase


``pytest.mark.urls`` - override the urlconf
Expand Down Expand Up @@ -333,6 +351,17 @@ use the :func:`pytest.mark.django_db` mark with ``transaction=True`` and

.. fixture:: live_server

``django_db_serialized_rollback``
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

When the ``transactional_db`` fixture is enabled, this fixture can be
added to trigger `rollback emulation`_ and thus restores data created
in data migrations after each transaction test. This is only required
for fixtures which need to enforce this behavior. A test function
would use ``pytest.mark.django_db(serialized_rollback=True)``
to request this behavior.


``live_server``
~~~~~~~~~~~~~~~

Expand All @@ -342,6 +371,12 @@ or by requesting it's string value: ``str(live_server)``. You can
also directly concatenate a string to form a URL: ``live_server +
'/foo'``.

Since the live server and the tests run in different threads, they
cannot share a database transaction. For this reason, ``live_server``
depends on the ``transactional_db`` fixture. If tests depend on data
created in data migrations, you should add the ``serialized_rollback``
fixture.

.. note:: Combining database access fixtures.

When using multiple database fixtures together, only one of them is
Expand Down
70 changes: 58 additions & 12 deletions pytest_django/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,14 @@
from contextlib import contextmanager
from functools import partial
from typing import (
Any, Callable, Generator, Iterable, List, Optional, Tuple, Union,
Any,
Callable,
Generator,
Iterable,
List,
Optional,
Tuple,
Union,
)

import pytest
Expand All @@ -28,6 +35,7 @@
"db",
"transactional_db",
"django_db_reset_sequences",
"django_db_serialized_rollback",
"admin_user",
"django_user_model",
"django_username_field",
Expand Down Expand Up @@ -143,6 +151,7 @@ def _django_db_fixture_helper(
django_db_blocker,
transactional: bool = False,
reset_sequences: bool = False,
serialized_rollback: bool = False,
) -> None:
if is_django_unittest(request):
return
Expand All @@ -152,7 +161,9 @@ def _django_db_fixture_helper(
return

_databases = getattr(
request.node, "_pytest_django_databases", None,
request.node,
"_pytest_django_databases",
None,
) # type: Optional[_DjangoDbDatabases]

django_db_blocker.unblock()
Expand All @@ -163,6 +174,21 @@ def _django_db_fixture_helper(

if transactional:
test_case_class = django.test.TransactionTestCase

if reset_sequences:

class ResetSequenceTestCase(test_case_class):
reset_sequences = True

test_case_class = ResetSequenceTestCase

if serialized_rollback:

class SerializedRollbackTestCase(test_case_class):
serialized_rollback = True

test_case_class = SerializedRollbackTestCase

else:
test_case_class = django.test.TestCase

Expand Down Expand Up @@ -245,13 +271,17 @@ def db(
"""
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if "django_db_serialized_rollback" in request.fixturenames:
request.getfixturevalue("django_db_serialized_rollback")
if (
"transactional_db" in request.fixturenames
or "live_server" in request.fixturenames
):
request.getfixturevalue("transactional_db")
else:
_django_db_fixture_helper(request, django_db_blocker, transactional=False)
_django_db_fixture_helper(
request, django_db_blocker, transactional=False, serialized_rollback=False
)


@pytest.fixture(scope="function")
Expand All @@ -274,6 +304,8 @@ def transactional_db(
"""
if "django_db_reset_sequences" in request.fixturenames:
request.getfixturevalue("django_db_reset_sequences")
if "django_db_serialized_rollback" in request.fixturenames:
request.getfixturevalue("django_db_serialized_rollback")
_django_db_fixture_helper(request, django_db_blocker, transactional=True)


Expand All @@ -299,6 +331,20 @@ def django_db_reset_sequences(
)


@pytest.fixture(scope="function")
def django_db_serialized_rollback(request, django_db_setup, django_db_blocker):
"""Enable serialized rollback after transaction test cases

This fixture only has an effect when the ``transactional_db``
fixture is active, which happen as a side-effect of requesting
``live_server``.

"""
_django_db_fixture_helper(
request, django_db_blocker, transactional=True, serialized_rollback=True
)


@pytest.fixture()
def client() -> "django.test.client.Client":
"""A Django test client instance."""
Expand Down Expand Up @@ -462,9 +508,11 @@ def live_server(request):
"""
skip_if_no_django()

addr = request.config.getvalue("liveserver") or os.getenv(
"DJANGO_LIVE_TEST_SERVER_ADDRESS"
) or "localhost"
addr = (
request.config.getvalue("liveserver")
or os.getenv("DJANGO_LIVE_TEST_SERVER_ADDRESS")
or "localhost"
)

server = live_server_helper.LiveServer(addr)
request.addfinalizer(server.stop)
Expand Down Expand Up @@ -549,11 +597,7 @@ def django_assert_max_num_queries(pytestconfig):


@contextmanager
def _capture_on_commit_callbacks(
*,
using: Optional[str] = None,
execute: bool = False
):
def _capture_on_commit_callbacks(*, using: Optional[str] = None, execute: bool = False):
from django.db import DEFAULT_DB_ALIAS, connections
from django.test import TestCase

Expand All @@ -574,7 +618,9 @@ def _capture_on_commit_callbacks(
callback()

else:
with TestCase.captureOnCommitCallbacks(using=using, execute=execute) as callbacks:
with TestCase.captureOnCommitCallbacks(
using=using, execute=execute
) as callbacks:
yield callbacks


Expand Down
45 changes: 30 additions & 15 deletions pytest_django/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import pytest

from .django_compat import is_django_unittest # noqa
from .fixtures import _live_server_helper # noqa
from .fixtures import _live_server_helper # noqa; noqa
from .fixtures import admin_client # noqa
from .fixtures import admin_user # noqa
from .fixtures import async_client # noqa
Expand All @@ -40,6 +40,7 @@
from .fixtures import rf # noqa
from .fixtures import settings # noqa
from .fixtures import transactional_db # noqa
from .fixtures import django_db_serialized_rollback
from .lazy_django import django_settings_is_configured, skip_if_no_django


Expand Down Expand Up @@ -380,15 +381,15 @@ def get_order_number(test: pytest.Item) -> int:
if issubclass(test_cls, TransactionTestCase):
return 1

marker_db = test.get_closest_marker('django_db')
marker_db = test.get_closest_marker("django_db")
if not marker_db:
transaction = None
else:
transaction = validate_django_db(marker_db)[0]
if transaction is True:
return 1

fixtures = getattr(test, 'fixturenames', [])
fixtures = getattr(test, "fixturenames", [])
if "transactional_db" in fixtures:
return 1

Expand Down Expand Up @@ -417,7 +418,8 @@ def django_test_environment(request) -> None:
if django_settings_is_configured():
_setup_django()
from django.test.utils import (
setup_test_environment, teardown_test_environment,
setup_test_environment,
teardown_test_environment,
)

debug_ini = request.config.getini("django_debug_mode")
Expand Down Expand Up @@ -454,18 +456,26 @@ def django_db_blocker() -> "Optional[_DatabaseBlocker]":
def _django_db_marker(request) -> None:
"""Implement the django_db marker, internal to pytest-django.

This will dynamically request the ``db``, ``transactional_db`` or
``django_db_reset_sequences`` fixtures as required by the django_db marker.
This will dynamically request the ``db``, ``transactional_db``,
``django_db_reset_sequences`` or ``django_db_serialized_rollback``
fixtures as required by the django_db marker.
"""
marker = request.node.get_closest_marker("django_db")
if marker:
transaction, reset_sequences, databases = validate_django_db(marker)
(
transaction,
reset_sequences,
serialized_rollback,
databases,
) = validate_django_db(marker)

# TODO: Use pytest Store (item.store) once that's stable.
request.node._pytest_django_databases = databases

if reset_sequences:
request.getfixturevalue("django_db_reset_sequences")
elif serialized_rollback:
request.getfixturevalue("django_db_serialized_rollback")
elif transaction:
request.getfixturevalue("transactional_db")
else:
Expand All @@ -486,6 +496,7 @@ def _django_setup_unittest(
# Before pytest 5.4: https://github.com/pytest-dev/pytest/issues/5991
# After pytest 5.4: https://github.com/pytest-dev/pytest-django/issues/824
from _pytest.unittest import TestCaseFunction

original_runtest = TestCaseFunction.runtest

def non_debugging_runtest(self) -> None:
Expand Down Expand Up @@ -641,13 +652,15 @@ def __mod__(self, var: str) -> str:
from django.conf import settings as dj_settings

if dj_settings.TEMPLATES:
dj_settings.TEMPLATES[0]["OPTIONS"]["string_if_invalid"] = InvalidVarException()
dj_settings.TEMPLATES[0]["OPTIONS"][
"string_if_invalid"
] = InvalidVarException()


@pytest.fixture(autouse=True)
def _template_string_if_invalid_marker(request) -> None:
"""Apply the @pytest.mark.ignore_template_errors marker,
internal to pytest-django."""
internal to pytest-django."""
marker = request.keywords.get("ignore_template_errors", None)
if os.environ.get(INVALID_TEMPLATE_VARS_ENV, "false") == "true":
if marker and django_settings_is_configured():
Expand Down Expand Up @@ -742,18 +755,20 @@ def validate_django_db(marker) -> "_DjangoDb":
"""Validate the django_db marker.

It checks the signature and creates the ``transaction``,
``reset_sequences`` and ``databases`` attributes on the marker
which will have the correct values.
``reset_sequences`` and ``serialized_rollback`` attributes on
the marker which will have the correct values.

A sequence reset is only allowed when combined with a transaction.
A serialized rollback is only allowed when combined with a transaction.
"""

def apifun(
transaction: bool = False,
reset_sequences: bool = False,
transaction=False,
reset_sequences=False,
serialized_rollback=False,
databases: "_DjangoDbDatabases" = None,
) -> "_DjangoDb":
return transaction, reset_sequences, databases
):
return transaction, reset_sequences, serialized_rollback, databases

return apifun(*marker.args, **marker.kwargs)

Expand Down
Loading