Skip to content

Commit cc96a16

Browse files
committed
Fix PermutationRV ambiguous signature
The RV always expects a vector input and `ndims_paramas` is always `[1]`. Size is no longer ignored
1 parent c53ea8f commit cc96a16

File tree

2 files changed

+40
-15
lines changed

2 files changed

+40
-15
lines changed

pytensor/tensor/random/basic.py

+32-15
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import scipy.stats as stats
66

77
import pytensor
8-
from pytensor.tensor.basic import as_tensor_variable, arange
8+
from pytensor.tensor.basic import arange, as_tensor_variable
99
from pytensor.tensor.random.op import RandomVariable
1010
from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
1111
from pytensor.tensor.random.utils import (
@@ -2072,18 +2072,15 @@ class PermutationRV(RandomVariable):
20722072

20732073
@classmethod
20742074
def rng_fn(cls, rng, x, size):
2075-
return rng.permutation(x if x.ndim > 0 else x.item())
2075+
return rng.permutation(x)
20762076

2077-
def _infer_shape(self, size, dist_params, param_shapes=None):
2078-
param_shapes = param_shapes or [p.shape for p in dist_params]
2079-
2080-
(x,) = dist_params
2081-
(x_shape,) = param_shapes
2082-
2083-
if x.ndim == 0:
2084-
return (x,)
2085-
else:
2086-
return x_shape
2077+
def _supp_shape_from_params(self, dist_params, param_shapes=None):
2078+
return supp_shape_from_ref_param_shape(
2079+
ndim_supp=self.ndim_supp,
2080+
dist_params=dist_params,
2081+
param_shapes=param_shapes,
2082+
ref_param_idx=0,
2083+
)
20872084

20882085
def __call__(self, x, **kwargs):
20892086
r"""Randomly permute a sequence or a range of values.
@@ -2096,15 +2093,35 @@ def __call__(self, x, **kwargs):
20962093
Parameters
20972094
----------
20982095
x
2099-
If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
2100-
shuffle its elements randomly.
2096+
Elements to be shuffled.
21012097
21022098
"""
21032099
x = as_tensor_variable(x)
21042100
return super().__call__(x, dtype=x.dtype, **kwargs)
21052101

21062102

2107-
permutation = PermutationRV()
2103+
_permutation = PermutationRV()
2104+
2105+
2106+
def permutation(x, **kwargs):
2107+
r"""Randomly permute a sequence or a range of values.
2108+
2109+
Signature
2110+
---------
2111+
2112+
`(x) -> (x)`
2113+
2114+
Parameters
2115+
----------
2116+
x
2117+
If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
2118+
shuffle its elements randomly.
2119+
2120+
"""
2121+
x = as_tensor_variable(x)
2122+
if x.type.ndim == 0:
2123+
x = arange(x)
2124+
return _permutation(x, **kwargs)
21082125

21092126

21102127
__all__ = [

tests/tensor/random/test_basic.py

+8
Original file line numberDiff line numberDiff line change
@@ -1413,6 +1413,14 @@ def test_permutation_samples():
14131413
compare_sample_values(permutation, np.array([1.0, 2.0, 3.0], dtype=config.floatX))
14141414

14151415

1416+
def test_permutation_shape():
1417+
assert tuple(permutation(5).shape.eval()) == (5,)
1418+
assert tuple(permutation(np.arange(5)).shape.eval()) == (5,)
1419+
assert tuple(permutation(np.arange(10).reshape(2, 5)).shape.eval()) == (2, 5)
1420+
assert tuple(permutation(5, size=(2, 3)).shape.eval()) == (2, 3, 5)
1421+
assert tuple(permutation(np.arange(5), size=(2, 3)).shape.eval()) == (2, 3, 5)
1422+
1423+
14161424
@config.change_flags(compute_test_value="off")
14171425
def test_pickle():
14181426
# This is an interesting `Op` case, because it has `None` types and a

0 commit comments

Comments
 (0)