Skip to content

Commit 194b871

Browse files
committed
Group JAX random shape input tests
1 parent 9ac810c commit 194b871

File tree

1 file changed

+87
-91
lines changed

1 file changed

+87
-91
lines changed

tests/link/jax/test_random.py

Lines changed: 87 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -809,94 +809,90 @@ def sample_fn(rng, size, dtype, *parameters):
809809
compare_jax_and_py(fgraph, [])
810810

811811

812-
def test_random_concrete_shape():
813-
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
814-
815-
There are three quantities that JAX considers as concrete:
816-
1. Constants known at compile time;
817-
2. The shape of an array.
818-
3. `static_argnums` parameters
819-
This test makes sure that graphs with `RandomVariable`s compile when the
820-
`size` parameter satisfies either of these criteria.
821-
822-
"""
823-
rng = shared(np.random.default_rng(123))
824-
x_pt = pt.dmatrix()
825-
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
826-
jax_fn = compile_random_function([x_pt], out)
827-
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
828-
829-
830-
def test_random_concrete_shape_from_param():
831-
rng = shared(np.random.default_rng(123))
832-
x_pt = pt.dmatrix()
833-
out = pt.random.normal(x_pt, 1, rng=rng)
834-
jax_fn = compile_random_function([x_pt], out)
835-
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
836-
837-
838-
def test_random_concrete_shape_subtensor():
839-
"""JAX should compile when a concrete value is passed for the `size` parameter.
840-
841-
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
842-
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
843-
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
844-
rewrite.
845-
846-
JAX does not accept scalars as `size` or `shape` arguments, so this is a
847-
slight improvement over their API.
848-
849-
"""
850-
rng = shared(np.random.default_rng(123))
851-
x_pt = pt.dmatrix()
852-
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
853-
jax_fn = compile_random_function([x_pt], out)
854-
assert jax_fn(np.ones((2, 3))).shape == (3,)
855-
856-
857-
def test_random_concrete_shape_subtensor_tuple():
858-
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
859-
860-
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
861-
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
862-
scalar inputs into tuples of concrete values using the
863-
`jax_size_parameter_as_tuple` rewrite.
864-
865-
"""
866-
rng = shared(np.random.default_rng(123))
867-
x_pt = pt.dmatrix()
868-
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
869-
jax_fn = compile_random_function([x_pt], out)
870-
assert jax_fn(np.ones((2, 3))).shape == (2,)
871-
872-
873-
@pytest.mark.xfail(
874-
reason="`size_pt` should be specified as a static argument", strict=True
875-
)
876-
def test_random_concrete_shape_graph_input():
877-
rng = shared(np.random.default_rng(123))
878-
size_pt = pt.scalar()
879-
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
880-
jax_fn = compile_random_function([size_pt], out)
881-
assert jax_fn(10).shape == (10,)
882-
883-
884-
def test_constant_shape_after_graph_rewriting():
885-
size = pt.vector("size", shape=(2,), dtype=int)
886-
x = pt.random.normal(size=size)
887-
assert x.type.shape == (None, None)
888-
889-
with pytest.raises(TypeError):
890-
compile_random_function([size], x)([2, 5])
891-
892-
# Rebuild with strict=False so output type is not updated
893-
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
894-
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
895-
assert new_x.type.shape == (None, None)
896-
assert compile_random_function([], new_x)().shape == (2, 5)
897-
898-
# Rebuild with strict=True, so output type is updated
899-
# This uses a different path in the dispatch implementation
900-
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
901-
assert new_x.type.shape == (2, 5)
902-
assert compile_random_function([], new_x)().shape == (2, 5)
812+
class TestRandomShapeInputs:
813+
def test_random_concrete_shape(self):
814+
"""JAX should compile when a `RandomVariable` is passed a concrete shape.
815+
816+
There are three quantities that JAX considers as concrete:
817+
1. Constants known at compile time;
818+
2. The shape of an array.
819+
3. `static_argnums` parameters
820+
This test makes sure that graphs with `RandomVariable`s compile when the
821+
`size` parameter satisfies either of these criteria.
822+
823+
"""
824+
rng = shared(np.random.default_rng(123))
825+
x_pt = pt.dmatrix()
826+
out = pt.random.normal(0, 1, size=x_pt.shape, rng=rng)
827+
jax_fn = compile_random_function([x_pt], out)
828+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
829+
830+
def test_random_concrete_shape_from_param(self):
831+
rng = shared(np.random.default_rng(123))
832+
x_pt = pt.dmatrix()
833+
out = pt.random.normal(x_pt, 1, rng=rng)
834+
jax_fn = compile_random_function([x_pt], out)
835+
assert jax_fn(np.ones((2, 3))).shape == (2, 3)
836+
837+
def test_random_concrete_shape_subtensor(self):
838+
"""JAX should compile when a concrete value is passed for the `size` parameter.
839+
840+
This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
841+
inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
842+
inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
843+
rewrite.
844+
845+
JAX does not accept scalars as `size` or `shape` arguments, so this is a
846+
slight improvement over their API.
847+
848+
"""
849+
rng = shared(np.random.default_rng(123))
850+
x_pt = pt.dmatrix()
851+
out = pt.random.normal(0, 1, size=x_pt.shape[1], rng=rng)
852+
jax_fn = compile_random_function([x_pt], out)
853+
assert jax_fn(np.ones((2, 3))).shape == (3,)
854+
855+
def test_random_concrete_shape_subtensor_tuple(self):
856+
"""JAX should compile when a tuple of concrete values is passed for the `size` parameter.
857+
858+
This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
859+
inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
860+
scalar inputs into tuples of concrete values using the
861+
`jax_size_parameter_as_tuple` rewrite.
862+
863+
"""
864+
rng = shared(np.random.default_rng(123))
865+
x_pt = pt.dmatrix()
866+
out = pt.random.normal(0, 1, size=(x_pt.shape[0],), rng=rng)
867+
jax_fn = compile_random_function([x_pt], out)
868+
assert jax_fn(np.ones((2, 3))).shape == (2,)
869+
870+
@pytest.mark.xfail(
871+
reason="`size_pt` should be specified as a static argument", strict=True
872+
)
873+
def test_random_concrete_shape_graph_input(self):
874+
rng = shared(np.random.default_rng(123))
875+
size_pt = pt.scalar()
876+
out = pt.random.normal(0, 1, size=size_pt, rng=rng)
877+
jax_fn = compile_random_function([size_pt], out)
878+
assert jax_fn(10).shape == (10,)
879+
880+
def test_constant_shape_after_graph_rewriting(self):
881+
size = pt.vector("size", shape=(2,), dtype=int)
882+
x = pt.random.normal(size=size)
883+
assert x.type.shape == (None, None)
884+
885+
with pytest.raises(TypeError):
886+
compile_random_function([size], x)([2, 5])
887+
888+
# Rebuild with strict=False so output type is not updated
889+
# This reflects cases where size is constant folded during rewrites but the RV node is not recreated
890+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=True)
891+
assert new_x.type.shape == (None, None)
892+
assert compile_random_function([], new_x)().shape == (2, 5)
893+
894+
# Rebuild with strict=True, so output type is updated
895+
# This uses a different path in the dispatch implementation
896+
new_x = clone_replace(x, {size: pt.constant([2, 5])}, rebuild_strict=False)
897+
assert new_x.type.shape == (2, 5)
898+
assert compile_random_function([], new_x)().shape == (2, 5)

0 commit comments

Comments
 (0)