5
5
import scipy .stats as stats
6
6
7
7
import pytensor
8
- from pytensor .tensor .basic import as_tensor_variable , arange
8
+ from pytensor .tensor .basic import arange , as_tensor_variable
9
9
from pytensor .tensor .random .op import RandomVariable
10
10
from pytensor .tensor .random .type import RandomGeneratorType , RandomStateType
11
11
from pytensor .tensor .random .utils import (
@@ -2072,18 +2072,15 @@ class PermutationRV(RandomVariable):
2072
2072
2073
2073
@classmethod
2074
2074
def rng_fn (cls , rng , x , size ):
2075
- return rng .permutation (x if x . ndim > 0 else x . item () )
2075
+ return rng .permutation (x )
2076
2076
2077
- def _infer_shape (self , size , dist_params , param_shapes = None ):
2078
- param_shapes = param_shapes or [p .shape for p in dist_params ]
2079
-
2080
- (x ,) = dist_params
2081
- (x_shape ,) = param_shapes
2082
-
2083
- if x .ndim == 0 :
2084
- return (x ,)
2085
- else :
2086
- return x_shape
2077
+ def _supp_shape_from_params (self , dist_params , param_shapes = None ):
2078
+ return supp_shape_from_ref_param_shape (
2079
+ ndim_supp = self .ndim_supp ,
2080
+ dist_params = dist_params ,
2081
+ param_shapes = param_shapes ,
2082
+ ref_param_idx = 0 ,
2083
+ )
2087
2084
2088
2085
def __call__ (self , x , ** kwargs ):
2089
2086
r"""Randomly permute a sequence or a range of values.
@@ -2096,15 +2093,35 @@ def __call__(self, x, **kwargs):
2096
2093
Parameters
2097
2094
----------
2098
2095
x
2099
- If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
2100
- shuffle its elements randomly.
2096
+ Elements to be shuffled.
2101
2097
2102
2098
"""
2103
2099
x = as_tensor_variable (x )
2104
2100
return super ().__call__ (x , dtype = x .dtype , ** kwargs )
2105
2101
2106
2102
2107
- permutation = PermutationRV ()
2103
+ _permutation = PermutationRV ()
2104
+
2105
+
2106
+ def permutation (x , ** kwargs ):
2107
+ r"""Randomly permute a sequence or a range of values.
2108
+
2109
+ Signature
2110
+ ---------
2111
+
2112
+ `(x) -> (x)`
2113
+
2114
+ Parameters
2115
+ ----------
2116
+ x
2117
+ If `x` is an integer, randomly permute `np.arange(x)`. If `x` is a sequence,
2118
+ shuffle its elements randomly.
2119
+
2120
+ """
2121
+ x = as_tensor_variable (x )
2122
+ if x .type .ndim == 0 :
2123
+ x = arange (x )
2124
+ return _permutation (x , ** kwargs )
2108
2125
2109
2126
2110
2127
__all__ = [
0 commit comments