7
7
import numba
8
8
import numpy as np
9
9
from llvmlite import ir
10
- from numba import TypingError , types
10
+ from numba import TypingError , literal_unroll , types
11
11
from numba .core import cgutils
12
12
from numba .cpython .unsafe .tuple import tuple_setitem
13
13
from numba .np import arrayobj
@@ -653,8 +653,8 @@ def impl(*inputs):
653
653
iter_shape = iter_shape_template
654
654
for i in range (ndim ):
655
655
maxval = 1
656
- for j in range ( n_inputs ):
657
- maxval = max (maxval , inputs [ j ] .shape [i ])
656
+ for inp in literal_unroll ( inputs ):
657
+ maxval = max (maxval , inp .shape [i ])
658
658
659
659
iter_shape = tuple_setitem (iter_shape , i , maxval )
660
660
@@ -667,12 +667,20 @@ def impl(*inputs):
667
667
)
668
668
669
669
outputs = make_outputs (iter_shape_rep , output_bc_patterns , output_dtypes )
670
+ #outputs = (np.empty(inputs[0].shape),)
671
+ #iter_shape = inputs[0].shape
670
672
671
- for input_ , bcs in zip (inputs , input_bc_patterns ):
673
+ i = 0
674
+ for input_ in literal_unroll (inputs ):
675
+ bcs = input_bc_patterns [i ]
672
676
check_broadcasting (input_ , bcs , iter_shape )
677
+ i = i + 1
673
678
674
- for out , bcs in zip (outputs , output_bc_patterns ):
679
+ i = 0
680
+ for out in literal_unroll (outputs ):
681
+ bcs = output_bc_patterns [i ]
675
682
check_broadcasting (out , bcs , iter_shape )
683
+ i = i + 1
676
684
677
685
loop_call (* outputs , * inputs , iter_shape )
678
686
return outputs
0 commit comments