Skip to content

Commit 06d9a49

Browse files
committed
Use infer_shape of core_op to infer Blockwise core shapes
This can only be done when the output of infer_shape of the core_op depends only on the input shapes, and not their values.
1 parent fa0ab9d commit 06d9a49

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

pytensor/tensor/blockwise.py

+26-3
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77
from pytensor import config
88
from pytensor.compile.builders import OpFromGraph
99
from pytensor.gradient import DisconnectedType
10-
from pytensor.graph.basic import Apply, Constant
10+
from pytensor.graph import FunctionGraph
11+
from pytensor.graph.basic import Apply, Constant, ancestors
1112
from pytensor.graph.null_type import NullType
1213
from pytensor.graph.op import Op
1314
from pytensor.graph.replace import (
@@ -179,15 +180,37 @@ def infer_shape(
179180

180181
batch_shape = broadcast_shape(*batch_shapes, arrays_are_shapes=True)
181182

183+
# Try to extract the core shapes from the core_op
184+
if hasattr(self.core_op, "infer_shape"):
185+
dummy_core_node = self._create_dummy_core_node(node.inputs)
186+
dummy_core_inputs = dummy_core_node.inputs
187+
dummy_fgraph = FunctionGraph(outputs=dummy_core_node.outputs, clone=False)
188+
core_input_shapes = [
189+
input_shape[batch_ndims:] for input_shape in input_shapes
190+
]
191+
core_output_shapes = self.core_op.infer_shape(
192+
dummy_fgraph, dummy_core_node, core_input_shapes
193+
)
194+
182195
out_shapes = []
183-
for output, sig in zip(node.outputs, self.outputs_sig):
196+
for o, (output, sig) in enumerate(zip(node.outputs, self.outputs_sig)):
184197
core_out_shape = []
185198
for i, dim_name in enumerate(sig):
186199
# The output dim is the same as another input dim
187200
if dim_name in core_dims:
188201
core_out_shape.append(core_dims[dim_name])
189202
else:
190-
# TODO: We could try to make use of infer_shape of core_op
203+
if hasattr(self.core_op, "infer_shape"):
204+
# If the input values are needed to compute the dimension length, we can't use the infer_shape
205+
# of the core_node as the value is not constant across batch dims of the Blockwise
206+
core_out_dim = core_output_shapes[o][i]
207+
if not (
208+
set(dummy_core_inputs) & set(ancestors([core_out_dim]))
209+
):
210+
core_out_shape.append(core_out_dim)
211+
continue
212+
213+
# Fallback shape requires evaluating the Blockwise Op
191214
core_out_shape.append(Shape_i(batch_ndims + i)(output))
192215
out_shapes.append((*batch_shape, *core_out_shape))
193216

tests/tensor/test_blockwise.py

+52
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,58 @@ def test_blockwise_shape():
215215
assert tuple(shape_fn(inp1_test, inp2_test)[1]) == (7, 5, 4)
216216

217217

218+
def test_blockwise_infer_core_shape():
219+
class TestOpWithInferShape(Op):
220+
def make_node(self, a, b):
221+
assert a.type.ndim == 1
222+
assert b.type.ndim == 1
223+
c = tensor(shape=(None,))
224+
d = tensor(shape=(None,))
225+
return Apply(self, [a, b], [c, d])
226+
227+
def perform(self, node, inputs, outputs):
228+
a, b = inputs
229+
c, d = outputs
230+
c[0] = np.arange(a.size + b.size)
231+
d[0] = np.arange(a.sum() + b.sum())
232+
233+
def infer_shape(self, fgraph, node, input_shapes):
234+
# First output shape depends only on input_shapes
235+
# Second output shape depends on input values
236+
x, y = node.inputs
237+
[(x_shape,), (y_shape,)] = input_shapes
238+
return (x_shape + y_shape,), (x.sum() + y.sum(),)
239+
240+
blockwise_op = Blockwise(
241+
core_op=TestOpWithInferShape(), signature="(a),(b)->(c),(d)"
242+
)
243+
244+
a = tensor("a", shape=(5, 3))
245+
b = tensor("b", shape=(1, 4))
246+
c, d = blockwise_op(a, b)
247+
assert c.type.shape == (5, None)
248+
assert d.type.shape == (5, None)
249+
250+
c_shape_fn = pytensor.function([a, b], c.shape)
251+
# c_shape can be computed from the input shapes alone
252+
assert not any(
253+
isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape)
254+
for n in c_shape_fn.maker.fgraph.apply_nodes
255+
)
256+
257+
d_shape_fn = pytensor.function([a, b], d.shape)
258+
# d_shape cannot be computed from the input shapes alone
259+
assert any(
260+
isinstance(getattr(n.op, "core_op", n.op), TestOpWithInferShape)
261+
for n in d_shape_fn.maker.fgraph.apply_nodes
262+
)
263+
264+
a_test = np.zeros(a.type.shape)
265+
b_test = np.zeros(b.type.shape)
266+
assert tuple(c_shape_fn(a_test, b_test)) == (5, 7)
267+
assert tuple(d_shape_fn(a_test, b_test)) == (5, 0)
268+
269+
218270
class BlockwiseOpTester:
219271
"""Base class to test Blockwise works for specific Ops"""
220272

0 commit comments

Comments
 (0)