Skip to content

Commit d3aa283

Browse files
ricardoV94jessegrabowskiaseyboldt
committed
Adapt Numba vectorize iterator for RandomVariables
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 18dcf62 commit d3aa283

File tree

3 files changed

+332
-159
lines changed

3 files changed

+332
-159
lines changed

pytensor/link/numba/dispatch/basic.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,16 @@ def numba_njit(*args, **kwargs):
6262
kwargs.setdefault("no_cpython_wrapper", True)
6363
kwargs.setdefault("no_cfunc_wrapper", True)
6464

65-
# Supress caching warnings
65+
# Suppress cache warning for internal functions
66+
# We have to add an ansi escape code for optional bold text by numba
6667
warnings.filterwarnings(
6768
"ignore",
68-
message='Cannot cache compiled function "numba_funcified_fgraph" as it uses dynamic globals',
69+
message=(
70+
"(\x1b\\[1m)*" # ansi escape code for bold text
71+
"Cannot cache compiled function "
72+
'"(numba_funcified_fgraph|store_core_outputs)" '
73+
"as it uses dynamic globals"
74+
),
6975
category=NumbaWarning,
7076
)
7177

pytensor/link/numba/dispatch/elemwise.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
_jit_options,
2525
_vectorized,
2626
encode_literals,
27+
store_core_outputs,
2728
)
2829
from pytensor.link.utils import compile_function_src, get_name_for_object
2930
from pytensor.scalar.basic import (
@@ -480,10 +481,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
480481
**kwargs,
481482
)
482483

484+
nin = len(node.inputs)
485+
nout = len(node.outputs)
486+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
487+
483488
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
484489
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
485490
output_dtypes = tuple(out.type.dtype for out in node.outputs)
486491
inplace_pattern = tuple(op.inplace_pattern.items())
492+
core_output_shapes = tuple(() for _ in range(nout))
487493

488494
# numba doesn't support nested literals right now...
489495
input_bc_patterns_enc = encode_literals(input_bc_patterns)
@@ -493,12 +499,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
493499

494500
def elemwise_wrapper(*inputs):
495501
return _vectorized(
496-
scalar_op_fn,
502+
core_op_fn,
497503
input_bc_patterns_enc,
498504
output_bc_patterns_enc,
499505
output_dtypes_enc,
500506
inplace_pattern_enc,
507+
(), # constant_inputs
501508
inputs,
509+
core_output_shapes, # core_shapes
510+
None, # size
502511
)
503512

504513
# Pure python implementation, that will be used in tests

0 commit comments

Comments
 (0)