@@ -176,15 +176,17 @@ def test_jax_split_not_supported(self):
176
176
UserWarning , match = "Split node does not have constant split positions."
177
177
):
178
178
fn = pytensor .function ([a ], a_splits , mode = "JAX" )
179
- # It raises an informative ConcretizationTypeError, but there's an AttributeError that surpsasses it
179
+ # It raises an informative ConcretizationTypeError, but there's an AttributeError that surasses it
180
180
with pytest .raises (AttributeError ):
181
181
fn (np .zeros ((6 , 4 ), dtype = pytensor .config .floatX ))
182
182
183
183
split_axis = iscalar ("split_axis" )
184
184
a_splits = at .split (a , splits_size = [2 , 4 ], n_splits = 2 , axis = split_axis )
185
185
with pytest .warns (UserWarning , match = "Split node does not have constant axis." ):
186
186
fn = pytensor .function ([a , split_axis ], a_splits , mode = "JAX" )
187
- with pytest .raises (jax .errors .TracerIntegerConversionError ):
187
+ # Same as above, an AttributeError surpasses the `TracerIntegerConversionError`
188
+ # Both errors are included for backwards compatibility
189
+ with pytest .raises ((AttributeError , jax .errors .TracerIntegerConversionError )):
188
190
fn (np .zeros ((6 , 6 ), dtype = pytensor .config .floatX ), 0 )
189
191
190
192
0 commit comments