Skip to content

Commit 6ab30d2

Browse files
committed
Enable vectorization in llvm elemwise
1 parent 41c4081 commit 6ab30d2

File tree

1 file changed

+16
-3
lines changed

1 file changed

+16
-3
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from numba import TypingError, types
1111
from numba.core import cgutils
1212
from numba.cpython.unsafe.tuple import tuple_setitem
13-
from numba.extending import register_jitable
1413
from numba.np import arrayobj
1514
from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple
1615

@@ -500,6 +499,14 @@ def extract_array(aryty, ary):
500499
layout = aryty.layout
501500
return (data, shape, strides, layout)
502501

502+
mod = builder.module
503+
domain = mod.add_metadata([], self_ref=True)
504+
input_scope = mod.add_metadata([domain], self_ref=True)
505+
output_scope = mod.add_metadata([domain], self_ref=True)
506+
input_scope_set = mod.add_metadata([input_scope, output_scope])
507+
508+
output_scope_set = mod.add_metadata([input_scope, output_scope])
509+
503510
inputs = [
504511
extract_array(aryty, ary)
505512
for aryty, ary in zip(in_types, inputs, strict=True)
@@ -544,6 +551,8 @@ def extract_array(aryty, ary):
544551
context, builder, *array_info, idxs_bc, *safe
545552
)
546553
val = builder.load(ptr)
554+
val.set_metadata("alias.scope", input_scope_set)
555+
val.set_metadata("noalias", output_scope_set)
547556
input_vals.append(val)
548557

549558
# Call scalar function
@@ -573,7 +582,9 @@ def extract_array(aryty, ary):
573582
ptr = cgutils.get_item_pointer2(
574583
context, builder, *outputs[i], idxs_bc
575584
)
576-
builder.store(value, ptr)
585+
store = builder.store(value, ptr)
586+
store.set_metadata("alias.scope", output_scope_set)
587+
store.set_metadata("noalias", input_scope_set)
577588

578589
# Close the loops and write accumulator values to the output arrays
579590
for depth, loop in enumerate(loop_stack[::-1]):
@@ -588,7 +599,9 @@ def extract_array(aryty, ary):
588599
ptr = cgutils.get_item_pointer2(
589600
context, builder, *outputs[output], idxs_bc
590601
)
591-
builder.store(builder.load(accu), ptr)
602+
store = builder.store(builder.load(accu), ptr)
603+
store.set_metadata("alias.scope", output_scope_set)
604+
store.set_metadata("noalias", input_scope_set)
592605
loop.__exit__(None, None, None)
593606
return
594607

0 commit comments

Comments
 (0)