Skip to content

Commit 08db524

Browse files
ricardoV94aseyboldt
andcommitted
Handle vector parameters scalar output
Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 01ffb02 commit 08db524

File tree

3 files changed

+93
-29
lines changed

3 files changed

+93
-29
lines changed

pytensor/link/numba/dispatch/random.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,11 @@ def numba_funcify_RandomVariable(op: RandomVariable, node, **kwargs):
198198
# scalar_op_fn = locals()['scalar_op_fn']
199199

200200
# @numba_basic.numba_njit
201-
# def scalar_op_fn(rng, mu, scale):
201+
# def core_op_fn(rng, mu, scale):
202202
# return rng.normal(mu, scale)
203203

204204
@numba_basic.numba_njit
205-
def scalar_op_fn(rng, p):
205+
def core_op_fn(rng, p):
206206
unif_sample = rng.uniform(0, 1)
207207
return np.searchsorted(np.cumsum(p), unif_sample)
208208

@@ -227,7 +227,7 @@ def random_wrapper(rng, size, dtype, *inputs):
227227
rng = copy(rng)
228228

229229
draws = _vectorized(
230-
scalar_op_fn,
230+
core_op_fn,
231231
input_bc_patterns_enc,
232232
output_bc_patterns_enc,
233233
output_dtypes_enc,

pytensor/link/numba/dispatch/vectorize_codegen.py

+65-24
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from numba.core.base import BaseContext
1313
from numba.core.types.misc import NoneType
1414
from numba.np import arrayobj
15-
from numba.np.ufunc.wrappers import _ArrayArgLoader
1615

1716

1817
def compute_itershape(
@@ -158,7 +157,7 @@ def make_loop_call(
158157
input_types: tuple[Any, ...],
159158
output_types: tuple[Any, ...],
160159
):
161-
# safe = (False, False)
160+
safe = (False, False)
162161

163162
n_outputs = len(outputs)
164163

@@ -183,14 +182,6 @@ def extract_array(aryty, obj):
183182
# input_scope_set = mod.add_metadata([input_scope, output_scope])
184183
# output_scope_set = mod.add_metadata([input_scope, output_scope])
185184

186-
typ = input_types[0]
187-
inp = inputs[0]
188-
shape = cgutils.unpack_tuple(builder, inp.shape)
189-
strides = cgutils.unpack_tuple(builder, inp.strides)
190-
loader = _ArrayArgLoader(typ.dtype, typ.ndim, shape[-1], False, shape, strides)
191-
192-
inputs = tuple(extract_array(aryty, ary) for aryty, ary in zip(input_types, inputs))
193-
194185
outputs = tuple(
195186
extract_array(aryty, ary) for aryty, ary in zip(output_types, outputs)
196187
)
@@ -221,13 +212,50 @@ def extract_array(aryty, obj):
221212

222213
# Load values from input arrays
223214
input_vals = []
224-
for array_info, bc in zip(inputs, input_bc):
225-
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)]
226-
# ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
227-
val = loader.load(context, builder, inp.data, idxs[0] or zero)
228-
# val = builder.load(ptr)
229-
# val.set_metadata("alias.scope", input_scope_set)
230-
# val.set_metadata("noalias", output_scope_set)
215+
for input, input_type, bc in zip(inputs, input_types, input_bc):
216+
core_ndim = input_type.ndim - len(bc)
217+
218+
idxs_bc = [zero if bc else idx for idx, bc in zip(idxs, bc)] + [
219+
zero
220+
] * core_ndim
221+
ptr = cgutils.get_item_pointer2(
222+
context,
223+
builder,
224+
input.data,
225+
cgutils.unpack_tuple(builder, input.shape),
226+
cgutils.unpack_tuple(builder, input.strides),
227+
input_type.layout,
228+
idxs_bc,
229+
*safe,
230+
)
231+
if core_ndim == 0:
232+
# Retrive scalar item at index
233+
val = builder.load(ptr)
234+
# val.set_metadata("alias.scope", input_scope_set)
235+
# val.set_metadata("noalias", output_scope_set)
236+
else:
237+
# Retrieve array item at index
238+
# This is a streamlined version of Numba's `GUArrayArg.load`
239+
# TODO check layout arg!
240+
core_arry_type = types.Array(
241+
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
242+
)
243+
core_array = context.make_array(core_arry_type)(context, builder)
244+
core_shape = cgutils.unpack_tuple(builder, input.shape)[-core_ndim:]
245+
core_strides = cgutils.unpack_tuple(builder, input.strides)[-core_ndim:]
246+
itemsize = context.get_abi_sizeof(context.get_data_type(input_type.dtype))
247+
context.populate_array(
248+
core_array,
249+
# TODO whey do we need to bitcast?
250+
data=builder.bitcast(ptr, core_array.data.type),
251+
shape=cgutils.pack_array(builder, core_shape),
252+
strides=cgutils.pack_array(builder, core_strides),
253+
itemsize=context.get_constant(types.intp, itemsize),
254+
# TODO what is meminfo about?
255+
meminfo=None,
256+
)
257+
val = core_array._getvalue()
258+
231259
input_vals.append(val)
232260

