Skip to content

Commit 372efd6

Browse files
committed
Introduce core shape in RandomVariable Ops
1 parent b9bd344 commit 372efd6

File tree

10 files changed

+323
-366
lines changed

10 files changed

+323
-366
lines changed

pytensor/link/jax/dispatch/random.py

+11-9
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,13 @@ def jax_funcify_RandomVariable(op, node, **kwargs):
9393
if None in out_size:
9494
assert_size_argument_jax_compatible(node)
9595

96-
def sample_fn(rng, size, *parameters):
97-
return jax_sample_fn(op)(rng, size, out_dtype, *parameters)
96+
def sample_fn(rng, batch_shape, core_shape, *parameters):
97+
return jax_sample_fn(op)(rng, batch_shape, out_dtype, *parameters)
9898

9999
else:
100100

101-
def sample_fn(rng, size, *parameters):
102-
return jax_sample_fn(op)(rng, out_size, out_dtype, *parameters)
101+
def sample_fn(rng, batch_shape, core_shape, *parameters):
102+
return jax_sample_fn(op)(rng, batch_shape, out_dtype, *parameters)
103103

104104
return sample_fn
105105

@@ -305,7 +305,7 @@ def jax_sample_fn_binomial(op):
305305

306306
from numpyro.distributions.util import binomial
307307

308-
def sample_fn(rng, size, dtype, n, p):
308+
def sample_fn(rng, size, core_shape, n, p):
309309
rng_key = rng["jax_state"]
310310
rng_key, sampling_key = jax.random.split(rng_key, 2)
311311

@@ -328,11 +328,11 @@ def jax_sample_fn_multinomial(op):
328328

329329
from numpyro.distributions.util import multinomial
330330

331-
def sample_fn(rng, size, dtype, n, p):
331+
def sample_fn(rng, batch_shape, core_shape, n, p):
332332
rng_key = rng["jax_state"]
333333
rng_key, sampling_key = jax.random.split(rng_key, 2)
334334

335-
sample = multinomial(key=sampling_key, n=n, p=p, shape=size)
335+
sample = multinomial(key=sampling_key, n=n, p=p, shape=batch_shape)
336336

337337
rng["jax_state"] = rng_key
338338

@@ -351,12 +351,14 @@ def jax_sample_fn_vonmises(op):
351351

352352
from numpyro.distributions.util import von_mises_centered
353353

354-
def sample_fn(rng, size, dtype, mu, kappa):
354+
dtype = op.dtype
355+
356+
def sample_fn(rng, batch_shape, core_shape, mu, kappa):
355357
rng_key = rng["jax_state"]
356358
rng_key, sampling_key = jax.random.split(rng_key, 2)
357359

358360
sample = von_mises_centered(
359-
key=sampling_key, concentration=kappa, shape=size, dtype=dtype
361+
key=sampling_key, concentration=kappa, shape=batch_shape, dtype=dtype
360362
)
361363
sample = (sample + mu + np.pi) % (2.0 * np.pi) - np.pi
362364

pytensor/link/numba/dispatch/random.py

+26-19
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from pytensor.link.utils import (
2323
compile_function_src,
2424
)
25+
from pytensor.tensor import NoneConst, get_vector_length
2526
from pytensor.tensor.random.op import RandomVariable
2627

2728

@@ -61,7 +62,6 @@ def numba_core_rv_funcify(op: Op, node: Apply) -> Callable:
6162
@numba_core_rv_funcify.register(ptr.BinomialRV)
6263
@numba_core_rv_funcify.register(ptr.NegativeBinomialRV)
6364
@numba_core_rv_funcify.register(ptr.MultinomialRV)
64-
@numba_core_rv_funcify.register(ptr.DirichletRV)
6565
@numba_core_rv_funcify.register(ptr.ChoiceRV) # the `p` argument is not supported
6666
@numba_core_rv_funcify.register(ptr.PermutationRV)
6767
def numba_core_rv_default(op, node):
@@ -155,6 +155,18 @@ def random_fn(rng, mean, cov, out):
155155
return random_fn
156156

157157

