Skip to content

Commit cd2085e

Browse files
committed
approx(): Detect type errors earlier.
1 parent ad305e7 commit cd2085e

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

src/_pytest/python_api.py

+49-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import math
22
import sys
3+
from numbers import Number
4+
from decimal import Decimal
35

46
import py
57
from six.moves import zip, filterfalse
@@ -29,6 +31,9 @@ def _cmp_raises_type_error(self, other):
2931
"Comparison operators other than == and != not supported by approx objects"
3032
)
3133

34+
def _non_numeric_type_error(value):
35+
return TypeError("cannot make approximate comparisons to non-numeric values, e.g. {}".format(value))
36+
3237

3338
# builtin pytest.approx helper
3439

@@ -39,7 +44,7 @@ class ApproxBase(object):
3944
or sequences of numbers.
4045
"""
4146

42-
# Tell numpy to use our `__eq__` operator instead of its
47+
# Tell numpy to use our `__eq__` operator instead of its.
4348
__array_ufunc__ = None
4449
__array_priority__ = 100
4550

@@ -48,6 +53,7 @@ def __init__(self, expected, rel=None, abs=None, nan_ok=False):
4853
self.abs = abs
4954
self.rel = rel
5055
self.nan_ok = nan_ok
56+
self._check_type()
5157

5258
def __repr__(self):
5359
raise NotImplementedError
@@ -75,6 +81,17 @@ def _yield_comparisons(self, actual):
7581
"""
7682
raise NotImplementedError
7783

84+
def _check_type(self):
85+
"""
86+
Raise a TypeError if the expected value is not a valid type.
87+
"""
88+
# This is only a concern if the expected value is a sequence. In every
89+
# other case, the approx() function ensures that the expected value has
90+
# a numeric type. For this reason, the default is to do nothing. The
91+
# classes that deal with sequences should reimplement this method to
92+
# raise if there are any non-numeric elements in the sequence.
93+
pass
94+
7895

7996
class ApproxNumpy(ApproxBase):
8097
"""
@@ -151,6 +168,13 @@ def _yield_comparisons(self, actual):
151168
for k in self.expected.keys():
152169
yield actual[k], self.expected[k]
153170

171+
def _check_type(self):
172+
for x in self.expected.values():
173+
if isinstance(x, type(self.expected)):
174+
raise TypeError("pytest.approx() does not support nested dictionaries, e.g. {}".format(self.expected))
175+
elif not isinstance(x, Number):
176+
raise _non_numeric_type_error(self.expected)
177+
154178

155179
class ApproxSequence(ApproxBase):
156180
"""
@@ -174,6 +198,13 @@ def __eq__(self, actual):
174198
def _yield_comparisons(self, actual):
175199
return zip(actual, self.expected)
176200

201+
def _check_type(self):
202+
for x in self.expected:
203+
if isinstance(x, type(self.expected)):
204+
raise TypeError("pytest.approx() does not support nested data structures, e.g. {}".format(self.expected))
205+
elif not isinstance(x, Number):
206+
raise _non_numeric_type_error(self.expected)
207+
177208

178209
class ApproxScalar(ApproxBase):
179210
"""
@@ -294,8 +325,6 @@ class ApproxDecimal(ApproxScalar):
294325
"""
295326
Perform approximate comparisons where the expected value is a decimal.
296327
"""
297-
from decimal import Decimal
298-
299328
DEFAULT_ABSOLUTE_TOLERANCE = Decimal("1e-12")
300329
DEFAULT_RELATIVE_TOLERANCE = Decimal("1e-6")
301330

@@ -453,32 +482,33 @@ def approx(expected, rel=None, abs=None, nan_ok=False):
453482
__ https://docs.python.org/3/reference/datamodel.html#object.__ge__
454483
"""
455484

456-
from decimal import Decimal
457-
458485
# Delegate the comparison to a class that knows how to deal with the type
459486
# of the expected value (e.g. int, float, list, dict, numpy.array, etc).
460487
#
461-
# This architecture is really driven by the need to support numpy arrays.
462-
# The only way to override `==` for arrays without requiring that approx be
463-
# the left operand is to inherit the approx object from `numpy.ndarray`.
464-
# But that can't be a general solution, because it requires (1) numpy to be
465-
# installed and (2) the expected value to be a numpy array. So the general
466-
# solution is to delegate each type of expected value to a different class.
488+
# The primary responsibility of these classes is to implement ``__eq__()``
489+
# and ``__repr__()``. The former is used to actually check if some
490+
# "actual" value is equivalent to the given expected value within the
491+
# allowed tolerance. The latter is used to show the user the expected
492+
# value and tolerance, in the case that a test failed.
467493
#
468-
# This has the advantage that it made it easy to support mapping types
469-
# (i.e. dict). The old code accepted mapping types, but would only compare
470-
# their keys, which is probably not what most people would expect.
494+
# The actual logic for making approximate comparisons can be found in
495+
# ApproxScalar, which is used to compare individual numbers. All of the
496+
# other Approx classes eventually delegate to this class. The ApproxBase
497+
# class provides some convenient methods and overloads, but isn't really
498+
# essential.
471499

472-
if _is_numpy_array(expected):
473-
cls = ApproxNumpy
500+
if isinstance(expected, Decimal):
501+
cls = ApproxDecimal
502+
elif isinstance(expected, Number):
503+
cls = ApproxScalar
474504
elif isinstance(expected, Mapping):
475505
cls = ApproxMapping
476506
elif isinstance(expected, Sequence) and not isinstance(expected, STRING_TYPES):
477507
cls = ApproxSequence
478-
elif isinstance(expected, Decimal):
479-
cls = ApproxDecimal
508+
elif _is_numpy_array(expected):
509+
cls = ApproxNumpy
480510
else:
481-
cls = ApproxScalar
511+
raise _non_numeric_type_error(expected)
482512

483513
return cls(expected, rel, abs, nan_ok)
484514

testing/python/approx.py

+7
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,13 @@ def test_foo():
441441
["*At index 0 diff: 3 != 4 * {}".format(expected), "=* 1 failed in *="]
442442
)
443443

444+
@pytest.mark.parametrize(
445+
'x', [None, 'string', ['string'], [[1]], {'key': 'string'}, {'key': {'key': 1}}]
446+
)
447+
def test_expected_value_type_error(self, x):
448+
with pytest.raises(TypeError):
449+
approx(x)
450+
444451
@pytest.mark.parametrize(
445452
"op",
446453
[

0 commit comments

Comments
 (0)