Skip to content

Commit b61d31c

Browse files
committed
Don't run unrelated tests in altenarnative backends
1 parent 558f084 commit b61d31c

File tree

6 files changed

+67
-64
lines changed

6 files changed

+67
-64
lines changed

tests/link/jax/test_elemwise.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,11 @@
1515
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
1616
from pytensor.tensor.type import matrix, tensor, vector, vectors
1717
from tests.link.jax.test_basic import compare_jax_and_py
18-
from tests.tensor.test_elemwise import TestElemwise
18+
from tests.tensor.test_elemwise import check_elemwise_runtime_broadcast
1919

2020

2121
def test_elemwise_runtime_broadcast():
22-
TestElemwise.check_runtime_broadcast(get_mode("JAX"))
22+
check_elemwise_runtime_broadcast(get_mode("JAX"))
2323

2424

2525
def test_jax_Dimshuffle():

tests/link/jax/test_tensor_basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from pytensor.graph.op import get_test_value
1515
from pytensor.tensor.type import iscalar, matrix, scalar, vector
1616
from tests.link.jax.test_basic import compare_jax_and_py
17-
from tests.tensor.test_basic import TestAlloc
17+
from tests.tensor.test_basic import check_alloc_runtime_broadcast
1818

1919

2020
def test_jax_Alloc():
@@ -54,7 +54,7 @@ def compare_shape_dtype(x, y):
5454

5555

5656
def test_alloc_runtime_broadcast():
57-
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
57+
check_alloc_runtime_broadcast(get_mode("JAX"))
5858

5959

6060
def test_jax_MakeVector():

tests/link/numba/test_elemwise.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
scalar_my_multi_out,
2525
set_test_value,
2626
)
27-
from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester
27+
from tests.tensor.test_elemwise import (
28+
careduce_benchmark_tester,
29+
check_elemwise_runtime_broadcast,
30+
)
2831

2932

3033
rng = np.random.default_rng(42849)
@@ -124,7 +127,7 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
124127

125128
@pytest.mark.xfail(reason="Logic had to be reversed due to surprising segfaults")
126129
def test_elemwise_runtime_broadcast():
127-
TestElemwise.check_runtime_broadcast(get_mode("NUMBA"))
130+
check_elemwise_runtime_broadcast(get_mode("NUMBA"))
128131

129132

130133
def test_elemwise_speed(benchmark):

tests/link/numba/test_tensor_basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
compare_shape_dtype,
1717
set_test_value,
1818
)
19-
from tests.tensor.test_basic import TestAlloc
19+
from tests.tensor.test_basic import check_alloc_runtime_broadcast
2020

2121

2222
pytest.importorskip("numba")
@@ -52,7 +52,7 @@ def test_Alloc(v, shape):
5252

5353

5454
def test_alloc_runtime_broadcast():
55-
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
55+
check_alloc_runtime_broadcast(get_mode("NUMBA"))
5656

5757

5858
def test_AllocEmpty():

tests/tensor/test_basic.py

+27-27
Original file line numberDiff line numberDiff line change
@@ -716,6 +716,32 @@ def test_masked_array_not_implemented(
716716
ptb.as_tensor(x)
717717

718718

719+
def check_alloc_runtime_broadcast(mode):
720+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
721+
floatX = config.floatX
722+
x_v = vector("x", shape=(None,))
723+
724+
out = alloc(x_v, 5, 3)
725+
f = pytensor.function([x_v], out, mode=mode)
726+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
727+
728+
np.testing.assert_array_equal(
729+
f(x=np.zeros((3,), dtype=floatX)),
730+
np.zeros((5, 3), dtype=floatX),
731+
)
732+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
733+
f(x=np.zeros((1,), dtype=floatX))
734+
735+
out = alloc(specify_shape(x_v, (1,)), 5, 3)
736+
f = pytensor.function([x_v], out, mode=mode)
737+
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
738+
739+
np.testing.assert_array_equal(
740+
f(x=np.zeros((1,), dtype=floatX)),
741+
np.zeros((5, 3), dtype=floatX),
742+
)
743+
744+
719745
class TestAlloc:
720746
dtype = config.floatX
721747
mode = mode_opt
@@ -729,32 +755,6 @@ def check_allocs_in_fgraph(fgraph, n):
729755
== n
730756
)
731757

