Skip to content

Commit 91fa11b

Browse files
python_api: let approx() take nonnumeric values (#7710)
Co-authored-by: Bruno Oliveira <[email protected]>
1 parent f324b27 commit 91fa11b

File tree

4 files changed

+117
-21
lines changed

4 files changed

+117
-21
lines changed

AUTHORS

+1
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ Ilya Konstantinov
129129
Ionuț Turturică
130130
Iwan Briquemont
131131
Jaap Broekhuizen
132+
Jakob van Santen
132133
Jakub Mitoraj
133134
Jan Balster
134135
Janne Vanhala

changelog/7710.improvement.rst

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
Use strict equality comparison for nonnumeric types in ``approx`` instead of
2+
raising ``TypeError``.
3+
This was the undocumented behavior before 3.7, but is now officially a supported feature.

src/_pytest/python_api.py

+50-16
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from collections.abc import Mapping
55
from collections.abc import Sized
66
from decimal import Decimal
7-
from numbers import Number
7+
from numbers import Complex
88
from types import TracebackType
99
from typing import Any
1010
from typing import Callable
@@ -146,7 +146,10 @@ def __repr__(self) -> str:
146146
)
147147

148148
def __eq__(self, actual) -> bool:
149-
if set(actual.keys()) != set(self.expected.keys()):
149+
try:
150+
if set(actual.keys()) != set(self.expected.keys()):
151+
return False
152+
except AttributeError:
150153
return False
151154

152155
return ApproxBase.__eq__(self, actual)
@@ -161,8 +164,6 @@ def _check_type(self) -> None:
161164
if isinstance(value, type(self.expected)):
162165
msg = "pytest.approx() does not support nested dictionaries: key={!r} value={!r}\n full mapping={}"
163166
raise TypeError(msg.format(key, value, pprint.pformat(self.expected)))
164-
elif not isinstance(value, Number):
165-
raise _non_numeric_type_error(self.expected, at="key={!r}".format(key))
166167

167168

168169
class ApproxSequencelike(ApproxBase):
@@ -177,7 +178,10 @@ def __repr__(self) -> str:
177178
)
178179

179180
def __eq__(self, actual) -> bool:
180-
if len(actual) != len(self.expected):
181+
try:
182+
if len(actual) != len(self.expected):
183+
return False
184+
except TypeError:
181185
return False
182186
return ApproxBase.__eq__(self, actual)
183187

@@ -190,10 +194,6 @@ def _check_type(self) -> None:
190194
if isinstance(x, type(self.expected)):
191195
msg = "pytest.approx() does not support nested data structures: {!r} at index {}\n full sequence: {}"
192196
raise TypeError(msg.format(x, index, pprint.pformat(self.expected)))
193-
elif not isinstance(x, Number):
194-
raise _non_numeric_type_error(
195-
self.expected, at="index {}".format(index)
196-
)
197197

198198

199199
class ApproxScalar(ApproxBase):
@@ -211,16 +211,23 @@ def __repr__(self) -> str:
211211
For example, ``1.0 ± 1e-6``, ``(3+4j) ± 5e-6 ∠ ±180°``.
212212
"""
213213

214-
# Infinities aren't compared using tolerances, so don't show a
215-
# tolerance. Need to call abs to handle complex numbers, e.g. (inf + 1j).
216-
if math.isinf(abs(self.expected)):
214+
# Don't show a tolerance for values that aren't compared using
215+
# tolerances, i.e. non-numerics and infinities. Need to call abs to
216+
# handle complex numbers, e.g. (inf + 1j).
217+
if (not isinstance(self.expected, (Complex, Decimal))) or math.isinf(
218+
abs(self.expected)
219+
):
217220
return str(self.expected)
218221

219222
# If a sensible tolerance can't be calculated, self.tolerance will
220223
# raise a ValueError. In this case, display '???'.
221224
try:
222225
vetted_tolerance = "{:.1e}".format(self.tolerance)
223-
if isinstance(self.expected, complex) and not math.isinf(self.tolerance):
226+
if (
227+
isinstance(self.expected, Complex)
228+
and self.expected.imag
229+
and not math.isinf(self.tolerance)
230+
):
224231
vetted_tolerance += " ∠ ±180°"
225232
except ValueError:
226233
vetted_tolerance = "???"
@@ -239,6 +246,15 @@ def __eq__(self, actual) -> bool:
239246
if actual == self.expected:
240247
return True
241248

249+
# If either type is non-numeric, fall back to strict equality.
250+
# NB: we need Complex, rather than just Number, to ensure that __abs__,
251+
# __sub__, and __float__ are defined.
252+
if not (
253+
isinstance(self.expected, (Complex, Decimal))
254+
and isinstance(actual, (Complex, Decimal))
255+
):
256+
return False
257+
242258
# Allow the user to control whether NaNs are considered equal to each
243259
# other or not. The abs() calls are for compatibility with complex
244260
# numbers.
@@ -409,6 +425,18 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
409425
>>> 1 + 1e-8 == approx(1, rel=1e-6, abs=1e-12)
410426
True
411427
428+
You can also use ``approx`` to compare nonnumeric types, or dicts and
429+
sequences containing nonnumeric types, in which case it falls back to
430+
strict equality. This can be useful for comparing dicts and sequences that
431+
can contain optional values::
432+
433+
>>> {"required": 1.0000005, "optional": None} == approx({"required": 1, "optional": None})
434+
True
435+
>>> [None, 1.0000005] == approx([None,1])
436+
True
437+
>>> ["foo", 1.0000005] == approx([None,1])
438+
False
439+
412440
If you're thinking about using ``approx``, then you might want to know how
413441
it compares to other good ways of comparing floating-point numbers. All of
414442
these algorithms are based on relative and absolute tolerances and should
@@ -466,6 +494,14 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
466494
follows a fixed behavior. `More information...`__
467495
468496
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
497+
498+
.. versionchanged:: 3.7.1
499+
``approx`` raises ``TypeError`` when it encounters a dict value or
500+
sequence element of nonnumeric type.
501+
502+
.. versionchanged:: 6.1.0
503+
``approx`` falls back to strict equality for nonnumeric types instead
504+
of raising ``TypeError``.
469505
"""
470506

