@@ -705,6 +705,33 @@ def test_any_grad(self):
705
705
assert np .all (gx_val == 0 )
706
706
707
707
708
+ def check_elemwise_runtime_broadcast (mode ):
709
+ """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
710
+ x_v = matrix ("x" )
711
+ m_v = vector ("m" )
712
+
713
+ z_v = x_v - m_v
714
+ f = pytensor .function ([x_v , m_v ], z_v , mode = mode )
715
+
716
+ # Test invalid broadcasting by either x or m
717
+ for x_sh , m_sh in [((2 , 1 ), (3 ,)), ((2 , 3 ), (1 ,))]:
718
+ x = np .ones (x_sh ).astype (config .floatX )
719
+ m = np .zeros (m_sh ).astype (config .floatX )
720
+
721
+ # This error is introduced by PyTensor, so it's the same across different backends
722
+ with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
723
+ f (x , m )
724
+
725
+ x = np .ones ((2 , 3 )).astype (config .floatX )
726
+ m = np .zeros ((1 ,)).astype (config .floatX )
727
+
728
+ x = np .ones ((2 , 4 )).astype (config .floatX )
729
+ m = np .zeros ((3 ,)).astype (config .floatX )
730
+ # This error is backend specific, and may have different types
731
+ with pytest .raises ((ValueError , TypeError )):
732
+ f (x , m )
733
+
734
+
708
735
class TestElemwise (unittest_tools .InferShapeTester ):
709
736
def test_elemwise_grad_bool (self ):
710
737
x = scalar (dtype = "bool" )
@@ -750,42 +777,15 @@ def test_input_dimensions_overflow(self):
750
777
g = pytensor .function ([a , b , c , d , e , f ], s , mode = Mode (linker = "py" ))
751
778
g (* [np .zeros (2 ** 11 , config .floatX ) for i in range (6 )])
752
779
753
- @staticmethod
754
- def check_runtime_broadcast (mode ):
755
- """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
756
- x_v = matrix ("x" )
757
- m_v = vector ("m" )
758
-
759
- z_v = x_v - m_v
760
- f = pytensor .function ([x_v , m_v ], z_v , mode = mode )
761
-
762
- # Test invalid broadcasting by either x or m
763
- for x_sh , m_sh in [((2 , 1 ), (3 ,)), ((2 , 3 ), (1 ,))]:
764
- x = np .ones (x_sh ).astype (config .floatX )
765
- m = np .zeros (m_sh ).astype (config .floatX )
766
-
767
- # This error is introduced by PyTensor, so it's the same across different backends
768
- with pytest .raises (ValueError , match = "Runtime broadcasting not allowed" ):
769
- f (x , m )
770
-
771
- x = np .ones ((2 , 3 )).astype (config .floatX )
772
- m = np .zeros ((1 ,)).astype (config .floatX )
773
-
774
- x = np .ones ((2 , 4 )).astype (config .floatX )
775
- m = np .zeros ((3 ,)).astype (config .floatX )
776
- # This error is backend specific, and may have different types
777
- with pytest .raises ((ValueError , TypeError )):
778
- f (x , m )
779
-
780
780
def test_runtime_broadcast_python (self ):
781
- self . check_runtime_broadcast (Mode (linker = "py" ))
781
+ check_elemwise_runtime_broadcast (Mode (linker = "py" ))
782
782
783
783
@pytest .mark .skipif (
784
784
not pytensor .config .cxx ,
785
785
reason = "G++ not available, so we need to skip this test." ,
786
786
)
787
787
def test_runtime_broadcast_c (self ):
788
- self . check_runtime_broadcast (Mode (linker = "c" ))
788
+ check_elemwise_runtime_broadcast (Mode (linker = "c" ))
789
789
790
790
def test_str (self ):
791
791
op = Elemwise (ps .add , inplace_pattern = {0 : 0 }, name = None )
0 commit comments