Skip to content

Add utility verify_infer_shape #1088

Open
@ricardoV94

Description

@ricardoV94

Description

There is a Test subclass that can be used for it, but requires running pytest which is not the best for someone implementing their Op. We could put that functionality in a verify_infer_shape function like verify_grad.

class InferShapeTester:
def setup_method(self):
# Take into account any mode that may be defined in a child class
# and it can be None
mode = getattr(self, "mode", None)
if mode is None:
mode = pytensor.compile.get_default_mode()
# This mode seems to be the minimal one including the shape_i
# optimizations, if we don't want to enumerate them explicitly.
self.mode = mode.including("canonicalize")
def _compile_and_check(
self,
inputs,
outputs,
numeric_inputs,
cls,
excluding=None,
warn=True,
check_topo=True,
):
"""This tests the infer_shape method only
When testing with input values with shapes that take the same
value over different dimensions (for instance, a square
matrix, or a tensor3 with shape (n, n, n), or (m, n, m)), it
is not possible to detect if the output shape was computed
correctly, or if some shapes with the same value have been
mixed up. For instance, if the infer_shape uses the width of a
matrix instead of its height, then testing with only square
matrices will not detect the problem. If warn=True, we emit a
warning when testing with such values.
:param check_topo: If True, we check that the Op where removed
from the graph. False is useful to test not implemented case.
"""
mode = self.mode
if excluding:
mode = mode.excluding(*excluding)
if warn:
for var, inp in zip(inputs, numeric_inputs):
if isinstance(inp, int | float | list | tuple):
inp = var.type.filter(inp)
if not hasattr(inp, "shape"):
continue
# remove broadcasted dims as it is sure they can't be
# changed to prevent the same dim problem.
if hasattr(var.type, "broadcastable"):
shp = [
inp.shape[i]
for i in range(inp.ndim)
if not var.type.broadcastable[i]
]
else:
shp = inp.shape
if len(set(shp)) != len(shp):
_logger.warning(
"While testing shape inference for %r, we received an"
" input with a shape that has some repeated values: %r"
", like a square matrix. This makes it impossible to"
" check if the values for these dimensions have been"
" correctly used, or if they have been mixed up.",
cls,
inp.shape,
)
break
outputs_function = pytensor.function(inputs, outputs, mode=mode)
# Now that we have full shape information at the type level, it's
# possible/more likely that shape-computing graphs will not need the
# inputs to the graph for which the shape is computed
shapes_function = pytensor.function(
inputs, [o.shape for o in outputs], mode=mode, on_unused_input="ignore"
)
# Check that the Op is removed from the compiled function.
if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(t in outputs for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
numeric_outputs = outputs_function(*numeric_inputs)
numeric_shapes = shapes_function(*numeric_inputs)
for out, shape in zip(numeric_outputs, numeric_shapes):
assert np.all(out.shape == shape), (out.shape, shape)

Probably could do with some cleaning as well

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions