-
Notifications
You must be signed in to change notification settings - Fork 129
Open
Labels
Description
Description
There is a Test subclass that can be used for it, but requires running pytest which is not the best for someone implementing their Op. We could put that functionality in a verify_infer_shape
function like verify_grad
.
pytensor/tests/unittest_tools.py
Lines 178 to 265 in 3523bfa
class InferShapeTester: | |
def setup_method(self): | |
# Take into account any mode that may be defined in a child class | |
# and it can be None | |
mode = getattr(self, "mode", None) | |
if mode is None: | |
mode = pytensor.compile.get_default_mode() | |
# This mode seems to be the minimal one including the shape_i | |
# optimizations, if we don't want to enumerate them explicitly. | |
self.mode = mode.including("canonicalize") | |
def _compile_and_check( | |
self, | |
inputs, | |
outputs, | |
numeric_inputs, | |
cls, | |
excluding=None, | |
warn=True, | |
check_topo=True, | |
): | |
"""This tests the infer_shape method only | |
When testing with input values with shapes that take the same | |
value over different dimensions (for instance, a square | |
matrix, or a tensor3 with shape (n, n, n), or (m, n, m)), it | |
is not possible to detect if the output shape was computed | |
correctly, or if some shapes with the same value have been | |
mixed up. For instance, if the infer_shape uses the width of a | |
matrix instead of its height, then testing with only square | |
matrices will not detect the problem. If warn=True, we emit a | |
warning when testing with such values. | |
:param check_topo: If True, we check that the Op where removed | |
from the graph. False is useful to test not implemented case. | |
""" | |
mode = self.mode | |
if excluding: | |
mode = mode.excluding(*excluding) | |
if warn: | |
for var, inp in zip(inputs, numeric_inputs): | |
if isinstance(inp, int | float | list | tuple): | |
inp = var.type.filter(inp) | |
if not hasattr(inp, "shape"): | |
continue | |
# remove broadcasted dims as it is sure they can't be | |
# changed to prevent the same dim problem. | |
if hasattr(var.type, "broadcastable"): | |
shp = [ | |
inp.shape[i] | |
for i in range(inp.ndim) | |
if not var.type.broadcastable[i] | |
] | |
else: | |
shp = inp.shape | |
if len(set(shp)) != len(shp): | |
_logger.warning( | |
"While testing shape inference for %r, we received an" | |
" input with a shape that has some repeated values: %r" | |
", like a square matrix. This makes it impossible to" | |
" check if the values for these dimensions have been" | |
" correctly used, or if they have been mixed up.", | |
cls, | |
inp.shape, | |
) | |
break | |
outputs_function = pytensor.function(inputs, outputs, mode=mode) | |
# Now that we have full shape information at the type level, it's | |
# possible/more likely that shape-computing graphs will not need the | |
# inputs to the graph for which the shape is computed | |
shapes_function = pytensor.function( | |
inputs, [o.shape for o in outputs], mode=mode, on_unused_input="ignore" | |
) | |
# Check that the Op is removed from the compiled function. | |
if check_topo: | |
topo_shape = shapes_function.maker.fgraph.toposort() | |
assert not any(t in outputs for t in topo_shape) | |
topo_out = outputs_function.maker.fgraph.toposort() | |
assert any(isinstance(t.op, cls) for t in topo_out) | |
# Check that the shape produced agrees with the actual shape. | |
numeric_outputs = outputs_function(*numeric_inputs) | |
numeric_shapes = shapes_function(*numeric_inputs) | |
for out, shape in zip(numeric_outputs, numeric_shapes): | |
assert np.all(out.shape == shape), (out.shape, shape) |
Probably could do with some cleaning as well