12
12
from numba .core .base import BaseContext
13
13
from numba .core .types .misc import NoneType
14
14
from numba .np import arrayobj
15
- from numba .np .ufunc .wrappers import _ArrayArgLoader
16
15
17
16
18
17
def compute_itershape (
@@ -158,7 +157,7 @@ def make_loop_call(
158
157
input_types : tuple [Any , ...],
159
158
output_types : tuple [Any , ...],
160
159
):
161
- # safe = (False, False)
160
+ safe = (False , False )
162
161
163
162
n_outputs = len (outputs )
164
163
@@ -183,14 +182,6 @@ def extract_array(aryty, obj):
183
182
# input_scope_set = mod.add_metadata([input_scope, output_scope])
184
183
# output_scope_set = mod.add_metadata([input_scope, output_scope])
185
184
186
- typ = input_types [0 ]
187
- inp = inputs [0 ]
188
- shape = cgutils .unpack_tuple (builder , inp .shape )
189
- strides = cgutils .unpack_tuple (builder , inp .strides )
190
- loader = _ArrayArgLoader (typ .dtype , typ .ndim , shape [- 1 ], False , shape , strides )
191
-
192
- inputs = tuple (extract_array (aryty , ary ) for aryty , ary in zip (input_types , inputs ))
193
-
194
185
outputs = tuple (
195
186
extract_array (aryty , ary ) for aryty , ary in zip (output_types , outputs )
196
187
)
@@ -221,13 +212,50 @@ def extract_array(aryty, obj):
221
212
222
213
# Load values from input arrays
223
214
input_vals = []
224
- for array_info , bc in zip (inputs , input_bc ):
225
- idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )]
226
- # ptr = cgutils.get_item_pointer2(context, builder, *array_info, idxs_bc, *safe)
227
- val = loader .load (context , builder , inp .data , idxs [0 ] or zero )
228
- # val = builder.load(ptr)
229
- # val.set_metadata("alias.scope", input_scope_set)
230
- # val.set_metadata("noalias", output_scope_set)
215
+ for input , input_type , bc in zip (inputs , input_types , input_bc ):
216
+ core_ndim = input_type .ndim - len (bc )
217
+
218
+ idxs_bc = [zero if bc else idx for idx , bc in zip (idxs , bc )] + [
219
+ zero
220
+ ] * core_ndim
221
+ ptr = cgutils .get_item_pointer2 (
222
+ context ,
223
+ builder ,
224
+ input .data ,
225
+ cgutils .unpack_tuple (builder , input .shape ),
226
+ cgutils .unpack_tuple (builder , input .strides ),
227
+ input_type .layout ,
228
+ idxs_bc ,
229
+ * safe ,
230
+ )
231
+ if core_ndim == 0 :
232
+ # Retrive scalar item at index
233
+ val = builder .load (ptr )
234
+ # val.set_metadata("alias.scope", input_scope_set)
235
+ # val.set_metadata("noalias", output_scope_set)
236
+ else :
237
+ # Retrieve array item at index
238
+ # This is a streamlined version of Numba's `GUArrayArg.load`
239
+ # TODO check layout arg!
240
+ core_arry_type = types .Array (
241
+ dtype = input_type .dtype , ndim = core_ndim , layout = input_type .layout
242
+ )
243
+ core_array = context .make_array (core_arry_type )(context , builder )
244
+ core_shape = cgutils .unpack_tuple (builder , input .shape )[- core_ndim :]
245
+ core_strides = cgutils .unpack_tuple (builder , input .strides )[- core_ndim :]
246
+ itemsize = context .get_abi_sizeof (context .get_data_type (input_type .dtype ))
247
+ context .populate_array (
248
+ core_array ,
249
+ # TODO whey do we need to bitcast?
250
+ data = builder .bitcast (ptr , core_array .data .type ),
251
+ shape = cgutils .pack_array (builder , core_shape ),
252
+ strides = cgutils .pack_array (builder , core_strides ),
253
+ itemsize = context .get_constant (types .intp , itemsize ),
254
+ # TODO what is meminfo about?
255
+ meminfo = None ,
256
+ )
257
+ val = core_array ._getvalue ()
258
+
231
259
input_vals .append (val )
232
260
233
261
inner_codegen = context .get_function (scalar_func , scalar_signature )
@@ -350,17 +378,30 @@ def _vectorized(
350
378
351
379
batch_ndim = len (input_bc_patterns [0 ])
352
380
353
- if not all (input .ndim >= batch_ndim for input in inputs ):
354
- raise TypingError ("Vectorized inputs must have the same rank." )
381
+ if not all (
382
+ len (pattern ) == batch_ndim for pattern in input_bc_patterns + output_bc_patterns
383
+ ):
384
+ raise TypingError (
385
+ "Vectorized broadcastable patterns must have the same length."
386
+ )
355
387
356
- if not all (len (pattern ) >= batch_ndim for pattern in output_bc_patterns ):
357
- raise TypingError ("Invalid output broadcasting pattern." )
388
+ core_input_types = []
389
+ for input_type , bc_pattern in zip (inputs , input_bc_patterns ):
390
+ core_ndim = input_type .ndim - len (bc_pattern )
391
+ # TODO: Reconsider this
392
+ if core_ndim == 0 :
393
+ core_input_type = input_type .dtype
394
+ else :
395
+ core_input_type = types .Array (
396
+ dtype = input_type .dtype , ndim = core_ndim , layout = input_type .layout
397
+ )
398
+ core_input_types .append (core_input_type )
358
399
359
- scalar_signature = typingctx .resolve_function_type (
400
+ core_signature = typingctx .resolve_function_type (
360
401
scalar_func ,
361
402
[
362
403
* constant_inputs ,
363
- * [ in_type . dtype if in_type . ndim == 0 else in_type for in_type in inputs ] ,
404
+ * core_input_types ,
364
405
],
365
406
{},
366
407
)
@@ -415,7 +456,7 @@ def codegen(
415
456
ctx ,
416
457
builder ,
417
458
scalar_func ,
418
- scalar_signature ,
459
+ core_signature ,
419
460
iter_shape ,
420
461
constant_inputs ,
421
462
inputs ,
0 commit comments