7
7
from pytensor .graph import Apply , FunctionGraph , Op , Type , node_rewriter
8
8
from pytensor .graph .rewriting .basic import in2out
9
9
from pytensor .scalar import constant
10
- from pytensor .tensor import (
11
- NoneConst ,
12
- add ,
13
- and_ ,
14
- empty ,
15
- get_scalar_constant_value ,
16
- set_subtensor ,
17
- )
10
+ from pytensor .tensor import add , and_ , empty , get_scalar_constant_value , set_subtensor
18
11
from pytensor .tensor .exceptions import NotScalarConstantError
19
12
from pytensor .tensor .shape import Shape_i
13
+ from pytensor .tensor .subtensor import Subtensor , get_idx_list
20
14
from pytensor .tensor .type import DenseTensorType , TensorType
21
15
from pytensor .tensor .type_other import NoneTypeT
16
+ from pytensor .typed_list import GetItem , TypedListType , append , make_empty_list
22
17
23
18
24
19
def validate_loop_update_types (update ):
@@ -176,8 +171,7 @@ def __init__(
176
171
)
177
172
)
178
173
else :
179
- # We can't concatenate all types of states, such as RandomTypes
180
- self .trace_types .append (NoneConst .type )
174
+ self .trace_types .append (TypedListType (state_type ))
181
175
182
176
self .constant_types = [inp .type for inp in update_fg .inputs [self .n_states :]]
183
177
self .n_constants = len (self .constant_types )
@@ -312,10 +306,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
312
306
if fgraph .clients [trace ]
313
307
]
314
308
315
- # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
316
- for trace_idx in used_traces_idxs :
317
- assert not isinstance (old_states [trace_idx ].type , NoneTypeT )
318
-
319
309
# Inputs to the new Loop
320
310
max_iters = node .inputs [0 ]
321
311
init_states = node .inputs [1 : 1 + op .n_states ]
@@ -324,6 +314,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
324
314
(max_iters , * tuple (init_states [trace_idx ].shape )),
325
315
dtype = init_states [trace_idx ].dtype ,
326
316
)
317
+ if isinstance (init_states [trace_idx ].type , DenseTensorType )
318
+ else make_empty_list (init_states [trace_idx ].type )
327
319
for trace_idx in used_traces_idxs
328
320
]
329
321
constants = node .inputs [1 + op .n_states :]
@@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
387
379
inner_while_cond , * inner_next_states = update_fg .outputs
388
380
inner_next_traces = [
389
381
set_subtensor (prev_trace [inner_idx ], inner_next_states [trace_idx ])
382
+ if isinstance (prev_trace .type , DenseTensorType )
383
+ else append (prev_trace , inner_next_states [trace_idx ])
390
384
for trace_idx , prev_trace in zip (used_traces_idxs , inner_traces )
391
385
]
392
386
for t in inner_next_traces :
@@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
429
423
replacements = dict (zip (old_states , new_states ))
430
424
for trace_idx , new_trace in zip (used_traces_idxs , new_traces ):
431
425
# If there is no while condition, the whole trace will be used
432
- if op .has_while_condition :
426
+ if op .has_while_condition and isinstance ( new_trace . type , DenseTensorType ) :
433
427
new_trace = new_trace [:final_idx ]
434
428
replacements [old_traces [trace_idx ]] = new_trace
435
429
return replacements
@@ -446,3 +440,39 @@ def scan(fn, idx, initial_states, constants, max_iters):
446
440
"not_jax" ,
447
441
position = 1.0 ,
448
442
)
443
+
444
+
445
+ @node_rewriter ([Scan ])
446
+ def scan_view_last_state (fgraph , node ):
447
+ """Replace trace[-1] by the last state output of a Scan node"""
448
+ replacements = {}
449
+ for final_state , trace in zip (
450
+ node .outputs [: node .op .n_states ], node .outputs [node .op .n_states :]
451
+ ):
452
+ clients = fgraph .clients [trace ]
453
+ for client , _ in clients :
454
+ if client == "output" :
455
+ continue
456
+ if isinstance (client .op , (Subtensor , GetItem )):
457
+ if isinstance (client .op , Subtensor ):
458
+ idxs = get_idx_list (client .inputs , client .op .idx_list )
459
+ if len (idxs ) == 1 :
460
+ idx = idxs [0 ]
461
+ else :
462
+ idx = client .inputs [1 ]
463
+ try :
464
+ last_index = get_scalar_constant_value (idx ) == - 1
465
+ except NotScalarConstantError :
466
+ continue
467
+ if last_index :
468
+ replacements [client .default_output ()] = final_state
469
+ return replacements
470
+
471
+
472
+ optdb .register (
473
+ "scan_view_last_state" ,
474
+ in2out (scan_view_last_state ),
475
+ "fast_compile" ,
476
+ "fast_run" ,
477
+ position = 0.999 ,
478
+ )
0 commit comments