@@ -443,6 +443,7 @@ def _vectorize_bc(
443
443
scalar_func ,
444
444
input_bc_patterns ,
445
445
output_bc_patterns ,
446
+ output_dtypes ,
446
447
boundscheck = False ,
447
448
noalias_outputs = False ,
448
449
):
@@ -484,8 +485,9 @@ def codegen(context, builder, signature, args):
484
485
shape = cgutils .unpack_tuple (builder , iter_shape )
485
486
486
487
# Lower the code of the scalar function so that we can use it in the inner loop
488
+ # Caching is set to false to avoid a numba bug TODO ref?
487
489
inner = context .compile_subroutine (
488
- builder , scalar_func , scalar_signature
490
+ builder , scalar_func , scalar_signature , caching = False ,
489
491
).fndesc
490
492
491
493
# Extract shape and stride information from the array.
@@ -546,9 +548,15 @@ def extract_array(aryty, ary):
546
548
547
549
# Call scalar function
548
550
output_values = context .call_internal (
549
- builder , inner , scalar_signature , input_vals
551
+ builder ,
552
+ inner ,
553
+ scalar_signature ,
554
+ input_vals ,
550
555
)
551
- output_values = cgutils .unpack_tuple (builder , output_values )
556
+ if isinstance (scalar_signature .return_type , types .Tuple ):
557
+ output_values = cgutils .unpack_tuple (builder , output_values )
558
+ else :
559
+ output_values = [output_values ]
552
560
553
561
# Update output value or accumulators respectively
554
562
for i , ((accu , _ ), value ) in enumerate (
@@ -614,9 +622,6 @@ def impl_vectorized(*inputs):
614
622
615
623
iter_shape_repeated = tuple ([iter_shape_template [:] for _ in range (n_outputs )])
616
624
617
- # TODO Infer from signature
618
- output_dtypes = (np .float64 ,) * n_outputs
619
-
620
625
@numba .extending .register_jitable
621
626
def make_output (iter_shape , bc , dtype ):
622
627
shape = iter_shape
@@ -684,19 +689,29 @@ def numba_funcify_Elemwise(op, node, **kwargs):
684
689
685
690
assert not op .inplace_pattern
686
691
687
- @register_jitable
688
- def wrapper (in1 , in2 ):
689
- return (scalar_op_fn (in1 , in2 ),)
692
+ #scalar_wrapper = register_jitable(scalar_op_fn)
693
+ scalar_wrapper = scalar_op_fn
690
694
691
695
ndim = node .outputs [0 ].ndim
692
696
output_bc_patterns = tuple ([(False ,) * ndim for _ in node .outputs ])
693
697
input_bc_patterns = tuple ([input_var .broadcastable for input_var in node .inputs ])
694
698
695
- vectorized = _vectorize_bc (wrapper , input_bc_patterns , output_bc_patterns )
699
+ vectorized = _vectorize_bc (
700
+ scalar_wrapper ,
701
+ input_bc_patterns ,
702
+ output_bc_patterns ,
703
+ output_dtypes = tuple ([
704
+ variable .dtype
705
+ for variable in node .outputs
706
+ ]),
707
+ )
696
708
697
- @numba_njit
698
- def elemwise_wrapper (in1 , in2 ):
699
- return vectorized (in1 , in2 )[0 ]
709
+ if len (node .outputs ) == 1 :
710
+ @numba_njit
711
+ def elemwise_wrapper (* inputs ):
712
+ return vectorized (* inputs )[0 ]
713
+ else :
714
+ elemwise_wrapper = vectorized
700
715
701
716
return elemwise_wrapper
702
717
0 commit comments