Skip to content

Commit 36df379

Browse files
aseyboldtricardoV94
authored andcommitted
fix(numba): Correlty report the elemwise output type
1 parent 849c3b8 commit 36df379

File tree

2 files changed

+25
-6
lines changed

2 files changed

+25
-6
lines changed

pytensor/link/numba/dispatch/elemwise.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -604,12 +604,16 @@ def codegen(
604604
builder, sig.return_type, [out._getvalue() for out in outputs]
605605
)
606606

607-
ret_type = types.Tuple(
608-
[
609-
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
610-
for dtype in output_dtypes
611-
]
612-
)
607+
ret_types = [
608+
types.Array(numba.from_dtype(np.dtype(dtype)), ndim, "C")
609+
for dtype in output_dtypes
610+
]
611+
612+
for output_idx, input_idx in inplace_pattern:
613+
ret_types[output_idx] = input_types[input_idx]
614+
615+
ret_type = types.Tuple(ret_types)
616+
613617
if len(output_dtypes) == 1:
614618
ret_type = ret_type.types[0]
615619
sig = ret_type(*arg_types)

tests/link/numba/test_elemwise.py

+15
Original file line numberDiff line numberDiff line change
@@ -605,3 +605,18 @@ def test_fused_elemwise_benchmark(benchmark):
605605
# JIT compile first
606606
func()
607607
benchmark(func)
608+
609+
610+
def test_elemwise_out_type():
611+
# Create a graph with an elemwise
612+
# Ravel failes if the elemwise output type is reported incorrectly
613+
x = at.matrix()
614+
y = (2 * x).ravel()
615+
616+
# Pass in the input as mutable, to trigger the inplace rewrites
617+
func = pytensor.function([pytensor.In(x, mutable=True)], y, mode="NUMBA")
618+
619+
# Apply it to a numpy array that is neither C or F contigous
620+
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
621+
622+
assert func(x_val).shape == (18,)

0 commit comments

Comments
 (0)