@@ -809,94 +809,90 @@ def sample_fn(rng, size, dtype, *parameters):
809
809
compare_jax_and_py (fgraph , [])
810
810
811
811
812
- def test_random_concrete_shape ():
813
- """JAX should compile when a `RandomVariable` is passed a concrete shape.
814
-
815
- There are three quantities that JAX considers as concrete:
816
- 1. Constants known at compile time;
817
- 2. The shape of an array.
818
- 3. `static_argnums` parameters
819
- This test makes sure that graphs with `RandomVariable`s compile when the
820
- `size` parameter satisfies either of these criteria.
821
-
822
- """
823
- rng = shared (np .random .default_rng (123 ))
824
- x_pt = pt .dmatrix ()
825
- out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
826
- jax_fn = compile_random_function ([x_pt ], out )
827
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
828
-
829
-
830
- def test_random_concrete_shape_from_param ():
831
- rng = shared (np .random .default_rng (123 ))
832
- x_pt = pt .dmatrix ()
833
- out = pt .random .normal (x_pt , 1 , rng = rng )
834
- jax_fn = compile_random_function ([x_pt ], out )
835
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
836
-
837
-
838
- def test_random_concrete_shape_subtensor ():
839
- """JAX should compile when a concrete value is passed for the `size` parameter.
840
-
841
- This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
842
- inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
843
- inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
844
- rewrite.
845
-
846
- JAX does not accept scalars as `size` or `shape` arguments, so this is a
847
- slight improvement over their API.
848
-
849
- """
850
- rng = shared (np .random .default_rng (123 ))
851
- x_pt = pt .dmatrix ()
852
- out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
853
- jax_fn = compile_random_function ([x_pt ], out )
854
- assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
855
-
856
-
857
- def test_random_concrete_shape_subtensor_tuple ():
858
- """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
859
-
860
- This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
861
- inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
862
- scalar inputs into tuples of concrete values using the
863
- `jax_size_parameter_as_tuple` rewrite.
864
-
865
- """
866
- rng = shared (np .random .default_rng (123 ))
867
- x_pt = pt .dmatrix ()
868
- out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
869
- jax_fn = compile_random_function ([x_pt ], out )
870
- assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
871
-
872
-
873
- @pytest .mark .xfail (
874
- reason = "`size_pt` should be specified as a static argument" , strict = True
875
- )
876
- def test_random_concrete_shape_graph_input ():
877
- rng = shared (np .random .default_rng (123 ))
878
- size_pt = pt .scalar ()
879
- out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
880
- jax_fn = compile_random_function ([size_pt ], out )
881
- assert jax_fn (10 ).shape == (10 ,)
882
-
883
-
884
- def test_constant_shape_after_graph_rewriting ():
885
- size = pt .vector ("size" , shape = (2 ,), dtype = int )
886
- x = pt .random .normal (size = size )
887
- assert x .type .shape == (None , None )
888
-
889
- with pytest .raises (TypeError ):
890
- compile_random_function ([size ], x )([2 , 5 ])
891
-
892
- # Rebuild with strict=False so output type is not updated
893
- # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
894
- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
895
- assert new_x .type .shape == (None , None )
896
- assert compile_random_function ([], new_x )().shape == (2 , 5 )
897
-
898
- # Rebuild with strict=True, so output type is updated
899
- # This uses a different path in the dispatch implementation
900
- new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
901
- assert new_x .type .shape == (2 , 5 )
902
- assert compile_random_function ([], new_x )().shape == (2 , 5 )
812
+ class TestRandomShapeInputs :
813
+ def test_random_concrete_shape (self ):
814
+ """JAX should compile when a `RandomVariable` is passed a concrete shape.
815
+
816
+ There are three quantities that JAX considers as concrete:
817
+ 1. Constants known at compile time;
818
+ 2. The shape of an array.
819
+ 3. `static_argnums` parameters
820
+ This test makes sure that graphs with `RandomVariable`s compile when the
821
+ `size` parameter satisfies either of these criteria.
822
+
823
+ """
824
+ rng = shared (np .random .default_rng (123 ))
825
+ x_pt = pt .dmatrix ()
826
+ out = pt .random .normal (0 , 1 , size = x_pt .shape , rng = rng )
827
+ jax_fn = compile_random_function ([x_pt ], out )
828
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
829
+
830
+ def test_random_concrete_shape_from_param (self ):
831
+ rng = shared (np .random .default_rng (123 ))
832
+ x_pt = pt .dmatrix ()
833
+ out = pt .random .normal (x_pt , 1 , rng = rng )
834
+ jax_fn = compile_random_function ([x_pt ], out )
835
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
836
+
837
+ def test_random_concrete_shape_subtensor (self ):
838
+ """JAX should compile when a concrete value is passed for the `size` parameter.
839
+
840
+ This test ensures that the `DimShuffle` `Op` used by PyTensor to turn scalar
841
+ inputs into 1d vectors is replaced by an `Op` that turns concrete scalar
842
+ inputs into tuples of concrete values using the `jax_size_parameter_as_tuple`
843
+ rewrite.
844
+
845
+ JAX does not accept scalars as `size` or `shape` arguments, so this is a
846
+ slight improvement over their API.
847
+
848
+ """
849
+ rng = shared (np .random .default_rng (123 ))
850
+ x_pt = pt .dmatrix ()
851
+ out = pt .random .normal (0 , 1 , size = x_pt .shape [1 ], rng = rng )
852
+ jax_fn = compile_random_function ([x_pt ], out )
853
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
854
+
855
+ def test_random_concrete_shape_subtensor_tuple (self ):
856
+ """JAX should compile when a tuple of concrete values is passed for the `size` parameter.
857
+
858
+ This test ensures that the `MakeVector` `Op` used by PyTensor to turn tuple
859
+ inputs into 1d vectors is replaced by an `Op` that turns a tuple of concrete
860
+ scalar inputs into tuples of concrete values using the
861
+ `jax_size_parameter_as_tuple` rewrite.
862
+
863
+ """
864
+ rng = shared (np .random .default_rng (123 ))
865
+ x_pt = pt .dmatrix ()
866
+ out = pt .random .normal (0 , 1 , size = (x_pt .shape [0 ],), rng = rng )
867
+ jax_fn = compile_random_function ([x_pt ], out )
868
+ assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
869
+
870
+ @pytest .mark .xfail (
871
+ reason = "`size_pt` should be specified as a static argument" , strict = True
872
+ )
873
+ def test_random_concrete_shape_graph_input (self ):
874
+ rng = shared (np .random .default_rng (123 ))
875
+ size_pt = pt .scalar ()
876
+ out = pt .random .normal (0 , 1 , size = size_pt , rng = rng )
877
+ jax_fn = compile_random_function ([size_pt ], out )
878
+ assert jax_fn (10 ).shape == (10 ,)
879
+
880
+ def test_constant_shape_after_graph_rewriting (self ):
881
+ size = pt .vector ("size" , shape = (2 ,), dtype = int )
882
+ x = pt .random .normal (size = size )
883
+ assert x .type .shape == (None , None )
884
+
885
+ with pytest .raises (TypeError ):
886
+ compile_random_function ([size ], x )([2 , 5 ])
887
+
888
+ # Rebuild with strict=False so output type is not updated
889
+ # This reflects cases where size is constant folded during rewrites but the RV node is not recreated
890
+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = True )
891
+ assert new_x .type .shape == (None , None )
892
+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
893
+
894
+ # Rebuild with strict=True, so output type is updated
895
+ # This uses a different path in the dispatch implementation
896
+ new_x = clone_replace (x , {size : pt .constant ([2 , 5 ])}, rebuild_strict = False )
897
+ assert new_x .type .shape == (2 , 5 )
898
+ assert compile_random_function ([], new_x )().shape == (2 , 5 )
0 commit comments