12
12
from numba .core .base import BaseContext
13
13
from numba .core .types .misc import NoneType
14
14
from numba .np import arrayobj
15
+ from numba .np .ufunc .wrappers import _ArrayArgLoader
15
16
16
17
17
18
def compute_itershape (
@@ -22,11 +23,11 @@ def compute_itershape(
22
23
size : list [ir .Instruction ] | None ,
23
24
):
24
25
one = ir .IntType (64 )(1 )
25
- ndim = len (in_shapes [0 ])
26
- shape = [None ] * ndim
26
+ batch_ndim = len (broadcast_pattern [0 ])
27
+ shape = [None ] * batch_ndim
27
28
if size is not None :
28
29
shape = size
29
- for i in range (ndim ):
30
+ for i in range (batch_ndim ):
30
31
for j , (bc , in_shape ) in enumerate (zip (broadcast_pattern , in_shapes )):
31
32
length = in_shape [i ]
32
33
if bc [i ]:
@@ -61,7 +62,7 @@ def compute_itershape(
61
62
)
62
63
else :
63
64
# Size is implied by the broadcast pattern
64
- for i in range (ndim ):
65
+ for i in range (batch_ndim ):
65
66
for j , (bc , in_shape ) in enumerate (zip (broadcast_pattern , in_shapes )):
66
67
length = in_shape [i ]
67
68
if bc [i ]:
@@ -96,7 +97,7 @@ def compute_itershape(
96
97
)
97
98
else :
98
99
shape [i ] = length
99
- for i in range (ndim ):
100
+ for i in range (batch_ndim ):
100
101
if shape [i ] is None :
101
102
shape [i ] = one
102
103
return shape
@@ -157,7 +158,7 @@ def make_loop_call(
157
158
input_types : tuple [Any , ...],
158
159
output_types : tuple [Any , ...],
159
160
):
160
- safe = (False , False )
161
+ # safe = (False, False)
161
162
162
163
n_outputs = len (outputs )
163
164
@@ -182,6 +183,12 @@ def extract_array(aryty, obj):
182
183
# input_scope_set = mod.add_metadata([input_scope, output_scope])
183
184
# output_scope_set = mod.add_metadata([input_scope, output_scope])
184
185
186
+ typ = input_types [0 ]
187
+ inp = inputs [0 ]
188
+ shape = cgutils .unpack_tuple (builder , inp .shape )
189
+ strides = cgutils .unpack_tuple (builder , inp .strides )
190
+ loader = _ArrayArgLoader (typ .dtype , typ .ndim , shape [- 1 ], False , shape , strides )
191
+
185
192
inputs = tuple (extract_array (aryty , ary ) for aryty , ary in zip (input_types , inputs ))
186
193
187
194
outputs = tuple (
@@ -216,8 +223,9 @@ def extract_array(aryty, obj):
216
223
input_vals = []
217
224
for array_info , bc in zip (inputs , input_bc ):
218
225
idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )]
219
- ptr = cgutils .get_item_pointer2 (context , builder , * array_info , idxs_bc , * safe )
220
- val = builder .load (ptr )
226
+ # ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
227
+ val = loader .load (context , builder , inp .data , idxs [0 ] or zero )
228
+ # val = builder.load(ptr)
221
229
# val.set_metadata("alias.scope", input_scope_set)
222
230
# val.set_metadata("noalias", output_scope_set)
223
231
input_vals .append (val )
@@ -340,16 +348,21 @@ def _vectorized(
340
348
if not all (isinstance (input , types .Array ) for input in inputs ):
341
349
raise TypingError ("Vectorized inputs must be arrays." )
342
350
343
- ndim = inputs [0 ]. ndim
351
+ batch_ndim = len ( input_bc_patterns [0 ])
344
352
345
- if not all (input .ndim == ndim for input in inputs ):
353
+ if not all (input .ndim >= batch_ndim for input in inputs ):
346
354
raise TypingError ("Vectorized inputs must have the same rank." )
347
355
348
- if not all (len (pattern ) == ndim for pattern in output_bc_patterns ):
356
+ if not all (len (pattern ) >= batch_ndim for pattern in output_bc_patterns ):
349
357
raise TypingError ("Invalid output broadcasting pattern." )
350
358
351
359
scalar_signature = typingctx .resolve_function_type (
352
- scalar_func , [* constant_inputs , * [in_type .dtype for in_type in inputs ]], {}
360
+ scalar_func ,
361
+ [
362
+ * constant_inputs ,
363
+ * [in_type .dtype if in_type .ndim == 0 else in_type for in_type in inputs ],
364
+ ],
365
+ {},
353
366
)
354
367
355
368
# So we can access the constant values in codegen...
@@ -430,7 +443,7 @@ def codegen(
430
443
)
431
444
432
445
ret_types = [
433
- types .Array (numba .from_dtype (np .dtype (dtype )), ndim , "C" )
446
+ types .Array (numba .from_dtype (np .dtype (dtype )), batch_ndim , "C" )
434
447
for dtype in output_dtypes
435
448
]
436
449
0 commit comments