File tree 2 files changed +10
-5
lines changed
2 files changed +10
-5
lines changed Original file line number Diff line number Diff line change 4
4
from typing import cast
5
5
6
6
import numpy as np
7
+ from numpy .core .numeric import normalize_axis_tuple # type: ignore
7
8
8
9
import pytensor
9
10
from pytensor .gradient import DisconnectedType
@@ -994,9 +995,7 @@ def specify_broadcastable(x, *axes):
994
995
if not axes :
995
996
return x
996
997
997
- if max (axes ) >= x .type .ndim :
998
- raise ValueError ("Trying to specify broadcastable of non-existent dimension" )
999
-
998
+ axes = normalize_axis_tuple (axes , x .type .ndim )
1000
999
shape_info = [1 if i in axes else s for i , s in enumerate (x .type .shape )]
1001
1000
return specify_shape (x , shape_info )
1002
1001
Original file line number Diff line number Diff line change @@ -562,16 +562,22 @@ def test_basic(self):
562
562
x = matrix ()
563
563
assert specify_broadcastable (x , 0 ).type .shape == (1 , None )
564
564
assert specify_broadcastable (x , 1 ).type .shape == (None , 1 )
565
+ assert specify_broadcastable (x , - 1 ).type .shape == (None , 1 )
565
566
assert specify_broadcastable (x , 0 , 1 ).type .shape == (1 , 1 )
566
567
567
568
x = row ()
568
569
assert specify_broadcastable (x , 0 ) is x
569
570
assert specify_broadcastable (x , 1 ) is not x
571
+ assert specify_broadcastable (x , - 2 ) is x
570
572
571
573
def test_validation (self ):
572
574
x = matrix ()
573
- with pytest .raises (ValueError , match = "^Trying to specify broadcastable of*" ):
574
- specify_broadcastable (x , 2 )
575
+ axis = 2
576
+ with pytest .raises (
577
+ ValueError ,
578
+ match = f"axis { axis } is out of bounds for array of dimension { axis } " ,
579
+ ):
580
+ specify_broadcastable (x , axis )
575
581
576
582
577
583
class TestRopLop (RopLopChecker ):
You can’t perform that action at this time.
0 commit comments