732-
@staticmethod
733-
def check_runtime_broadcast(mode):
734-
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
735-
floatX = config.floatX
736-
x_v = vector("x", shape=(None,))
737-
738-
out = alloc(x_v, 5, 3)
739-
f = pytensor.function([x_v], out, mode=mode)
740-
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
741-
742-
np.testing.assert_array_equal(
743-
f(x=np.zeros((3,), dtype=floatX)),
744-
np.zeros((5, 3), dtype=floatX),
745-
)
746-
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
747-
f(x=np.zeros((1,), dtype=floatX))
748-
749-
out = alloc(specify_shape(x_v, (1,)), 5, 3)
750-
f = pytensor.function([x_v], out, mode=mode)
751-
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
752-
753-
np.testing.assert_array_equal(
754-
f(x=np.zeros((1,), dtype=floatX)),
755-
np.zeros((5, 3), dtype=floatX),
756-
)
757-
758758
def setup_method(self):
759759
self.rng = np.random.default_rng(seed=utt.fetch_seed())
760760

@@ -912,7 +912,7 @@ def test_alloc_of_view_linker(self):
912912

913913
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
914914
def test_runtime_broadcast(self, mode):
915-
self.check_runtime_broadcast(mode)
915+
check_alloc_runtime_broadcast(mode)
916916

917917

918918
def test_infer_static_shape():

tests/tensor/test_elemwise.py

+29-29
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,33 @@ def test_any_grad(self):
705705
assert np.all(gx_val == 0)
706706

707707

708+
def check_elemwise_runtime_broadcast(mode):
709+
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
710+
x_v = matrix("x")
711+
m_v = vector("m")
712+
713+
z_v = x_v - m_v
714+
f = pytensor.function([x_v, m_v], z_v, mode=mode)
715+
716+
# Test invalid broadcasting by either x or m
717+
for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]:
718+
x = np.ones(x_sh).astype(config.floatX)
719+
m = np.zeros(m_sh).astype(config.floatX)
720+
721+
# This error is introduced by PyTensor, so it's the same across different backends
722+
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
723+
f(x, m)
724+
725+
x = np.ones((2, 3)).astype(config.floatX)
726+
m = np.zeros((1,)).astype(config.floatX)
727+
728+
x = np.ones((2, 4)).astype(config.floatX)
729+
m = np.zeros((3,)).astype(config.floatX)
730+
# This error is backend specific, and may have different types
731+
with pytest.raises((ValueError, TypeError)):
732+
f(x, m)
733+
734+
708735
class TestElemwise(unittest_tools.InferShapeTester):
709736
def test_elemwise_grad_bool(self):
710737
x = scalar(dtype="bool")
@@ -750,42 +777,15 @@ def test_input_dimensions_overflow(self):
750777
g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py"))
751778
g(*[np.zeros(2**11, config.floatX) for i in range(6)])
752779

753-
@staticmethod
754-
def check_runtime_broadcast(mode):
755-
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
756-
x_v = matrix("x")
757-
m_v = vector("m")
758-
759-
z_v = x_v - m_v
760-
f = pytensor.function([x_v, m_v], z_v, mode=mode)
761-
762-
# Test invalid broadcasting by either x or m
763-
for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]:
764-
x = np.ones(x_sh).astype(config.floatX)
765-
m = np.zeros(m_sh).astype(config.floatX)
766-
767-
# This error is introduced by PyTensor, so it's the same across different backends
768-
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
769-
f(x, m)
770-
771-
x = np.ones((2, 3)).astype(config.floatX)
772-
m = np.zeros((1,)).astype(config.floatX)
773-
774-
x = np.ones((2, 4)).astype(config.floatX)
775-
m = np.zeros((3,)).astype(config.floatX)
776-
# This error is backend specific, and may have different types
777-
with pytest.raises((ValueError, TypeError)):
778-
f(x, m)
779-
780780
def test_runtime_broadcast_python(self):
781-
self.check_runtime_broadcast(Mode(linker="py"))
781+
check_elemwise_runtime_broadcast(Mode(linker="py"))
782782

783783
@pytest.mark.skipif(
784784
not pytensor.config.cxx,
785785
reason="G++ not available, so we need to skip this test.",
786786
)
787787
def test_runtime_broadcast_c(self):
788-
self.check_runtime_broadcast(Mode(linker="c"))
788+
check_elemwise_runtime_broadcast(Mode(linker="c"))
789789

790790
def test_str(self):
791791
op = Elemwise(ps.add, inplace_pattern={0: 0}, name=None)

0 commit comments

Comments
 (0)