@@ -112,25 +112,46 @@ def numba_funcify_Scan(op, node, **kwargs):
112
112
# Inner-inputs are ordered as follows:
113
113
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
114
114
# shared-inputs + non-sequences.
115
+ temp_scalar_storage_alloc_stmts : List [str ] = []
116
+ inner_in_exprs_scalar : List [str ] = []
115
117
inner_in_exprs : List [str ] = []
116
118
117
119
def add_inner_in_expr (
118
- outer_in_name : str , tap_offset : Optional [int ], storage_size_var : Optional [str ]
120
+ outer_in_name : str ,
121
+ tap_offset : Optional [int ],
122
+ storage_size_var : Optional [str ],
123
+ vector_slice_opt : bool ,
119
124
):
120
125
"""Construct an inner-input expression."""
121
126
storage_name = outer_in_to_storage_name .get (outer_in_name , outer_in_name )
122
- indexed_inner_in_str = (
123
- storage_name
124
- if tap_offset is None
125
- else idx_to_str (
126
- storage_name , tap_offset , size = storage_size_var , allow_scalar = False
127
+ if vector_slice_opt :
128
+ indexed_inner_in_str_scalar = idx_to_str (
129
+ storage_name , tap_offset , size = storage_size_var , allow_scalar = True
130
+ )
131
+ temp_storage = f"{ storage_name } _temp_scalar_{ tap_offset } "
132
+ storage_dtype = outer_in_var .type .numpy_dtype .name
133
+ temp_scalar_storage_alloc_stmts .append (
134
+ f"{ temp_storage } = np.empty((), dtype=np.{ storage_dtype } )"
135
+ )
136
+ inner_in_exprs_scalar .append (
137
+ f"{ temp_storage } [()] = { indexed_inner_in_str_scalar } "
138
+ )
139
+ indexed_inner_in_str = temp_storage
140
+ else :
141
+ indexed_inner_in_str = (
142
+ storage_name
143
+ if tap_offset is None
144
+ else idx_to_str (
145
+ storage_name , tap_offset , size = storage_size_var , allow_scalar = False
146
+ )
127
147
)
128
- )
129
148
inner_in_exprs .append (indexed_inner_in_str )
130
149
131
150
for outer_in_name in outer_in_seqs_names :
132
151
# These outer-inputs are indexed without offsets or storage wrap-around
133
- add_inner_in_expr (outer_in_name , 0 , None )
152
+ outer_in_var = outer_in_names_to_vars [outer_in_name ]
153
+ is_vector = outer_in_var .ndim == 1
154
+ add_inner_in_expr (outer_in_name , 0 , None , vector_slice_opt = is_vector )
134
155
135
156
inner_in_names_to_input_taps : Dict [str , Tuple [int , ...]] = dict (
136
157
zip (
@@ -232,7 +253,13 @@ def add_output_storage_post_proc_stmt(
232
253
for in_tap in input_taps :
233
254
tap_offset = in_tap + tap_storage_size
234
255
assert tap_offset >= 0
235
- add_inner_in_expr (outer_in_name , tap_offset , storage_size_name )
256
+ is_vector = outer_in_var .ndim == 1
257
+ add_inner_in_expr (
258
+ outer_in_name ,
259
+ tap_offset ,
260
+ storage_size_name ,
261
+ vector_slice_opt = is_vector ,
262
+ )
236
263
237
264
output_taps = inner_in_names_to_output_taps .get (
238
265
outer_in_name , [tap_storage_size ]
@@ -253,7 +280,7 @@ def add_output_storage_post_proc_stmt(
253
280
254
281
else :
255
282
storage_size_stmt = ""
256
- add_inner_in_expr (outer_in_name , None , None )
283
+ add_inner_in_expr (outer_in_name , None , None , vector_slice_opt = False )
257
284
inner_out_to_outer_in_stmts .append (storage_name )
258
285
259
286
output_idx = outer_output_names .index (storage_name )
@@ -325,17 +352,19 @@ def add_output_storage_post_proc_stmt(
325
352
)
326
353
327
354
for name in outer_in_non_seqs_names :
328
- add_inner_in_expr (name , None , None )
355
+ add_inner_in_expr (name , None , None , vector_slice_opt = False )
329
356
330
357
if op .info .as_while :
331
358
# The inner function will return a boolean as the last value
332
359
inner_out_to_outer_in_stmts .append ("cond" )
333
360
334
361
assert len (inner_in_exprs ) == len (op .fgraph .inputs )
335
362
363
+ inner_scalar_in_args_to_temp_storage = "\n " .join (inner_in_exprs_scalar )
336
364
inner_in_args = create_arg_string (inner_in_exprs )
337
365
inner_outputs = create_tuple_string (inner_output_names )
338
366
input_storage_block = "\n " .join (storage_alloc_stmts )
367
+ input_temp_scalar_storage_block = "\n " .join (temp_scalar_storage_alloc_stmts )
339
368
output_storage_post_processing_block = "\n " .join (output_storage_post_proc_stmts )
340
369
inner_out_post_processing_block = "\n " .join (inner_out_post_processing_stmts )
341
370
@@ -348,9 +377,13 @@ def scan({", ".join(outer_in_names)}):
348
377
349
378
{ indent (input_storage_block , " " * 4 )}
350
379
380
+ { indent (input_temp_scalar_storage_block , " " * 4 )}
381
+
351
382
i = 0
352
383
cond = np.array(False)
353
384
while i < n_steps and not cond.item():
385
+ { indent (inner_scalar_in_args_to_temp_storage , " " * 8 )}
386
+
354
387
{ inner_outputs } = scan_inner_func({ inner_in_args } )
355
388
{ indent (inner_out_post_processing_block , " " * 8 )}
356
389
{ indent (inner_out_to_outer_out_stmts , " " * 8 )}
0 commit comments