7
7
from pytensor .graph import Apply , FunctionGraph , Op , Type , node_rewriter
8
8
from pytensor .graph .rewriting .basic import in2out
9
9
from pytensor .scalar import constant
10
- from pytensor .tensor import (
11
- NoneConst ,
12
- add ,
13
- and_ ,
14
- empty ,
15
- get_scalar_constant_value ,
16
- set_subtensor ,
17
- )
10
+ from pytensor .tensor import add , and_ , empty , get_scalar_constant_value , set_subtensor
18
11
from pytensor .tensor .exceptions import NotScalarConstantError
19
12
from pytensor .tensor .shape import Shape_i
20
13
from pytensor .tensor .type import DenseTensorType , TensorType
21
14
from pytensor .tensor .type_other import NoneTypeT
15
+ from pytensor .typed_list import TypedListType , append , make_empty_list
22
16
23
17
24
18
def validate_loop_update_types (update ):
@@ -176,8 +170,7 @@ def __init__(
176
170
)
177
171
)
178
172
else :
179
- # We can't concatenate all types of states, such as RandomTypes
180
- self .trace_types .append (NoneConst .type )
173
+ self .trace_types .append (TypedListType (state_type ))
181
174
182
175
self .constant_types = [inp .type for inp in update_fg .inputs [self .n_states :]]
183
176
self .n_constants = len (self .constant_types )
@@ -312,10 +305,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
312
305
if fgraph .clients [trace ]
313
306
]
314
307
315
- # Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
316
- for trace_idx in used_traces_idxs :
317
- assert not isinstance (old_states [trace_idx ].type , NoneTypeT )
318
-
319
308
# Inputs to the new Loop
320
309
max_iters = node .inputs [0 ]
321
310
init_states = node .inputs [1 : 1 + op .n_states ]
@@ -324,6 +313,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
324
313
(max_iters , * tuple (init_states [trace_idx ].shape )),
325
314
dtype = init_states [trace_idx ].dtype ,
326
315
)
316
+ if isinstance (init_states [trace_idx ].type , DenseTensorType )
317
+ else make_empty_list (init_states [trace_idx ].type )
327
318
for trace_idx in used_traces_idxs
328
319
]
329
320
constants = node .inputs [1 + op .n_states :]
@@ -376,6 +367,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
376
367
# Inner traces
377
368
inner_states = update_fg .inputs [: op .n_states ]
378
369
inner_traces = [init_trace .type () for init_trace in init_traces ]
370
+
379
371
for s , t in zip (inner_states , inner_traces ):
380
372
t .name = "trace"
381
373
if s .name :
@@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
387
379
inner_while_cond , * inner_next_states = update_fg .outputs
388
380
inner_next_traces = [
389
381
set_subtensor (prev_trace [inner_idx ], inner_next_states [trace_idx ])
382
+ if isinstance (prev_trace .type , DenseTensorType )
383
+ else append (prev_trace , inner_next_states [trace_idx ])
390
384
for trace_idx , prev_trace in zip (used_traces_idxs , inner_traces )
391
385
]
392
386
for t in inner_next_traces :
@@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
429
423
replacements = dict (zip (old_states , new_states ))
430
424
for trace_idx , new_trace in zip (used_traces_idxs , new_traces ):
431
425
# If there is no while condition, the whole trace will be used
432
- if op .has_while_condition :
426
+ if op .has_while_condition and isinstance ( new_trace . type , DenseTensorType ) :
433
427
new_trace = new_trace [:final_idx ]
434
428
replacements [old_traces [trace_idx ]] = new_trace
435
429
return replacements
0 commit comments