158+
@numba_core_rv_funcify.register(ptr.DirichletRV)
159+
def core_DirichletRV(op, node):
160+
@numba_basic.numba_njit
161+
def random_fn(rng, alpha):
162+
y = np.empty_like(alpha)
163+
for i in range(len(alpha)):
164+
y[i] = rng.gamma(alpha[i], 1.0)
165+
return y / y.sum()
166+
167+
return random_fn
168+
169+
158170
@numba_core_rv_funcify.register(ptr.GumbelRV)
159171
def core_GumbelRV(op, node):
160172
"""Code adapted from Numpy Implementation
@@ -239,20 +251,15 @@ def random_fn(rng, mu, kappa):
239251

240252
@numba_funcify.register(ptr.RandomVariable)
241253
def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
242-
size = op.size_param(node)
254+
batch_shape = op.batch_shape_param(node)
255+
core_shape = op.core_shape_param(node)
243256
dist_params = op.dist_params(node)
244-
245-
# None sizes are represented as empty tuple for the time being
246-
# https://github.com/pymc-devs/pytensor/issues/568
247-
[size_len] = size.type.shape
248-
size_is_None = size_len == 0
249-
257+
batch_shape_len = (
258+
None if NoneConst.equals(batch_shape) else get_vector_length(batch_shape)
259+
)
260+
core_shape_len = get_vector_length(core_shape)
250261
inplace = op.inplace
251262

252-
# TODO: Add core_shape to node.inputs
253-
if op.ndim_supp > 0:
254-
raise NotImplementedError("Multivariate RandomVariable not implemented yet")
255-
256263
core_op_fn = numba_core_rv_funcify(op, node)
257264
if not getattr(core_op_fn, "handles_out", False):
258265
nin = 1 + len(dist_params) # rng + params
@@ -270,7 +277,7 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
270277
output_dtypes = encode_literals((node.default_output().type.dtype,))
271278
inplace_pattern = encode_literals(())
272279

273-
def random_wrapper(rng, size, *inputs):
280+
def random_wrapper(rng, batch_shape, core_shape, *dist_params):
274281
if not inplace:
275282
rng = copy(rng)
276283

@@ -281,19 +288,19 @@ def random_wrapper(rng, size, *inputs):
281288
output_dtypes,
282289
inplace_pattern,
283290
(rng,),
284-
inputs,
285-
((),), # TODO: correct core_shapes
291+
dist_params,
292+
(numba_ndarray.to_fixed_tuple(core_shape, core_shape_len),),
286293
None
287-
if size_is_None
288-
else numba_ndarray.to_fixed_tuple(size, size_len), # size
294+
if batch_shape_len is None
295+
else numba_ndarray.to_fixed_tuple(batch_shape, batch_shape_len),
289296
)
290297
return rng, draws
291298

292-
def random(rng, size, *inputs):
299+
def random(rng, batch_shape, core_shape, *dist_params):
293300
pass
294301

295302
@overload(random)
296-
def ov_random(rng, size, *inputs):
303+
def ov_random(rng, batch_shape, core_shape, *dist_params):
297304
return random_wrapper
298305

299306
return random

pytensor/tensor/random/basic.py

+34-29
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,8 @@ class ScipyRandomVariable(RandomVariable):
3434
3535
"""
3636

37-
@classmethod
3837
@abc.abstractmethod
39-
def rng_fn_scipy(cls, rng, *args, **kwargs):
38+
def rng_fn_scipy(cls, *args, **kwargs):
4039
r"""
4140
4241
`RandomVariable`\s implementations that want to use SciPy-based samplers
@@ -46,24 +45,30 @@ def rng_fn_scipy(cls, rng, *args, **kwargs):
4645
4746
"""
4847

49-
@classmethod
50-
def rng_fn(cls, *args, **kwargs):
51-
size = args[-1]
52-
res = cls.rng_fn_scipy(*args, **kwargs)
48+
def rng_fn(self, *args):
49+
rng, *params, size, _ = args
50+
return self.rng_fn_scipy(rng, *params, size)
51+
52+
def perform(self, node, inputs, outputs):
53+
super().perform(node, inputs, outputs)
54+
55+
_, batch_shape, _, *params = inputs
56+
_, draws_container = outputs
57+
[draws] = draws_container
5358

54-
if np.ndim(res) == 0:
59+
if np.ndim(draws) == 0:
5560
# The sample is an `np.number`, and is not writeable, or non-NumPy
5661
# type, so we need to clone/create a usable NumPy result
57-
res = np.asarray(res)
62+
draws = np.asarray(draws)
5863

59-
if size is None:
64+
if batch_shape is None:
6065
# SciPy will sometimes drop broadcastable dimensions; we need to
6166
# check and, if necessary, add them back
62-
exp_shape = broadcast_shapes(*[np.shape(a) for a in args[1:-1]])
63-
if res.shape != exp_shape:
64-
return np.broadcast_to(res, exp_shape).copy()
67+
missing_ndim = node.outputs[1].type.ndim - draws.ndim
68+
if missing_ndim:
69+
draws = np.expand_dims(draws, tuple(range(missing_ndim)))
6570

66-
return res
71+
draws_container[0] = draws
6772

6873

6974
class UniformRV(RandomVariable):
@@ -423,7 +428,7 @@ class GammaRV(RandomVariable):
423428
dtype = "floatX"
424429
_print_name = ("Gamma", "\\operatorname{Gamma}")
425430

426-
def __call__(self, shape, scale, size=None, **kwargs):
431+
def __call__(self, shape_param, scale, size=None, **kwargs):
427432
r"""Draw samples from a gamma distribution.
428433
429434
Signature
@@ -433,7 +438,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
433438
434439
Parameters
435440
----------
436-
shape
441+
shape_param
437442
The shape :math:`\alpha` of the gamma distribution. Must be positive.
438443
scale
439444
The scale :math:`1/\beta` of the gamma distribution. Must be positive.
@@ -444,7 +449,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
444449
is returned.
445450
446451
"""
447-
return super().__call__(shape, scale, size=size, **kwargs)
452+
return super().__call__(shape_param, scale, size=size, **kwargs)
448453

449454