233261
inner_codegen = context.get_function(scalar_func, scalar_signature)
@@ -350,17 +378,30 @@ def _vectorized(
350378

351379
batch_ndim = len(input_bc_patterns[0])
352380

353-
if not all(input.ndim >= batch_ndim for input in inputs):
354-
raise TypingError("Vectorized inputs must have the same rank.")
381+
if not all(
382+
len(pattern) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
383+
):
384+
raise TypingError(
385+
"Vectorized broadcastable patterns must have the same length."
386+
)
355387

356-
if not all(len(pattern) >= batch_ndim for pattern in output_bc_patterns):
357-
raise TypingError("Invalid output broadcasting pattern.")
388+
core_input_types = []
389+
for input_type, bc_pattern in zip(inputs, input_bc_patterns):
390+
core_ndim = input_type.ndim - len(bc_pattern)
391+
# TODO: Reconsider this
392+
if core_ndim == 0:
393+
core_input_type = input_type.dtype
394+
else:
395+
core_input_type = types.Array(
396+
dtype=input_type.dtype, ndim=core_ndim, layout=input_type.layout
397+
)
398+
core_input_types.append(core_input_type)
358399

359-
scalar_signature = typingctx.resolve_function_type(
400+
core_signature = typingctx.resolve_function_type(
360401
scalar_func,
361402
[
362403
*constant_inputs,
363-
*[in_type.dtype if in_type.ndim == 0 else in_type for in_type in inputs],
404+
*core_input_types,
364405
],
365406
{},
366407
)
@@ -415,7 +456,7 @@ def codegen(
415456
ctx,
416457
builder,
417458
scalar_func,
418-
scalar_signature,
459+
core_signature,
419460
iter_shape,
420461
constant_inputs,
421462
inputs,

tests/link/numba/test_random.py

+25-2
Original file line numberDiff line numberDiff line change
@@ -645,7 +645,30 @@ def test_rng_non_default_update():
645645

646646

647647
def test_categorical_rv():
648-
x = pt.random.categorical(p=[[0.5, 0, 0, 0.5], [0, 0.5, 0.5, 0]], size=(2,))
648+
p = np.array(
649+
[
650+
[
651+
[1.0, 0, 0, 0],
652+
[0.0, 1.0, 0, 0],
653+
[0.0, 0, 1.0, 0],
654+
],
655+
[
656+
[0, 0, 0, 1.0],
657+
[0, 0, 0, 1.0],
658+
[0, 0, 0, 1.0],
659+
],
660+
]
661+
)
662+
x = pt.random.categorical(p=p, size=None)
649663
updates = {x.owner.inputs[0]: x.owner.outputs[0]}
650664
fn = function([], x, updates=updates, mode="NUMBA")
651-
print([fn() for _ in range(50)])
665+
res = fn()
666+
assert np.all(np.argmax(p, axis=-1) == res)
667+
668+
# Batch size
669+
x = pt.random.categorical(p=p[None], size=(3, *p.shape[:-1]))
670+
fn = function([], x, updates=updates, mode="NUMBA")
671+
new_res = fn()
672+
assert new_res.shape == (3, *res.shape)
673+
for new_res_row in new_res:
674+
assert np.all(new_res_row == res)

0 commit comments

Comments
 (0)