@@ -2002,6 +2002,11 @@ def _supp_shape_from_params(self, dist_params, param_shapes=None):
2002
2002
a_shape = tuple (a .shape ) if param_shapes is None else tuple (param_shapes [0 ])
2003
2003
a_batch_ndim = len (a_shape ) - self .ndims_params [0 ]
2004
2004
a_core_shape = a_shape [a_batch_ndim :]
2005
+ core_shape_ndim = core_shape .type .ndim
2006
+ if core_shape_ndim > 1 :
2007
+ # Batch core shapes are only valid if homogeneous or broadcasted,
2008
+ # as otherwise they would imply ragged choice arrays
2009
+ core_shape = core_shape [(0 ,) * (core_shape_ndim - 1 )]
2005
2010
return tuple (core_shape ) + a_core_shape [1 :]
2006
2011
2007
2012
def rng_fn (self , * params ):
@@ -2011,15 +2016,11 @@ def rng_fn(self, *params):
2011
2016
rng , a , core_shape , size = params
2012
2017
p = None
2013
2018
2019
+ if core_shape .ndim > 1 :
2020
+ core_shape = core_shape [(0 ,) * (core_shape .ndim - 1 )]
2014
2021
core_shape = tuple (core_shape )
2015
2022
2016
- # We don't have access to the node in rng_fn for easy computation of batch_ndim :(
2017
- a_batch_ndim = batch_ndim = a .ndim - self .ndims_params [0 ]
2018
- if p is not None :
2019
- p_batch_ndim = p .ndim - self .ndims_params [1 ]
2020
- batch_ndim = max (batch_ndim , p_batch_ndim )
2021
- size_ndim = 0 if size is None else len (size )
2022
- batch_ndim = max (batch_ndim , size_ndim )
2023
+ batch_ndim = a .ndim - self .ndims_params [0 ]
2023
2024
2024
2025
if batch_ndim == 0 :
2025
2026
# Numpy choice fails with size=() if a.ndim > 1 is batched
@@ -2031,16 +2032,16 @@ def rng_fn(self, *params):
2031
2032
# Numpy choice doesn't have a concept of batch dims
2032
2033
if size is None :
2033
2034
if p is None :
2034
- size = a .shape [:a_batch_ndim ]
2035
+ size = a .shape [:batch_ndim ]
2035
2036
else :
2036
2037
size = np .broadcast_shapes (
2037
- a .shape [:a_batch_ndim ],
2038
- p .shape [:p_batch_ndim ],
2038
+ a .shape [:batch_ndim ],
2039
+ p .shape [:batch_ndim ],
2039
2040
)
2040
2041
2041
- a = np .broadcast_to (a , size + a .shape [a_batch_ndim :])
2042
+ a = np .broadcast_to (a , size + a .shape [batch_ndim :])
2042
2043
if p is not None :
2043
- p = np .broadcast_to (p , size + p .shape [p_batch_ndim :])
2044
+ p = np .broadcast_to (p , size + p .shape [batch_ndim :])
2044
2045
2045
2046
a_indexed_shape = a .shape [len (size ) + 1 :]
2046
2047
out = np .empty (size + core_shape + a_indexed_shape , dtype = a .dtype )
@@ -2143,26 +2144,26 @@ class PermutationRV(RandomVariable):
2143
2144
def _supp_shape_from_params (self , dist_params , param_shapes = None ):
2144
2145
[x ] = dist_params
2145
2146
x_shape = tuple (x .shape if param_shapes is None else param_shapes [0 ])
2146
- if x .type .ndim == 0 :
2147
- return (x ,)
2147
+ if self .ndims_params [0 ] == 0 :
2148
+ # Implicit arange, this is only valid for homogeneous arrays
2149
+ # Otherwise it would imply a ragged permutation array.
2150
+ return (x .ravel ()[0 ],)
2148
2151
else :
2149
2152
batch_x_ndim = x .type .ndim - self .ndims_params [0 ]
2150
2153
return x_shape [batch_x_ndim :]
2151
2154
2152
2155
def rng_fn (self , rng , x , size ):
2153
2156
# We don't have access to the node in rng_fn :(
2154
- x_batch_ndim = x .ndim - self .ndims_params [0 ]
2155
- batch_ndim = max (x_batch_ndim , 0 if size is None else len (size ))
2157
+ batch_ndim = x .ndim - self .ndims_params [0 ]
2156
2158
2157
2159
if batch_ndim :
2158
2160
# rng.permutation has no concept of batch dims
2159
- x_core_shape = x .shape [x_batch_ndim :]
2160
2161
if size is None :
2161
- size = x .shape [:x_batch_ndim ]
2162
+ size = x .shape [:batch_ndim ]
2162
2163
else :
2163
- x = np .broadcast_to (x , size + x_core_shape )
2164
+ x = np .broadcast_to (x , size + x . shape [ batch_ndim :] )
2164
2165
2165
- out = np .empty (size + x_core_shape , dtype = x .dtype )
2166
+ out = np .empty (size + x . shape [ batch_ndim :] , dtype = x .dtype )
2166
2167
for idx in np .ndindex (size ):
2167
2168
out [idx ] = rng .permutation (x [idx ])
2168
2169
return out
0 commit comments