File tree 1 file changed +7
-5
lines changed
pytensor/link/numba/dispatch
1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -222,14 +222,16 @@ def add_output_storage_post_proc_stmt(
222
222
# the storage array.
223
223
# This is needed when the output storage array does not have a length
224
224
# equal to the number of taps plus `n_steps`.
225
+ # If the storage size only allows one entry, there's nothing to rotate
225
226
output_storage_post_proc_stmts .append (
226
227
dedent (
227
228
f"""
228
- if (i + { tap_size } ) > { storage_size } :
229
+ if 1 < { storage_size } < (i + { tap_size } ):
229
230
{ outer_in_name } _shift = (i + { tap_size } ) % ({ storage_size } )
230
- { outer_in_name } _left = { outer_in_name } [:{ outer_in_name } _shift]
231
- { outer_in_name } _right = { outer_in_name } [{ outer_in_name } _shift:]
232
- { outer_in_name } = np.concatenate(({ outer_in_name } _right, { outer_in_name } _left))
231
+ if { outer_in_name } _shift > 0:
232
+ { outer_in_name } _left = { outer_in_name } [:{ outer_in_name } _shift]
233
+ { outer_in_name } _right = { outer_in_name } [{ outer_in_name } _shift:]
234
+ { outer_in_name } = np.concatenate(({ outer_in_name } _right, { outer_in_name } _left))
233
235
"""
234
236
).strip ()
235
237
)
@@ -417,4 +419,4 @@ def scan({", ".join(outer_in_names)}):
417
419
418
420
scan_op_fn = compile_function_src (scan_op_src , "scan" , {** globals (), ** global_env })
419
421
420
- return numba_basic .numba_njit (scan_op_fn )
422
+ return numba_basic .numba_njit (scan_op_fn , boundscheck = False )
You can’t perform that action at this time.
0 commit comments