Skip to content

Commit 41c4081

Browse files
committed
Deal with single output scalar_ops
1 parent 84b14b8 commit 41c4081

File tree

1 file changed

+28
-13
lines changed

1 file changed

+28
-13
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 28 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,7 @@ def _vectorize_bc(
443443
scalar_func,
444444
input_bc_patterns,
445445
output_bc_patterns,
446+
output_dtypes,
446447
boundscheck=False,
447448
noalias_outputs=False,
448449
):
@@ -484,8 +485,9 @@ def codegen(context, builder, signature, args):
484485
shape = cgutils.unpack_tuple(builder, iter_shape)
485486

486487
# Lower the code of the scalar function so that we can use it in the inner loop
488+
# Caching is set to false to avoid a numba bug TODO ref?
487489
inner = context.compile_subroutine(
488-
builder, scalar_func, scalar_signature
490+
builder, scalar_func, scalar_signature, caching=False,
489491
).fndesc
490492

491493
# Extract shape and stride information from the array.
@@ -546,9 +548,15 @@ def extract_array(aryty, ary):
546548

547549
# Call scalar function
548550
output_values = context.call_internal(
549-
builder, inner, scalar_signature, input_vals
551+
builder,
552+
inner,
553+
scalar_signature,
554+
input_vals,
550555
)
551-
output_values = cgutils.unpack_tuple(builder, output_values)
556+
if isinstance(scalar_signature.return_type, types.Tuple):
557+
output_values = cgutils.unpack_tuple(builder, output_values)
558+
else:
559+
output_values = [output_values]
552560

553561
# Update output value or accumulators respectively
554562
for i, ((accu, _), value) in enumerate(
@@ -614,9 +622,6 @@ def impl_vectorized(*inputs):
614622

615623
iter_shape_repeated = tuple([iter_shape_template[:] for _ in range(n_outputs)])
616624

617-
# TODO Infer from signature
618-
output_dtypes = (np.float64,) * n_outputs
619-
620625
@numba.extending.register_jitable
621626
def make_output(iter_shape, bc, dtype):
622627
shape = iter_shape
@@ -684,19 +689,29 @@ def numba_funcify_Elemwise(op, node, **kwargs):
684689

685690
assert not op.inplace_pattern
686691

687-
@register_jitable
688-
def wrapper(in1, in2):
689-
return (scalar_op_fn(in1, in2),)
692+
#scalar_wrapper = register_jitable(scalar_op_fn)
693+
scalar_wrapper = scalar_op_fn
690694

691695
ndim = node.outputs[0].ndim
692696
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
693697
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
694698

695-
vectorized = _vectorize_bc(wrapper, input_bc_patterns, output_bc_patterns)
699+
vectorized = _vectorize_bc(
700+
scalar_wrapper,
701+
input_bc_patterns,
702+
output_bc_patterns,
703+
output_dtypes=tuple([
704+
variable.dtype
705+
for variable in node.outputs
706+
]),
707+
)
696708

697-
@numba_njit
698-
def elemwise_wrapper(in1, in2):
699-
return vectorized(in1, in2)[0]
709+
if len(node.outputs) == 1:
710+
@numba_njit
711+
def elemwise_wrapper(*inputs):
712+
return vectorized(*inputs)[0]
713+
else:
714+
elemwise_wrapper = vectorized
700715

701716
return elemwise_wrapper
702717

0 commit comments

Comments
 (0)