Skip to content

Commit 01ffb02

Browse files
committed
Break things to understand vector inputs
1 parent 7fda730 commit 01ffb02

File tree

3 files changed

+48
-23
lines changed

3 files changed

+48
-23
lines changed

pytensor/link/numba/dispatch/random.py

+15-10
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,10 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
181181
if op.ndim_supp > 0:
182182
raise NotImplementedError("Multivariate random variables not supported yet")
183183

184-
if any(ndim_param > 0 for ndim_param in op.ndims_params):
185-
raise NotImplementedError(
186-
"Random variables with non scalar core inputs not supported yet"
187-
)
184+
# if any(ndim_param > 0 for ndim_param in op.ndims_params):
185+
# raise NotImplementedError(
186+
# "Random variables with non scalar core inputs not supported yet"
187+
# )
188188

189189
# TODO: Use dispatch, so users can define the core case
190190
# Use string repr for default like below
@@ -197,14 +197,19 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
197197
# exec(inner_code)
198198
# scalar_op_fn = locals()['scalar_op_fn']
199199

200+
# @numba_basic.numba_njit
201+
# def scalar_op_fn(rng, mu, scale):
202+
# return rng.normal(mu, scale)
203+
200204
@numba_basic.numba_njit
201-
def scalar_op_fn(rng, mu, scale):
202-
return rng.normal(mu, scale)
205+
def scalar_op_fn(rng, p):
206+
unif_sample = rng.uniform(0, 1)
207+
return np.searchsorted(np.cumsum(p), unif_sample)
203208

204-
ndim = node.default_output().ndim
205-
output_bc_patterns = ((False,) * ndim,)
209+
batch_ndim = node.default_output().ndim - op.ndim_supp
210+
output_bc_patterns = ((False,) * batch_ndim,)
206211
input_bc_patterns = tuple(
207-
[input_var.broadcastable for input_var in node.inputs[3:]]
212+
[input_var.broadcastable[:batch_ndim] for input_var in node.inputs[3:]]
208213
)
209214
output_dtypes = (node.default_output().type.dtype,)
210215
inplace_pattern = () # tuple(op.inplace_pattern.items())
@@ -325,7 +330,7 @@ def body_fn(a):
325330
)
326331

327332

328-
@numba_funcify.register(ptr.CategoricalRV)
333+
# @numba_funcify.register(ptr.CategoricalRV)
329334
def numba_funcify_CategoricalRV(op, node, **kwargs):
330335
out_dtype = node.outputs[1].type.numpy_dtype
331336
size_len = int(get_vector_length(node.inputs[1]))

pytensor/link/numba/dispatch/vectorize_codegen.py

+26-13
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from numba.core.base import BaseContext
1313
from numba.core.types.misc import NoneType
1414
from numba.np import arrayobj
15+
from numba.np.ufunc.wrappers import _ArrayArgLoader
1516

1617

