Skip to content

Commit a5b92ac

Browse files
ricardoV94jessegrabowskiaseyboldt
committed
Adapt Elemwise iterator for Numba Generators
Co-authored-by: Jesse Grabowski <[email protected]> Co-authored-by: Adrian Seyboldt <[email protected]>
1 parent 13807b4 commit a5b92ac

File tree

2 files changed

+324
-157
lines changed

2 files changed

+324
-157
lines changed

pytensor/link/numba/dispatch/elemwise.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from pytensor.link.numba.dispatch.vectorize_codegen import (
2424
_vectorized,
2525
encode_literals,
26+
store_core_outputs,
2627
)
2728
from pytensor.link.utils import compile_function_src, get_name_for_object
2829
from pytensor.scalar.basic import (
@@ -483,10 +484,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
483484
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
484485
)
485486

487+
nin = len(node.inputs)
488+
nout = len(node.outputs)
489+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
490+
486491
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
487492
output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])
488493
output_dtypes = tuple(out.type.dtype for out in node.outputs)
489494
inplace_pattern = tuple(op.inplace_pattern.items())
495+
core_output_shapes = tuple(() for _ in range(nout))
490496

491497
# numba doesn't support nested literals right now...
492498
input_bc_patterns_enc = encode_literals(input_bc_patterns)
@@ -496,12 +502,15 @@ def numba_funcify_Elemwise(op, node, **kwargs):
496502

497503
def elemwise_wrapper(*inputs):
498504
return _vectorized(
499-
scalar_op_fn,
505+
core_op_fn,
500506
input_bc_patterns_enc,
501507
output_bc_patterns_enc,
502508
output_dtypes_enc,
503509
inplace_pattern_enc,
510+
(), # constant_inputs
504511
inputs,
512+
core_output_shapes, # core_shapes
513+
None, # size
505514
)
506515

507516
# Pure python implementation, that will be used in tests

0 commit comments

Comments
 (0)