Skip to content

Commit a5d54c8

Browse files
committed
Check for runtime broadcasting in Blockwise Ops
1 parent 893dc18 commit a5d54c8

File tree

2 files changed

+69
-1
lines changed

2 files changed

+69
-1
lines changed

pytensor/tensor/blockwise.py

+18
Original file line numberDiff line numberDiff line change
@@ -355,12 +355,30 @@ def core_func(*inner_inputs):
355355
self._gufunc = np.vectorize(core_func, signature=self.signature)
356356
return self._gufunc
357357

358+
def _check_runtime_broadcast(self, node, inputs):
359+
batch_ndim = self._batch_ndim_from_outputs(node.outputs)
360+
361+
for dims_and_bcast in zip(
362+
*[
363+
zip(input.shape[:batch_ndim], sinput.type.broadcastable[:batch_ndim])
364+
for input, sinput in zip(inputs, node.inputs)
365+
]
366+
):
367+
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
368+
raise ValueError(
369+
"Runtime broadcasting not allowed. "
370+
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
371+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
372+
)
373+
358374
def perform(self, node, inputs, output_storage):
359375
gufunc = self._gufunc
360376

361377
if gufunc is None:
362378
gufunc = self._create_gufunc(node)
363379

380+
self._check_runtime_broadcast(node, inputs)
381+
364382
res = gufunc(*inputs)
365383
if not isinstance(res, tuple):
366384
res = (res,)

tests/tensor/test_blockwise.py

+51-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytest
66

77
import pytensor
8-
from pytensor import config
8+
from pytensor import config, function
99
from pytensor.gradient import grad
1010
from pytensor.graph import Apply, Op
1111
from pytensor.graph.replace import vectorize_node
@@ -38,6 +38,56 @@ def test_vectorize_blockwise():
3838
assert new_vect_node.inputs[0] is tns4
3939

4040

41+
def check_blockwise_runtime_broadcasting(mode):
42+
a = tensor("a", shape=(None, 3, 5))
43+
b = tensor("b", shape=(None, 5, 3))
44+
45+
out = a @ b
46+
fn = function([a, b], out, mode=mode)
47+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
48+
49+
for valid_test_values in [
50+
(
51+
np.ones((2, 3, 5)).astype(config.floatX),
52+
np.ones((2, 5, 3)).astype(config.floatX),
53+
),
54+
(
55+
np.ones((1, 3, 5)).astype(config.floatX),
56+
np.ones((1, 5, 3)).astype(config.floatX),
57+
),
58+
]:
59+
batch_dim = valid_test_values[0].shape[0]
60+
np.testing.assert_allclose(
61+
fn(*valid_test_values), np.full((batch_dim, 3, 3), 5.0)
62+
)
63+
64+
for invalid_test_values in [
65+
(
66+
np.ones((1, 3, 5)).astype(config.floatX),
67+
np.ones((2, 5, 3)).astype(config.floatX),
68+
),
69+
(
70+
np.ones((2, 3, 5)).astype(config.floatX),
71+
np.ones((1, 5, 3)).astype(config.floatX),
72+
),
73+
]:
74+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
75+
fn(*invalid_test_values)
76+
77+
invalid_test_values = (
78+
np.ones((2, 3, 5)).astype(config.floatX),
79+
np.ones((3, 5, 3)).astype(config.floatX),
80+
)
81+
# Error message is backend specific
82+
with pytest.raises(ValueError):
83+
fn(*invalid_test_values)
84+
85+
86+
@pytest.mark.parametrize("mode", ("FAST_COMPILE", "FAST_RUN"))
87+
def test_runtime_broadcast(mode):
88+
check_blockwise_runtime_broadcasting(mode)
89+
90+
4191
class TestOp(Op):
4292
def make_node(self, *inputs):
4393
return Apply(self, inputs, [i.type() for i in inputs])

0 commit comments

Comments
 (0)