Skip to content

Commit 981be2a

Browse files
committed
Revert numba runtime broadcast check
1 parent 5bbfc96 commit 981be2a

File tree

2 files changed

+5
-9
lines changed

2 files changed

+5
-9
lines changed

pytensor/link/numba/dispatch/elemwise_codegen.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,15 @@ def compute_itershape(
3535
with builder.if_then(
3636
builder.icmp_unsigned("!=", length, shape[i]), likely=False
3737
):
38-
with builder.if_else(
39-
builder.or_(
40-
builder.icmp_unsigned("==", length, one),
41-
builder.icmp_unsigned("==", shape[i], one),
42-
)
43-
) as (
38+
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
4439
then,
4540
otherwise,
4641
):
4742
with then:
4843
msg = (
49-
"Runtime broadcasting not allowed. "
50-
"One input had a distinct dimension length of 1, but was not marked as broadcastable.\n"
51-
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
44+
f"Incompatible shapes for input {j} and axis {i} of "
45+
f"elemwise. Input {j} has shape 1, but is not statically "
46+
"known to have shape 1, and thus not broadcastable."
5247
)
5348
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
5449
with otherwise:

tests/link/numba/test_elemwise.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
121121
compare_numba_and_py(out_fg, input_vals)
122122

123123

124+
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
124125
def test_elemwise_runtime_shape_error():
125126
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
126127

0 commit comments

Comments
 (0)