@@ -34,9 +34,8 @@ class ScipyRandomVariable(RandomVariable):
34
34
35
35
"""
36
36
37
- @classmethod
38
37
@abc .abstractmethod
39
- def rng_fn_scipy (cls , rng , * args , ** kwargs ):
38
+ def rng_fn_scipy (cls , * args , ** kwargs ):
40
39
r"""
41
40
42
41
`RandomVariable`\s implementations that want to use SciPy-based samplers
@@ -46,24 +45,30 @@ def rng_fn_scipy(cls, rng, *args, **kwargs):
46
45
47
46
"""
48
47
49
- @classmethod
50
- def rng_fn (cls , * args , ** kwargs ):
51
- size = args [- 1 ]
52
- res = cls .rng_fn_scipy (* args , ** kwargs )
48
+ def rng_fn (self , * args ):
49
+ rng , * params , size , _ = args
50
+ return self .rng_fn_scipy (rng , * params , size )
51
+
52
+ def perform (self , node , inputs , outputs ):
53
+ super ().perform (node , inputs , outputs )
54
+
55
+ _ , batch_shape , _ , * params = inputs
56
+ _ , draws_container = outputs
57
+ [draws ] = draws_container
53
58
54
- if np .ndim (res ) == 0 :
59
+ if np .ndim (draws ) == 0 :
55
60
# The sample is an `np.number`, and is not writeable, or non-NumPy
56
61
# type, so we need to clone/create a usable NumPy result
57
- res = np .asarray (res )
62
+ draws = np .asarray (draws )
58
63
59
- if size is None :
64
+ if batch_shape is None :
60
65
# SciPy will sometimes drop broadcastable dimensions; we need to
61
66
# check and, if necessary, add them back
62
- exp_shape = broadcast_shapes ( * [ np . shape ( a ) for a in args [ 1 : - 1 ]])
63
- if res . shape != exp_shape :
64
- return np .broadcast_to ( res , exp_shape ). copy ( )
67
+ missing_ndim = node . outputs [ 1 ]. type . ndim - draws . ndim
68
+ if missing_ndim :
69
+ draws = np .expand_dims ( draws , tuple ( range ( missing_ndim )) )
65
70
66
- return res
71
+ draws_container [ 0 ] = draws
67
72
68
73
69
74
class UniformRV (RandomVariable ):
@@ -423,7 +428,7 @@ class GammaRV(RandomVariable):
423
428
dtype = "floatX"
424
429
_print_name = ("Gamma" , "\\ operatorname{Gamma}" )
425
430
426
- def __call__ (self , shape , scale , size = None , ** kwargs ):
431
+ def __call__ (self , shape_param , scale , size = None , ** kwargs ):
427
432
r"""Draw samples from a gamma distribution.
428
433
429
434
Signature
@@ -433,7 +438,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
433
438
434
439
Parameters
435
440
----------
436
- shape
441
+ shape_param
437
442
The shape :math:`\alpha` of the gamma distribution. Must be positive.
438
443
scale
439
444
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
@@ -444,7 +449,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
444
449
is returned.
445
450
446
451
"""
447
- return super ().__call__ (shape , scale , size = size , ** kwargs )
452
+ return super ().__call__ (shape_param , scale , size = size , ** kwargs )
448
453
449
454
450
455
_gamma = GammaRV ()
@@ -672,7 +677,7 @@ class WeibullRV(RandomVariable):
672
677
dtype = "floatX"
673
678
_print_name = ("Weibull" , "\\ operatorname{Weibull}" )
674
679
675
- def __call__ (self , shape , size = None , ** kwargs ):
680
+ def __call__ (self , shape_param , size = None , ** kwargs ):
676
681
r"""Draw samples from a weibull distribution.
677
682
678
683
Signature
@@ -682,7 +687,7 @@ def __call__(self, shape, size=None, **kwargs):
682
687
683
688
Parameters
684
689
----------
685
- shape
690
+ shape_param
686
691
The shape :math:`k` of the distribution. Must be positive.
687
692
size
688
693
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
@@ -691,7 +696,7 @@ def __call__(self, shape, size=None, **kwargs):
691
696
is returned.
692
697
693
698
"""
694
- return super ().__call__ (shape , size = size , ** kwargs )
699
+ return super ().__call__ (shape_param , size = size , ** kwargs )
695
700
696
701
697
702
weibull = WeibullRV ()
@@ -863,7 +868,7 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
863
868
return super ().__call__ (mean , cov , size = size , ** kwargs )
864
869
865
870
@classmethod
866
- def rng_fn (cls , rng , mean , cov , size ):
871
+ def rng_fn (cls , rng , mean , cov , size , core_shape = None ):
867
872
if mean .ndim > 1 or cov .ndim > 2 :
868
873
# Neither SciPy nor NumPy implement parameter broadcasting for
869
874
# multivariate normals (or any other multivariate distributions),
@@ -932,7 +937,7 @@ def __call__(self, alphas, size=None, **kwargs):
932
937
return super ().__call__ (alphas , size = size , ** kwargs )
933
938
934
939
@classmethod
935
- def rng_fn (cls , rng , alphas , size ):
940
+ def rng_fn (cls , rng , alphas , size , core_shape = None ):
936
941
if alphas .ndim > 1 :
937
942
if size is None :
938
943
size = ()
@@ -1213,7 +1218,7 @@ class InvGammaRV(ScipyRandomVariable):
1213
1218
dtype = "floatX"
1214
1219
_print_name = ("InverseGamma" , "\\ operatorname{InverseGamma}" )
1215
1220
1216
- def __call__ (self , shape , scale , size = None , ** kwargs ):
1221
+ def __call__ (self , shape_param , scale , size = None , ** kwargs ):
1217
1222
r"""Draw samples from an inverse-gamma distribution.
1218
1223
1219
1224
Signature
@@ -1223,7 +1228,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
1223
1228
1224
1229
Parameters
1225
1230
----------
1226
- shape
1231
+ shape_param
1227
1232
Shape parameter :math:`\alpha` of the distribution. Must be positive.
1228
1233
scale
1229
1234
Scale parameter :math:`\beta` of the distribution. Must be
@@ -1234,7 +1239,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
1234
1239
`None`, in which case a single sample is returned.
1235
1240
1236
1241
"""
1237
- return super ().__call__ (shape , scale , size = size , ** kwargs )
1242
+ return super ().__call__ (shape_param , scale , size = size , ** kwargs )
1238
1243
1239
1244
@classmethod
1240
1245
def rng_fn_scipy (cls , rng , shape , scale , size ):
@@ -1748,7 +1753,7 @@ def __call__(self, n, p, size=None, **kwargs):
1748
1753
return super ().__call__ (n , p , size = size , ** kwargs )
1749
1754
1750
1755
@classmethod
1751
- def rng_fn (cls , rng , n , p , size ):
1756
+ def rng_fn (cls , rng , n , p , size , core_shape = None ):
1752
1757
if n .ndim > 0 or p .ndim > 1 :
1753
1758
size = tuple (size or ())
1754
1759
@@ -1812,7 +1817,7 @@ def __call__(self, p, size=None, **kwargs):
1812
1817
return super ().__call__ (p , size = size , ** kwargs )
1813
1818
1814
1819
@classmethod
1815
- def rng_fn (cls , rng , p , size ):
1820
+ def rng_fn (cls , rng , p , size , core_shape = None ):
1816
1821
if size is None :
1817
1822
size = p .shape [:- 1 ]
1818
1823
else :
@@ -1901,10 +1906,10 @@ def __init__(self, *args, ndim_supp: int, p_none: bool, signature=None, **kwargs
1901
1906
def rng_fn (self , * params ):
1902
1907
# Should we split into two Ops depending on p_none or not?
1903
1908
if self .p_none :
1904
- rng , a , replace , size = params
1909
+ rng , a , replace , size , core_shape = params
1905
1910
p = None
1906
1911
else :
1907
- rng , a , p , replace , size = params
1912
+ rng , a , p , replace , size , core_shape = params
1908
1913
1909
1914
batch_ndim = a .ndim - self .ndims_params [0 ]
1910
1915
@@ -1982,7 +1987,7 @@ class PermutationRV(RandomVariable):
1982
1987
_print_name = ("permutation" , "\\ operatorname{permutation}" )
1983
1988
1984
1989
@classmethod
1985
- def rng_fn (cls , rng , x , size ):
1990
+ def rng_fn (cls , rng , x , size , core_shape = None ):
1986
1991
return rng .permutation (x )
1987
1992
1988
1993
def __call__ (self , x , dtype = None , ** kwargs ):
0 commit comments