@@ -447,11 +447,13 @@ def _vectorize_bc(
447
447
noalias_outputs = False ,
448
448
):
449
449
450
- flags = {
450
+ flags = True
451
+ {
451
452
"arcp" , # Allow Reciprocal
452
453
"contract" , # Allow floating-point contraction
453
454
"afn" , # Approximate functions
454
455
"reassoc" ,
456
+ "nsz" , # TODO Do we want this one?
455
457
}
456
458
457
459
n_inputs = len (input_bc_patterns )
@@ -473,6 +475,9 @@ def loop_call(typingctx, *args):
473
475
sig = types .void (types .StarArgTuple ([* out_types , * in_types , iter_shape_type ]))
474
476
475
477
def codegen (context , builder , signature , args ):
478
+ for i in [0 ]:
479
+ arg = builder .function .args [i ]
480
+ arg .add_attribute ("noalias" )
476
481
safe = (boundscheck , False )
477
482
[args ] = args
478
483
args = cgutils .unpack_tuple (builder , args )
@@ -485,9 +490,10 @@ def codegen(context, builder, signature, args):
485
490
486
491
# Lower the code of the scalar function so that we can use it in the inner loop
487
492
# Caching is set to false to avoid a numba bug TODO ref?
488
- inner = context .compile_subroutine (
493
+ inner_func = context .compile_subroutine (
489
494
builder , scalar_func , scalar_signature , caching = False ,
490
- ).fndesc
495
+ )
496
+ inner = inner_func .fndesc
491
497
492
498
# Extract shape and stride information from the array.
493
499
# For later use in the loop body to do the indexing
@@ -499,13 +505,15 @@ def extract_array(aryty, ary):
499
505
layout = aryty .layout
500
506
return (data , shape , strides , layout )
501
507
502
- mod = builder .module
503
- domain = mod .add_metadata ([], self_ref = True )
504
- input_scope = mod .add_metadata ([domain ], self_ref = True )
505
- output_scope = mod .add_metadata ([domain ], self_ref = True )
506
- input_scope_set = mod .add_metadata ([input_scope , output_scope ])
507
-
508
- output_scope_set = mod .add_metadata ([input_scope , output_scope ])
508
+ # TODO I think this is better than the noalias attribute
509
+ # for the input, but self_ref isn't supported in a released
510
+ # llvmlite version yet
511
+ #mod = builder.module
512
+ #domain = mod.add_metadata([], self_ref=True)
513
+ #input_scope = mod.add_metadata([domain], self_ref=True)
514
+ #output_scope = mod.add_metadata([domain], self_ref=True)
515
+ #input_scope_set = mod.add_metadata([input_scope, output_scope])
516
+ #output_scope_set = mod.add_metadata([input_scope, output_scope])
509
517
510
518
inputs = [
511
519
extract_array (aryty , ary )
@@ -551,8 +559,8 @@ def extract_array(aryty, ary):
551
559
context , builder , * array_info , idxs_bc , * safe
552
560
)
553
561
val = builder .load (ptr )
554
- val .set_metadata ("alias.scope" , input_scope_set )
555
- val .set_metadata ("noalias" , output_scope_set )
562
+ # val.set_metadata("alias.scope", input_scope_set)
563
+ # val.set_metadata("noalias", output_scope_set)
556
564
input_vals .append (val )
557
565
558
566
# Call scalar function
@@ -572,8 +580,14 @@ def extract_array(aryty, ary):
572
580
zip (output_accumulator , output_values , strict = True )
573
581
):
574
582
if accu is not None :
575
- new_value = builder .fadd (builder .load (accu ), value )
576
- builder .store (new_value , accu )
583
+ load = builder .load (accu )
584
+ #load.set_metadata("alias.scope", output_scope_set)
585
+ #load.set_metadata("noalias", input_scope_set)
586
+ new_value = builder .fadd (load , value )
587
+ store = builder .store (new_value , accu )
588
+ # TODO ?
589
+ #store.set_metadata("alias.scope", output_scope_set)
590
+ #store.set_metadata("noalias", input_scope_set)
577
591
else :
578
592
idxs_bc = [
579
593
zero if bc else idx
@@ -582,9 +596,10 @@ def extract_array(aryty, ary):
582
596
ptr = cgutils .get_item_pointer2 (
583
597
context , builder , * outputs [i ], idxs_bc
584
598
)
585
- store = builder .store (value , ptr )
586
- store .set_metadata ("alias.scope" , output_scope_set )
587
- store .set_metadata ("noalias" , input_scope_set )
599
+ #store = builder.store(value, ptr)
600
+ store = arrayobj .store_item (context , builder , out_types [i ], value , ptr )
601
+ #store.set_metadata("alias.scope", output_scope_set)
602
+ #store.set_metadata("noalias", input_scope_set)
588
603
589
604
# Close the loops and write accumulator values to the output arrays
590
605
for depth , loop in enumerate (loop_stack [::- 1 ]):
@@ -599,16 +614,20 @@ def extract_array(aryty, ary):
599
614
ptr = cgutils .get_item_pointer2 (
600
615
context , builder , * outputs [output ], idxs_bc
601
616
)
602
- store = builder .store (builder .load (accu ), ptr )
603
- store .set_metadata ("alias.scope" , output_scope_set )
604
- store .set_metadata ("noalias" , input_scope_set )
617
+ load = builder .load (accu )
618
+ #load.set_metadata("alias.scope", output_scope_set)
619
+ #load.set_metadata("noalias", input_scope_set)
620
+ #store = builder.store(load, ptr)
621
+ store = arrayobj .store_item (context , builder , out_types [output ], load , ptr )
622
+ #store.set_metadata("alias.scope", output_scope_set)
623
+ #store.set_metadata("noalias", input_scope_set)
605
624
loop .__exit__ (None , None , None )
606
625
return
607
626
608
627
return sig , codegen
609
628
610
629
def vectorized (* inputs ):
611
- pass
630
+ raise NotImplementedError ()
612
631
613
632
@numba .extending .overload (vectorized , jit_options = {"fastmath" : flags })
614
633
def impl_vectorized (* inputs ):
@@ -635,17 +654,32 @@ def impl_vectorized(*inputs):
635
654
636
655
iter_shape_repeated = tuple ([iter_shape_template [:] for _ in range (n_outputs )])
637
656
638
- @numba .extending .register_jitable
639
- def make_output (iter_shape , bc , dtype ):
640
- shape = iter_shape
641
- for i in range (ndim ):
642
- if bc [i ]:
643
- shape = tuple_setitem (
644
- shape ,
645
- i ,
646
- 1 ,
647
- )
648
- return np .empty (shape , dtype )
657
+ ndim_range = tuple (range (ndim ))
658
+
659
+ if ndim > 0 :
660
+ # TODO workaround for https://github.com/numba/numba/issues/8654
661
+ @numba .extending .register_jitable
662
+ def make_output (iter_shape , bc , dtype ):
663
+ shape = iter_shape
664
+ for i in literal_unroll (ndim_range ):
665
+ if bc [i ]:
666
+ shape = tuple_setitem (
667
+ shape ,
668
+ i ,
669
+ 1 ,
670
+ )
671
+ return np .empty (shape , dtype )
672
+
673
+ check_arrays = check_broadcasting
674
+ else :
675
+ @numba .extending .register_jitable
676
+ def make_output (iter_shape , bc , dtype ):
677
+ return np .empty ((), dtype )
678
+
679
+ @numba .extending .register_jitable
680
+ def check_arrays (a , b , c ):
681
+ pass
682
+
649
683
650
684
make_outputs = tuple_mapper (make_output )
651
685
@@ -667,8 +701,6 @@ def impl(*inputs):
667
701
)
668
702
669
703
outputs = make_outputs (iter_shape_rep , output_bc_patterns , output_dtypes )
670
- #outputs = (np.empty(inputs[0].shape),)
671
- #iter_shape = inputs[0].shape
672
704
673
705
i = 0
674
706
for input_ in literal_unroll (inputs ):
@@ -704,21 +736,24 @@ def numba_funcify_Elemwise(op, node, **kwargs):
704
736
scalar_inputs = [scalar (dtype = input .dtype ) for input in node .inputs ]
705
737
scalar_node = op .scalar_op .make_node (* scalar_inputs )
706
738
739
+ flags = True
740
+ {
741
+ "arcp" , # Allow Reciprocal
742
+ "contract" , # Allow floating-point contraction
743
+ "afn" , # Approximate functions
744
+ "reassoc" ,
745
+ }
746
+
707
747
scalar_op_fn = numba_funcify (
708
- op .scalar_op , node = scalar_node , parent_node = node , ** kwargs
748
+ op .scalar_op , node = scalar_node , parent_node = node , fastmath = flags , ** kwargs
709
749
)
710
750
711
- assert not op .inplace_pattern
712
-
713
- #scalar_wrapper = register_jitable(scalar_op_fn)
714
- scalar_wrapper = scalar_op_fn
715
-
716
751
ndim = node .outputs [0 ].ndim
717
752
output_bc_patterns = tuple ([(False ,) * ndim for _ in node .outputs ])
718
753
input_bc_patterns = tuple ([input_var .broadcastable for input_var in node .inputs ])
719
754
720
755
vectorized = _vectorize_bc (
721
- scalar_wrapper ,
756
+ scalar_op_fn ,
722
757
input_bc_patterns ,
723
758
output_bc_patterns ,
724
759
output_dtypes = tuple ([
@@ -727,10 +762,22 @@ def numba_funcify_Elemwise(op, node, **kwargs):
727
762
]),
728
763
)
729
764
765
+ # TODO We should do this in vectorize instead
766
+ if op .inplace_pattern :
767
+ pattern = list (op .inplace_pattern .items ())
768
+
769
+ @numba_njit
770
+ def elemwise_inplace (* inputs ):
771
+ outputs = vectorized (* inputs )
772
+ for out_idx , in_idx in literal_unroll (pattern ):
773
+ inputs [in_idx ][...] = outputs [out_idx ]
774
+ else :
775
+ elemwise_inplace = vectorized
776
+
730
777
if len (node .outputs ) == 1 :
731
778
@numba_njit
732
779
def elemwise_wrapper (* inputs ):
733
- return vectorized (* inputs )[0 ]
780
+ return elemwise_inplace (* inputs )[0 ]
734
781
else :
735
782
elemwise_wrapper = vectorized
736
783
0 commit comments