Skip to content

Commit 14651fb

Browse files
Add support for negative axis in specify_broadcastable (#710)
1 parent bcd81c7 commit 14651fb

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

pytensor/tensor/shape.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from typing import cast
55

66
import numpy as np
7+
from numpy.core.numeric import normalize_axis_tuple # type: ignore
78

89
import pytensor
910
from pytensor.gradient import DisconnectedType
@@ -994,9 +995,7 @@ def specify_broadcastable(x, *axes):
994995
if not axes:
995996
return x
996997

997-
if max(axes) >= x.type.ndim:
998-
raise ValueError("Trying to specify broadcastable of non-existent dimension")
999-
998+
axes = normalize_axis_tuple(axes, x.type.ndim)
1000999
shape_info = [1 if i in axes else s for i, s in enumerate(x.type.shape)]
10011000
return specify_shape(x, shape_info)
10021001

tests/tensor/test_shape.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -562,16 +562,22 @@ def test_basic(self):
562562
x = matrix()
563563
assert specify_broadcastable(x, 0).type.shape == (1, None)
564564
assert specify_broadcastable(x, 1).type.shape == (None, 1)
565+
assert specify_broadcastable(x, -1).type.shape == (None, 1)
565566
assert specify_broadcastable(x, 0, 1).type.shape == (1, 1)
566567

567568
x = row()
568569
assert specify_broadcastable(x, 0) is x
569570
assert specify_broadcastable(x, 1) is not x
571+
assert specify_broadcastable(x, -2) is x
570572

571573
def test_validation(self):
572574
x = matrix()
573-
with pytest.raises(ValueError, match="^Trying to specify broadcastable of*"):
574-
specify_broadcastable(x, 2)
575+
axis = 2
576+
with pytest.raises(
577+
ValueError,
578+
match=f"axis {axis} is out of bounds for array of dimension {axis}",
579+
):
580+
specify_broadcastable(x, axis)
575581

576582

577583
class TestRopLop(RopLopChecker):

0 commit comments

Comments
 (0)