Skip to content

Commit 769809e

Browse files
ricardoV94aseyboldt
andcommitted
WIP: Establish code to handle vector outputs
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 08db524 commit 769809e

File tree

4 files changed

+196
-158
lines changed

4 files changed

+196
-158
lines changed

pytensor/link/numba/dispatch/elemwise.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -492,9 +492,12 @@ def numba_funcify_Elemwise(op, node, **kwargs):
492492
scalar_op_fn = numba_funcify(
493493
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
494494
)
495+
# TODO: Implement me
496+
# scalar_op_fn = save_outputs_of_scalar_fn(op, scalar_op_fn)
495497

496498
ndim = node.outputs[0].ndim
497-
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
499+
nout = len(node.outputs)
500+
output_bc_patterns = tuple([(False,) * ndim for _ in nout])
498501
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
499502
output_dtypes = tuple(variable.dtype for variable in node.outputs)
500503
inplace_pattern = tuple(op.inplace_pattern.items())
@@ -516,6 +519,7 @@ def elemwise_wrapper(*inputs):
516519
inplace_pattern_enc,
517520
(), # constant_inputs
518521
inputs,
522+
[() for _ in range(nout)], # core_shapes
519523
None, # size
520524
)
521525

pytensor/link/numba/dispatch/random.py

+65-54
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pickle
33
from collections.abc import Callable
44
from copy import copy
5+
from functools import singledispatch
56
from textwrap import dedent, indent
67
from typing import Any
78

@@ -168,25 +169,10 @@ def impl(rng):
168169
return impl
169170

170171

171-
@numba_funcify.register(ptr.RandomVariable)
172-
def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
173-
_, size, _, *args = node.inputs
174-
# None sizes are represented as empty tuple for the time being
175-
# https://github.com/pymc-devs/pytensor/issues/568
176-
[size_len] = size.type.shape
177-
size_is_None = size_len == 0
178-
179-
inplace = op.inplace
180-
181-
if op.ndim_supp > 0:
182-
raise NotImplementedError("Multivariate random variables not supported yet")
183-
184-
# if any(ndim_param > 0 for ndim_param in op.ndims_params):
185-
# raise NotImplementedError(
186-
# "Random variables with non scalar core inputs not supported yet"
187-
# )
172+
@singledispatch
173+
def core_rv_fn(op: Op):
174+
"""Return the core function for a random variable operation."""
188175

189-
# TODO: Use dispatch, so users can define the core case
190176
# Use string repr for default like below
191177
# inner_code = dedent(f"""
192178
# @numba_basic.numba_njit
@@ -197,15 +183,67 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
197183
# exec(inner_code)
198184
# scalar_op_fn = locals()['scalar_op_fn']
199185

200-
# @numba_basic.numba_njit
201-
# def core_op_fn(rng, mu, scale):
202-
# return rng.normal(mu, scale)
186+
raise NotImplementedError()
187+
203188

189+
@core_rv_fn.register(ptr.NormalRV)
190+
def core_NormalRV(op):
204191
@numba_basic.numba_njit
205-
def core_op_fn(rng, p):
192+
def random_fn(rng, mu, scale, out):
193+
out[...] = rng.normal(mu, scale)
194+
195+
random_fn.handles_out = True
196+
return random_fn
197+
198+
199+
@core_rv_fn.register(ptr.CategoricalRV)
200+
def core_CategoricalRV(op):
201+
@numba_basic.numba_njit
202+
def random_fn(rng, p, out):
206203
unif_sample = rng.uniform(0, 1)
207-
return np.searchsorted(np.cumsum(p), unif_sample)
204+
# TODO: Check if LLVM can lift constant cumsum(p) out of the loop
205+
out[...] = np.searchsorted(np.cumsum(p), unif_sample)
206+
207+
random_fn.handles_out = True
208+
return random_fn
209+
210+
211+
@core_rv_fn.register(ptr.MvNormalRV)
212+
def core_MvNormalRV(op):
213+
@numba.njit
214+
def random_fn(rng, mean, cov, out):
215+
chol = np.linalg.cholesky(cov)
216+
stdnorm = rng.normal(size=cov.shape[-1])
217+
# np.dot(chol, stdnorm, out=out)
218+
# out[...] += mean
219+
out[...] = mean + np.dot(chol, stdnorm)
208220

