Skip to content

Commit ca09602

Browse files
committed
Minor numba Scan tweaks
1 parent 1364e31 commit ca09602

File tree

1 file changed

+7
-5
lines changed
  • pytensor/link/numba/dispatch

1 file changed

+7
-5
lines changed

pytensor/link/numba/dispatch/scan.py

+7-5
Original file line numberDiff line numberDiff line change
@@ -222,14 +222,16 @@ def add_output_storage_post_proc_stmt(
222222
# the storage array.
223223
# This is needed when the output storage array does not have a length
224224
# equal to the number of taps plus `n_steps`.
225+
# If the storage size only allows one entry, there's nothing to rotate
225226
output_storage_post_proc_stmts.append(
226227
dedent(
227228
f"""
228-
if (i + {tap_size}) > {storage_size}:
229+
if 1 < {storage_size} < (i + {tap_size}):
229230
{outer_in_name}_shift = (i + {tap_size}) % ({storage_size})
230-
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
231-
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
232-
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
231+
if {outer_in_name}_shift > 0:
232+
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
233+
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
234+
{outer_in_name} = np.concatenate(({outer_in_name}_right, {outer_in_name}_left))
233235
"""
234236
).strip()
235237
)
@@ -417,4 +419,4 @@ def scan({", ".join(outer_in_names)}):
417419

418420
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})
419421

420-
return numba_basic.numba_njit(scan_op_fn)
422+
return numba_basic.numba_njit(scan_op_fn, boundscheck=False)

0 commit comments

Comments
 (0)