1
+ import functools
1
2
from typing import List , Tuple
2
3
3
4
import numpy as np
4
5
5
- from pytensor import Variable , as_symbolic
6
+ from pytensor import Variable , as_symbolic , clone_replace
6
7
from pytensor .graph import FunctionGraph
8
+ from pytensor .graph .basic import Constant , truncated_graph_inputs
7
9
from pytensor .loop .op import Scan
8
10
from pytensor .scan .utils import until
9
- from pytensor .tensor import as_tensor , empty_like
11
+ from pytensor .tensor import as_tensor , constant , empty_like , minimum
10
12
11
13
12
14
def scan (
@@ -20,6 +22,8 @@ def scan(
20
22
if sequences is None and n_steps is None :
21
23
raise ValueError ("Must provide n_steps when scanning without sequences" )
22
24
25
+ # TODO: init_states should be made opaque to the inner function,
26
+ # since any relationship to the outer graph no longer holds
23
27
if init_states is None :
24
28
init_states = []
25
29
else :
@@ -34,20 +38,31 @@ def scan(
34
38
sequences = [sequences ]
35
39
sequences = [as_tensor (s ) for s in sequences ]
36
40
41
+ if sequences :
42
+ leading_dims = [seq .shape [0 ] for seq in sequences ]
43
+ shortest_dim = functools .reduce (minimum , leading_dims )
44
+ if n_steps is None :
45
+ n_steps = shortest_dim
46
+ else :
47
+ n_steps = minimum (n_steps , shortest_dim )
48
+
37
49
if non_sequences is None :
38
50
non_sequences = []
39
51
else :
40
52
if not isinstance (non_sequences , (tuple , list )):
41
53
non_sequences = [non_sequences ]
42
54
non_sequences = [as_symbolic (n ) for n in non_sequences ]
43
55
56
+ # Create subsequence inputs for the inner function
57
+ idx = constant (0 , dtype = "int64" , name = "idx" )
58
+ symbolic_idx = idx .type (name = "idx" )
59
+ subsequences = [s [symbolic_idx ] for s in sequences ]
44
60
# Note: Old scan order is sequences + init + non_sequences
45
- inner_sequences = [s [0 ] for s in sequences ]
46
- inner_inputs = [i .type () for i in init_states + inner_sequences + non_sequences ]
47
- inner_outputs = fn (* inner_inputs )
48
- if not isinstance (inner_outputs , (tuple , list )):
49
- inner_outputs = [inner_outputs ]
50
- next_states = [out for out in inner_outputs if not isinstance (out , until )]
61
+ fn_inputs = init_states + subsequences + non_sequences
62
+ fn_outputs = fn (* fn_inputs )
63
+ if not isinstance (fn_outputs , (tuple , list )):
64
+ fn_outputs = [fn_outputs ]
65
+ next_states = [out for out in fn_outputs if not isinstance (out , until )]
51
66
52
67
if len (next_states ) > len (init_states ):
53
68
if not init_states :
@@ -61,27 +76,43 @@ def scan(
61
76
prev_states = []
62
77
for i , (init_state , next_state ) in enumerate (zip (init_states , next_states )):
63
78
if init_state is None :
79
+ # next_state may reference idx, let's replace that by the initial value
80
+ [next_state ] = clone_replace (
81
+ output = [next_state ], replace = {symbolic_idx : idx }
82
+ )
64
83
init_state = empty_like (next_state )
65
84
init_state .name = "empty_init_state"
66
- inner_inputs .insert (i , init_state .type ())
67
85
prev_states .append (init_state )
68
86
69
- until_condition = [out .condition for out in inner_outputs if isinstance (out , until )]
87
+ until_condition = [out .condition for out in fn_outputs if isinstance (out , until )]
70
88
if not until_condition :
71
89
until_condition = [as_tensor (np .array (True ))]
72
90
if len (until_condition ) > 1 :
73
91
raise ValueError ("Only one until condition can be returned" )
74
92
75
- update_fg = FunctionGraph (
76
- inputs = inner_inputs , outputs = until_condition + next_states
93
+ fgraph_inputs = [symbolic_idx ] + prev_states + sequences + non_sequences
94
+ fgraph_outputs = until_condition + [symbolic_idx + 1 ] + next_states
95
+
96
+ all_fgraph_inputs = truncated_graph_inputs (
97
+ fgraph_outputs , ancestors_to_include = fgraph_inputs
98
+ )
99
+ extra_fgraph_inputs = [
100
+ inp
101
+ for inp in all_fgraph_inputs
102
+ if (not isinstance (inp , Constant ) and inp not in fgraph_inputs )
103
+ ]
104
+ fgraph_inputs = fgraph_inputs + extra_fgraph_inputs
105
+ update_fg = FunctionGraph (inputs = fgraph_inputs , outputs = fgraph_outputs )
106
+
107
+ scan_op = Scan (update_fg = update_fg )
108
+ scan_outs = scan_op (
109
+ n_steps , idx , * prev_states , * sequences , * non_sequences , * extra_fgraph_inputs
77
110
)
78
- scan_op = Scan (update_fg = update_fg , n_sequences = len (sequences ))
79
- scan_outs = scan_op (n_steps , * prev_states , * sequences , * non_sequences )
80
111
assert isinstance (scan_outs , list )
81
112
last_states = scan_outs [: scan_op .n_states ]
82
113
traces = scan_outs [scan_op .n_states :]
83
-
84
- return last_states , traces
114
+ # Don't return the inner index state
115
+ return last_states [ 1 :] , traces [ 1 :]
85
116
86
117
87
118
def map (
0 commit comments