221+
random_fn.handles_out = True
222+
return random_fn
223+
224+
225+
@numba_funcify.register(ptr.RandomVariable)
226+
def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
227+
_, size, _, *args = node.inputs
228+
# None sizes are represented as empty tuple for the time being
229+
# https://github.com/pymc-devs/pytensor/issues/568
230+
[size_len] = size.type.shape
231+
size_is_None = size_len == 0
232+
233+
inplace = op.inplace
234+
235+
# TODO: Add core_shape to node.inputs
236+
if op.ndim_supp > 0:
237+
raise NotImplementedError("Multivariate RandomVariable not implemented yet")
238+
239+
# TODO: Create a wrapper (string processing?) that takes a core function without outputs
240+
# and saves those outputs in the variables passed by `_vectorized`
241+
core_op_fn = core_rv_fn(op)
242+
if not getattr(core_op_fn, "handles_out", False):
243+
# core_op_fn = store_core_outputs(op, core_op_fn)
244+
raise NotImplementedError()
245+
246+
# TODO: Refactor this code, it's the same with Elemwise
209247
batch_ndim = node.default_output().ndim - op.ndim_supp
210248
output_bc_patterns = ((False,) * batch_ndim,)
211249
input_bc_patterns = tuple(
@@ -234,12 +272,14 @@ def random_wrapper(rng, size, dtype, *inputs):
234272
inplace_pattern_enc,
235273
(rng,),
236274
inputs,
237-
None if size_is_None else numba_ndarray.to_fixed_tuple(size, size_len),
275+
((),), # TODO: correct core_shapes
276+
None
277+
if size_is_None
278+
else numba_ndarray.to_fixed_tuple(size, size_len), # size
238279
)
239280
return rng, draws
240281

241282
def random(rng, size, dtype, *inputs):
242-
# TODO: Add code that will be tested for coverage
243283
pass
244284

245285
@overload(random)
@@ -330,35 +370,6 @@ def body_fn(a):
330370
)
331371

332372

333-
# @numba_funcify.register(ptr.CategoricalRV)
334-
def numba_funcify_CategoricalRV(op, node, **kwargs):
335-
out_dtype = node.outputs[1].type.numpy_dtype
336-
size_len = int(get_vector_length(node.inputs[1]))
337-
p_ndim = node.inputs[-1].ndim
338-
339-
@numba_basic.numba_njit
340-
def categorical_rv(rng, size, dtype, p):
341-
if not size_len:
342-
size_tpl = p.shape[:-1]
343-
else:
344-
size_tpl = numba_ndarray.to_fixed_tuple(size, size_len)
345-
p = np.broadcast_to(p, size_tpl + p.shape[-1:])
346-
347-
# Workaround https://github.com/numba/numba/issues/8975
348-
if not size_len and p_ndim == 1:
349-
unif_samples = np.asarray(np.random.uniform(0, 1))
350-
else:
351-
unif_samples = np.random.uniform(0, 1, size_tpl)
352-
353-
res = np.empty(size_tpl, dtype=out_dtype)
354-
for idx in np.ndindex(*size_tpl):
355-
res[idx] = np.searchsorted(np.cumsum(p[idx]), unif_samples[idx])
356-
357-
return (rng, res)
358-
359-
return categorical_rv
360-
361-
362373
@numba_funcify.register(ptr.DirichletRV)
363374
def numba_funcify_DirichletRV(op, node, **kwargs):
364375
out_dtype = node.outputs[1].type.numpy_dtype

0 commit comments

Comments
 (0)