Skip to content

Commit 04580b8

Browse files
committed
Run black
1 parent 02ea662 commit 04580b8

File tree

1 file changed

+36
-29
lines changed

1 file changed

+36
-29
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -447,7 +447,7 @@ def _vectorize_bc(
447447
noalias_outputs=False,
448448
):
449449

450-
flags = True
450+
flags = True
451451
{
452452
"arcp", # Allow Reciprocal
453453
"contract", # Allow floating-point contraction
@@ -491,7 +491,10 @@ def codegen(context, builder, signature, args):
491491
# Lower the code of the scalar function so that we can use it in the inner loop
492492
# Caching is set to false to avoid a numba bug TODO ref?
493493
inner_func = context.compile_subroutine(
494-
builder, scalar_func, scalar_signature, caching=False,
494+
builder,
495+
scalar_func,
496+
scalar_signature,
497+
caching=False,
495498
)
496499
inner = inner_func.fndesc
497500

@@ -508,12 +511,12 @@ def extract_array(aryty, ary):
508511
# TODO I think this is better than the noalias attribute
509512
# for the input, but self_ref isn't supported in a released
510513
# llvmlite version yet
511-
#mod = builder.module
512-
#domain = mod.add_metadata([], self_ref=True)
513-
#input_scope = mod.add_metadata([domain], self_ref=True)
514-
#output_scope = mod.add_metadata([domain], self_ref=True)
515-
#input_scope_set = mod.add_metadata([input_scope, output_scope])
516-
#output_scope_set = mod.add_metadata([input_scope, output_scope])
514+
# mod = builder.module
515+
# domain = mod.add_metadata([], self_ref=True)
516+
# input_scope = mod.add_metadata([domain], self_ref=True)
517+
# output_scope = mod.add_metadata([domain], self_ref=True)
518+
# input_scope_set = mod.add_metadata([input_scope, output_scope])
519+
# output_scope_set = mod.add_metadata([input_scope, output_scope])
517520

518521
inputs = [
519522
extract_array(aryty, ary)
@@ -559,8 +562,8 @@ def extract_array(aryty, ary):
559562
context, builder, *array_info, idxs_bc, *safe
560563
)
561564
val = builder.load(ptr)
562-
#val.set_metadata("alias.scope", input_scope_set)
563-
#val.set_metadata("noalias", output_scope_set)
565+
# val.set_metadata("alias.scope", input_scope_set)
566+
# val.set_metadata("noalias", output_scope_set)
564567
input_vals.append(val)
565568

566569
# Call scalar function
@@ -581,13 +584,13 @@ def extract_array(aryty, ary):
581584
):
582585
if accu is not None:
583586
load = builder.load(accu)
584-
#load.set_metadata("alias.scope", output_scope_set)
585-
#load.set_metadata("noalias", input_scope_set)
587+
# load.set_metadata("alias.scope", output_scope_set)
588+
# load.set_metadata("noalias", input_scope_set)
586589
new_value = builder.fadd(load, value)
587590
store = builder.store(new_value, accu)
588591
# TODO ?
589-
#store.set_metadata("alias.scope", output_scope_set)
590-
#store.set_metadata("noalias", input_scope_set)
592+
# store.set_metadata("alias.scope", output_scope_set)
593+
# store.set_metadata("noalias", input_scope_set)
591594
else:
592595
idxs_bc = [
593596
zero if bc else idx
@@ -596,10 +599,12 @@ def extract_array(aryty, ary):
596599
ptr = cgutils.get_item_pointer2(
597600
context, builder, *outputs[i], idxs_bc
598601
)
599-
#store = builder.store(value, ptr)
600-
store = arrayobj.store_item(context, builder, out_types[i], value, ptr)
601-
#store.set_metadata("alias.scope", output_scope_set)
602-
#store.set_metadata("noalias", input_scope_set)
602+
# store = builder.store(value, ptr)
603+
store = arrayobj.store_item(
604+
context, builder, out_types[i], value, ptr
605+
)
606+
# store.set_metadata("alias.scope", output_scope_set)
607+
# store.set_metadata("noalias", input_scope_set)
603608

604609
# Close the loops and write accumulator values to the output arrays
605610
for depth, loop in enumerate(loop_stack[::-1]):
@@ -615,12 +620,14 @@ def extract_array(aryty, ary):
615620
context, builder, *outputs[output], idxs_bc
616621
)
617622
load = builder.load(accu)
618-
#load.set_metadata("alias.scope", output_scope_set)
619-
#load.set_metadata("noalias", input_scope_set)
620-
#store = builder.store(load, ptr)
621-
store = arrayobj.store_item(context, builder, out_types[output], load, ptr)
622-
#store.set_metadata("alias.scope", output_scope_set)
623-
#store.set_metadata("noalias", input_scope_set)
623+
# load.set_metadata("alias.scope", output_scope_set)
624+
# load.set_metadata("noalias", input_scope_set)
625+
# store = builder.store(load, ptr)
626+
store = arrayobj.store_item(
627+
context, builder, out_types[output], load, ptr
628+
)
629+
# store.set_metadata("alias.scope", output_scope_set)
630+
# store.set_metadata("noalias", input_scope_set)
624631
loop.__exit__(None, None, None)
625632
return
626633

@@ -672,6 +679,7 @@ def make_output(iter_shape, bc, dtype):
672679

673680
check_arrays = check_broadcasting
674681
else:
682+
675683
@numba.extending.register_jitable
676684
def make_output(iter_shape, bc, dtype):
677685
return np.empty((), dtype)
@@ -680,7 +688,6 @@ def make_output(iter_shape, bc, dtype):
680688
def check_arrays(a, b, c):
681689
pass
682690

683-
684691
make_outputs = tuple_mapper(make_output)
685692

686693
def impl(*inputs):
@@ -756,10 +763,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
756763
scalar_op_fn,
757764
input_bc_patterns,
758765
output_bc_patterns,
759-
output_dtypes=tuple([
760-
variable.dtype
761-
for variable in node.outputs
762-
]),
766+
output_dtypes=tuple([variable.dtype for variable in node.outputs]),
763767
)
764768

765769
# TODO We should do this in vectorize instead
@@ -771,13 +775,16 @@ def elemwise_inplace(*inputs):
771775
outputs = vectorized(*inputs)
772776
for out_idx, in_idx in literal_unroll(pattern):
773777
inputs[in_idx][...] = outputs[out_idx]
778+
774779
else:
775780
elemwise_inplace = vectorized
776781

777782
if len(node.outputs) == 1:
783+
778784
@numba_njit
779785
def elemwise_wrapper(*inputs):
780786
return elemwise_inplace(*inputs)[0]
787+
781788
else:
782789
elemwise_wrapper = vectorized
783790

0 commit comments

Comments
 (0)