471507
# Delegate the comparison to a class that knows how to deal with the type
@@ -487,8 +523,6 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
487523

488524
if isinstance(expected, Decimal):
489525
cls = ApproxDecimal # type: Type[ApproxBase]
490-
elif isinstance(expected, Number):
491-
cls = ApproxScalar
492526
elif isinstance(expected, Mapping):
493527
cls = ApproxMapping
494528
elif _is_numpy_array(expected):
@@ -501,7 +535,7 @@ def approx(expected, rel=None, abs=None, nan_ok: bool = False) -> ApproxBase:
501535
):
502536
cls = ApproxSequencelike
503537
else:
504-
raise _non_numeric_type_error(expected, at=None)
538+
cls = ApproxScalar
505539

506540
return cls(expected, rel, abs, nan_ok)
507541

testing/python/approx.py

+63-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import operator
2+
import sys
23
from decimal import Decimal
34
from fractions import Fraction
45
from operator import eq
@@ -329,6 +330,9 @@ def test_tuple_wrong_len(self):
329330
assert (1, 2) != approx((1,))
330331
assert (1, 2) != approx((1, 2, 3))
331332

333+
def test_tuple_vs_other(self):
334+
assert 1 != approx((1,))
335+
332336
def test_dict(self):
333337
actual = {"a": 1 + 1e-7, "b": 2 + 1e-8}
334338
# Dictionaries became ordered in python3.6, so switch up the order here
@@ -346,6 +350,13 @@ def test_dict_wrong_len(self):
346350
assert {"a": 1, "b": 2} != approx({"a": 1, "c": 2})
347351
assert {"a": 1, "b": 2} != approx({"a": 1, "b": 2, "c": 3})
348352

353+
def test_dict_nonnumeric(self):
354+
assert {"a": 1.0, "b": None} == pytest.approx({"a": 1.0, "b": None})
355+
assert {"a": 1.0, "b": 1} != pytest.approx({"a": 1.0, "b": None})
356+
357+
def test_dict_vs_other(self):
358+
assert 1 != approx({"a": 0})
359+
349360
def test_numpy_array(self):
350361
np = pytest.importorskip("numpy")
351362

@@ -463,20 +474,67 @@ def test_foo():
463474
["*At index 0 diff: 3 != 4 ± {}".format(expected), "=* 1 failed in *="]
464475
)
465476

477+
@pytest.mark.parametrize(
478+
"x, name",
479+
[
480+
pytest.param([[1]], "data structures", id="nested-list"),
481+
pytest.param({"key": {"key": 1}}, "dictionaries", id="nested-dict"),
482+
],
483+
)
484+
def test_expected_value_type_error(self, x, name):
485+
with pytest.raises(
486+
TypeError,
487+
match=r"pytest.approx\(\) does not support nested {}:".format(name),
488+
):
489+
approx(x)
490+
466491
@pytest.mark.parametrize(
467492
"x",
468493
[
469494
pytest.param(None),
470495
pytest.param("string"),
471496
pytest.param(["string"], id="nested-str"),
472-
pytest.param([[1]], id="nested-list"),
473497
pytest.param({"key": "string"}, id="dict-with-string"),
474-
pytest.param({"key": {"key": 1}}, id="nested-dict"),
475498
],
476499
)
477-
def test_expected_value_type_error(self, x):
478-
with pytest.raises(TypeError):
479-
approx(x)
500+
def test_nonnumeric_okay_if_equal(self, x):
501+
assert x == approx(x)
502+
503+
@pytest.mark.parametrize(
504+
"x",
505+
[
506+
pytest.param("string"),
507+
pytest.param(["string"], id="nested-str"),
508+
pytest.param({"key": "string"}, id="dict-with-string"),
509+
],
510+
)
511+
def test_nonnumeric_false_if_unequal(self, x):
512+
"""For nonnumeric types, x != pytest.approx(y) reduces to x != y"""
513+
assert "ab" != approx("abc")
514+
assert ["ab"] != approx(["abc"])
515+
# in particular, both of these should return False
516+
assert {"a": 1.0} != approx({"a": None})
517+
assert {"a": None} != approx({"a": 1.0})
518+
519+
assert 1.0 != approx(None)
520+
assert None != approx(1.0) # noqa: E711
521+
522+
assert 1.0 != approx([None])
523+
assert None != approx([1.0]) # noqa: E711
524+
525+
@pytest.mark.skipif(sys.version_info < (3, 7), reason="requires ordered dicts")
526+
def test_nonnumeric_dict_repr(self):
527+
"""Dicts with non-numerics and infinites have no tolerances"""
528+
x1 = {"foo": 1.0000005, "bar": None, "foobar": inf}
529+
assert (
530+
repr(approx(x1))
531+
== "approx({'foo': 1.0000005 ± 1.0e-06, 'bar': None, 'foobar': inf})"
532+
)
533+
534+
def test_nonnumeric_list_repr(self):
535+
"""Lists with non-numerics and infinites have no tolerances"""
536+
x1 = [1.0000005, None, inf]
537+
assert repr(approx(x1)) == "approx([1.0000005 ± 1.0e-06, None, inf])"
480538

481539
@pytest.mark.parametrize(
482540
"op",

0 commit comments

Comments
 (0)