Skip to content

Commit 93a02c0

Browse files
committed
Unroll some loops in llvm-elemwise
1 parent 6ab30d2 commit 93a02c0

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import numba
88
import numpy as np
99
from llvmlite import ir
10-
from numba import TypingError, types
10+
from numba import TypingError, literal_unroll, types
1111
from numba.core import cgutils
1212
from numba.cpython.unsafe.tuple import tuple_setitem
1313
from numba.np import arrayobj
@@ -653,8 +653,8 @@ def impl(*inputs):
653653
iter_shape = iter_shape_template
654654
for i in range(ndim):
655655
maxval = 1
656-
for j in range(n_inputs):
657-
maxval = max(maxval, inputs[j].shape[i])
656+
for inp in literal_unroll(inputs):
657+
maxval = max(maxval, inp.shape[i])
658658

659659
iter_shape = tuple_setitem(iter_shape, i, maxval)
660660

@@ -667,12 +667,20 @@ def impl(*inputs):
667667
)
668668

669669
outputs = make_outputs(iter_shape_rep, output_bc_patterns, output_dtypes)
670+
#outputs = (np.empty(inputs[0].shape),)
671+
#iter_shape = inputs[0].shape
670672

671-
for input_, bcs in zip(inputs, input_bc_patterns):
673+
i = 0
674+
for input_ in literal_unroll(inputs):
675+
bcs = input_bc_patterns[i]
672676
check_broadcasting(input_, bcs, iter_shape)
677+
i = i + 1
673678

674-
for out, bcs in zip(outputs, output_bc_patterns):
679+
i = 0
680+
for out in literal_unroll(outputs):
681+
bcs = output_bc_patterns[i]
675682
check_broadcasting(out, bcs, iter_shape)
683+
i = i + 1
676684

677685
loop_call(*outputs, *inputs, iter_shape)
678686
return outputs

0 commit comments

Comments
 (0)