1718
def compute_itershape(
@@ -22,11 +23,11 @@ def compute_itershape(
2223
size: list[ir.Instruction] | None,
2324
):
2425
one = ir.IntType(64)(1)
25-
ndim = len(in_shapes[0])
26-
shape = [None] * ndim
26+
batch_ndim = len(broadcast_pattern[0])
27+
shape = [None] * batch_ndim
2728
if size is not None:
2829
shape = size
29-
for i in range(ndim):
30+
for i in range(batch_ndim):
3031
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
3132
length = in_shape[i]
3233
if bc[i]:
@@ -61,7 +62,7 @@ def compute_itershape(
6162
)
6263
else:
6364
# Size is implied by the broadcast pattern
64-
for i in range(ndim):
65+
for i in range(batch_ndim):
6566
for j, (bc, in_shape) in enumerate(zip(broadcast_pattern, in_shapes)):
6667
length = in_shape[i]
6768
if bc[i]:
@@ -96,7 +97,7 @@ def compute_itershape(
9697
)
9798
else:
9899
shape[i] = length
99-
for i in range(ndim):
100+
for i in range(batch_ndim):
100101
if shape[i] is None:
101102
shape[i] = one
102103
return shape
@@ -157,7 +158,7 @@ def make_loop_call(
157158
input_types: tuple[Any, ...],
158159
output_types: tuple[Any, ...],
159160
):
160-
safe = (False, False)
161+
# safe = (False, False)
161162

162163
n_outputs = len(outputs)
163164

@@ -182,6 +183,12 @@ def extract_array(aryty, obj):
182183
# input_scope_set = mod.add_metadata([input_scope, output_scope])
183184
# output_scope_set = mod.add_metadata([input_scope, output_scope])
184185

186+
typ = input_types[0]
187+
inp = inputs[0]
188+
shape = cgutils.unpack_tuple(builder, inp.shape)
189+
strides = cgutils.unpack_tuple(builder, inp.strides)
190+
loader = _ArrayArgLoader(typ.dtype, typ.ndim, shape[-1], False, shape, strides)
191+
185192
inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs))
186193

187194
outputs = tuple(
@@ -216,8 +223,9 @@ def extract_array(aryty, obj):
216223
input_vals = []
217224
for array_info, bc in zip(inputs, input_bc):
218225
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)]
219-
ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
220-
val = builder.load(ptr)
226+
# ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
227+
val = loader.load(context, builder, inp.data, idxs[0] or zero)
228+
# val = builder.load(ptr)
221229
# val.set_metadata("alias.scope", input_scope_set)
222230
# val.set_metadata("noalias", output_scope_set)
223231
input_vals.append(val)
@@ -340,16 +348,21 @@ def _vectorized(
340348
if not all(isinstance(input, types.Array) for input in inputs):
341349
raise TypingError("Vectorized inputs must be arrays.")
342350

343-
ndim = inputs[0].ndim
351+
batch_ndim = len(input_bc_patterns[0])
344352

345-
if not all(input.ndim == ndim for input in inputs):
353+
if not all(input.ndim >= batch_ndim for input in inputs):
346354
raise TypingError("Vectorized inputs must have the same rank.")
347355

348-
if not all(len(pattern) == ndim for pattern in output_bc_patterns):
356+
if not all(len(pattern) >= batch_ndim for pattern in output_bc_patterns):
349357
raise TypingError("Invalid output broadcasting pattern.")
350358

351359
scalar_signature = typingctx.resolve_function_type(
352-
scalar_func, [*constant_inputs, *[in_type.dtype for in_type in inputs]], {}
360+
scalar_func,
361+
[
362+
*constant_inputs,
363+
*[in_type.dtype if in_type.ndim == 0 else in_type for in_type in inputs],
364+
],
365+
{},
353366
)
354367

355368
# So we can access the constant values in codegen...
@@ -430,7 +443,7 @@ def codegen(
430443
)
431444

432445
ret_types = [
433-
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
446+
types.Array(numba.from_dtype(np.dtype(dtype)), batch_ndim, "C")
434447
for dtype in output_dtypes
435448
]
436449

tests/link/numba/test_random.py

+7
Original file line numberDiff line numberDiff line change
@@ -642,3 +642,10 @@ def test_rng_non_default_update():
642642
ref = np.random.default_rng(2).normal(size=10)
643643
np.testing.assert_allclose(fn(), ref)
644644
np.testing.assert_allclose(fn(), ref)
645+
646+
647+
def test_categorical_rv():
648+
x = pt.random.categorical(p=[[0.5, 0, 0, 0.5], [0, 0.5, 0.5, 0]], size=(2,))
649+
updates = {x.owner.inputs[0]: x.owner.outputs[0]}
650+
fn = function([], x, updates=updates, mode="NUMBA")
651+
print([fn() for _ in range(50)])

0 commit comments

Comments
 (0)