450455
_gamma = GammaRV()
@@ -672,7 +677,7 @@ class WeibullRV(RandomVariable):
672677
dtype = "floatX"
673678
_print_name = ("Weibull", "\\operatorname{Weibull}")
674679

675-
def __call__(self, shape, size=None, **kwargs):
680+
def __call__(self, shape_param, size=None, **kwargs):
676681
r"""Draw samples from a weibull distribution.
677682
678683
Signature
@@ -682,7 +687,7 @@ def __call__(self, shape, size=None, **kwargs):
682687
683688
Parameters
684689
----------
685-
shape
690+
shape_param
686691
The shape :math:`k` of the distribution. Must be positive.
687692
size
688693
Sample shape. If the given size is, e.g. `(m, n, k)` then `m * n * k`
@@ -691,7 +696,7 @@ def __call__(self, shape, size=None, **kwargs):
691696
is returned.
692697
693698
"""
694-
return super().__call__(shape, size=size, **kwargs)
699+
return super().__call__(shape_param, size=size, **kwargs)
695700

696701

697702
weibull = WeibullRV()
@@ -863,7 +868,7 @@ def __call__(self, mean=None, cov=None, size=None, **kwargs):
863868
return super().__call__(mean, cov, size=size, **kwargs)
864869

865870
@classmethod
866-
def rng_fn(cls, rng, mean, cov, size):
871+
def rng_fn(cls, rng, mean, cov, size, core_shape=None):
867872
if mean.ndim > 1 or cov.ndim > 2:
868873
# Neither SciPy nor NumPy implement parameter broadcasting for
869874
# multivariate normals (or any other multivariate distributions),
@@ -932,7 +937,7 @@ def __call__(self, alphas, size=None, **kwargs):
932937
return super().__call__(alphas, size=size, **kwargs)
933938

934939
@classmethod
935-
def rng_fn(cls, rng, alphas, size):
940+
def rng_fn(cls, rng, alphas, size, core_shape=None):
936941
if alphas.ndim > 1:
937942
if size is None:
938943
size = ()
@@ -1213,7 +1218,7 @@ class InvGammaRV(ScipyRandomVariable):
12131218
dtype = "floatX"
12141219
_print_name = ("InverseGamma", "\\operatorname{InverseGamma}")
12151220

1216-
def __call__(self, shape, scale, size=None, **kwargs):
1221+
def __call__(self, shape_param, scale, size=None, **kwargs):
12171222
r"""Draw samples from an inverse-gamma distribution.
12181223
12191224
Signature
@@ -1223,7 +1228,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
12231228
12241229
Parameters
12251230
----------
1226-
shape
1231+
shape_param
12271232
Shape parameter :math:`\alpha` of the distribution. Must be positive.
12281233
scale
12291234
Scale parameter :math:`\beta` of the distribution. Must be
@@ -1234,7 +1239,7 @@ def __call__(self, shape, scale, size=None, **kwargs):
12341239
`None`, in which case a single sample is returned.
12351240
12361241
"""
1237-
return super().__call__(shape, scale, size=size, **kwargs)
1242+
return super().__call__(shape_param, scale, size=size, **kwargs)
12381243

12391244
@classmethod
12401245
def rng_fn_scipy(cls, rng, shape, scale, size):
@@ -1748,7 +1753,7 @@ def __call__(self, n, p, size=None, **kwargs):
17481753
return super().__call__(n, p, size=size, **kwargs)
17491754

17501755
@classmethod
1751-
def rng_fn(cls, rng, n, p, size):
1756+
def rng_fn(cls, rng, n, p, size, core_shape=None):
17521757
if n.ndim > 0 or p.ndim > 1:
17531758
size = tuple(size or ())
17541759

@@ -1812,7 +1817,7 @@ def __call__(self, p, size=None, **kwargs):
18121817
return super().__call__(p, size=size, **kwargs)
18131818

18141819
@classmethod
1815-
def rng_fn(cls, rng, p, size):
1820+
def rng_fn(cls, rng, p, size, core_shape=None):
18161821
if size is None:
18171822
size = p.shape[:-1]
18181823
else:
@@ -1901,10 +1906,10 @@ def __init__(self, *args, ndim_supp: int, p_none: bool, signature=None, **kwargs
19011906
def rng_fn(self, *params):
19021907
# Should we split into two Ops depending on p_none or not?
19031908
if self.p_none:
1904-
rng, a, replace, size = params
1909+
rng, a, replace, size, core_shape = params
19051910
p = None
19061911
else:
1907-
rng, a, p, replace, size = params
1912+
rng, a, p, replace, size, core_shape = params
19081913

19091914
batch_ndim = a.ndim - self.ndims_params[0]
19101915

@@ -1982,7 +1987,7 @@ class PermutationRV(RandomVariable):
19821987
_print_name = ("permutation", "\\operatorname{permutation}")
19831988

19841989
@classmethod
1985-
def rng_fn(cls, rng, x, size):
1990+
def rng_fn(cls, rng, x, size, core_shape=None):
19861991
return rng.permutation(x)
19871992

19881993
def __call__(self, x, dtype=None, **kwargs):

0 commit comments

Comments
 (0)