Skip to content

Commit 2f27cea

Browse files
committed
Implement new Loop and Scan operators
1 parent b8831aa commit 2f27cea

File tree

4 files changed

+338
-0
lines changed

4 files changed

+338
-0
lines changed

pytensor/loop/__init__.py

Whitespace-only changes.

pytensor/loop/op.py

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
from typing import Optional
2+
3+
import numpy as np
4+
5+
from pytensor import In, Out
6+
from pytensor.compile import optdb, pfunc
7+
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter
8+
from pytensor.graph.rewriting.basic import in2out
9+
from pytensor.scalar import constant
10+
from pytensor.tensor import NoneConst, and_, empty, scalar, set_subtensor
11+
from pytensor.tensor.type import DenseTensorType, TensorType
12+
from pytensor.tensor.type_other import NoneTypeT
13+
14+
15+
def validate_loop_update_types(update):
16+
assert update.outputs[0].type.dtype == "bool"
17+
for input_state, output_state in zip(update.inputs, update.outputs[1:]):
18+
assert input_state.type == output_state.type
19+
20+
21+
class Loop(Op):
22+
"""Represent a do-while loop."""
23+
24+
def __init__(
25+
self,
26+
update: FunctionGraph, # (*state, *consts) -> (bool, *state)
27+
reverse: Optional[FunctionGraph] = None,
28+
):
29+
validate_loop_update_types(update)
30+
self.state_types = [out.type for out in update.outputs[1:]]
31+
self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]]
32+
self.update = update
33+
self.reverse = reverse
34+
self._update_fn = None
35+
36+
@property
37+
def update_fn(self):
38+
"""Lazily compile the inner update function graph."""
39+
if self._update_fn is not None:
40+
return self._update_fn
41+
42+
fgraph = self.update
43+
wrapped_inputs = [In(x, borrow=True) for x in fgraph.inputs]
44+
wrapped_outputs = [Out(x, borrow=False) for x in fgraph.outputs]
45+
46+
self._update_fn = pfunc(
47+
wrapped_inputs,
48+
wrapped_outputs,
49+
mode="FAST_RUN", # TODO: Figure this out
50+
accept_inplace=False,
51+
on_unused_input="ignore",
52+
fgraph=fgraph,
53+
)
54+
return self._update_fn
55+
56+
def make_node(self, *inputs):
57+
assert len(inputs) == len(self.state_types) + len(self.const_types)
58+
59+
states = inputs[: len(self.state_types)]
60+
states = [
61+
inp_type.filter_variable(inp)
62+
for inp_type, inp in zip(self.state_types, states)
63+
]
64+
65+
consts = inputs[len(self.state_types) :]
66+
consts = [
67+
inp_type.filter_variable(inp)
68+
for inp_type, inp in zip(self.const_types, consts)
69+
]
70+
71+
return Apply(
72+
self,
73+
[*states, *consts],
74+
[state_type() for state_type in self.state_types],
75+
)
76+
77+
def infer_shape(self, fgraph, node, input_shapes):
78+
return input_shapes[: len(self.state_types)]
79+
80+
def perform(self, node, inputs, output_storage):
81+
update_fn = self.update_fn
82+
83+
states = inputs[: len(self.state_types)]
84+
consts = inputs[len(self.state_types) :]
85+
while True:
86+
go_on, *states = update_fn(*states, *consts)
87+
if not go_on:
88+
break
89+
90+
for i, state in enumerate(states):
91+
output_storage[i][0] = state
92+
93+
def L_Op(self, *args):
94+
if not self.reverse:
95+
raise NotImplementedError()
96+
# Use L_Op of self.reverse.update
97+
...
98+
99+
def R_Op(self, *args):
100+
# Use R_op of self.update
101+
...
102+
103+
104+
class Scan(Op):
105+
"""Represent a scan.
106+
107+
This Op must always be converted to a Loop during compilation
108+
"""
109+
110+
def __init__(
111+
self,
112+
update: FunctionGraph, # (*state, *consts) -> (bool, *state)
113+
reverse: Optional[FunctionGraph] = None,
114+
):
115+
validate_loop_update_types(update)
116+
self.state_types = [out.type for out in update.outputs[1:]]
117+
self.trace_types: list[Type] = []
118+
for state_type in self.state_types:
119+
# Accommodate SparseTensors and Scalars
120+
if isinstance(state_type, DenseTensorType):
121+
self.trace_types.append(
122+
DenseTensorType(
123+
shape=(None, *state_type.shape), dtype=state_type.dtype
124+
)
125+
)
126+
else:
127+
# We can't concatenate all types of states, such as RandomTypes
128+
self.trace_types.append(NoneConst.type)
129+
self.const_types = [inp.type for inp in update.inputs[len(self.state_types) :]]
130+
self.update = update
131+
self.reverse = reverse
132+
self._update_fn = None
133+
134+
def make_node(self, n_steps, *inputs):
135+
assert len(inputs) == len(self.state_types) + len(self.const_types)
136+
137+
n_steps = TensorType(dtype="int64", shape=()).filter_variable(n_steps)
138+
139+
states = inputs[: len(self.state_types)]
140+
states = [
141+
inp_type.filter_variable(inp)
142+
for inp_type, inp in zip(self.state_types, states)
143+
]
144+
145+
consts = inputs[len(self.state_types) :]
146+
consts = [
147+
inp_type.filter_variable(inp)
148+
for inp_type, inp in zip(self.const_types, consts)
149+
]
150+
151+
return Apply(
152+
self,
153+
[n_steps, *states, *consts],
154+
[output_type() for output_type in self.state_types + self.trace_types],
155+
)
156+
157+
def infer_shape(self, fgraph, node, input_shapes):
158+
n_steps = node.inputs[0]
159+
state_shapes = input_shapes[1 : len(self.state_types) + 1]
160+
trace_shapes = [
161+
(n_steps, *state_shape) if state_shape is not None else None
162+
for state_shape in state_shapes
163+
]
164+
return state_shapes + trace_shapes
165+
166+
def perform(self, node, inputs, output_storage):
167+
raise RuntimeError("Loop Op should not be present in compiled graph")
168+
169+
def L_op(self, *args):
170+
# Use trace outputs
171+
...
172+
173+
def R_op(self, *args):
174+
# Use R_op of self.update
175+
...
176+
177+
178+
@node_rewriter([Scan])
179+
def scan_to_loop(fgraph, node):
180+
"""Rewrite a Scan Op into a Loop Op"""
181+
op: Scan = node.op # type: ignore
182+
183+
n_steps = node.inputs[0]
184+
185+
n_state_vars = len(op.state_types)
186+
old_states = node.outputs[:n_state_vars]
187+
old_traces = node.outputs[n_state_vars:]
188+
189+
# Only include the intermediate states that are used elsewhere
190+
used_traces_idxs = [
191+
i
192+
for i, trace in enumerate(node.outputs[n_state_vars:])
193+
if fgraph.clients[trace]
194+
]
195+
196+
# Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
197+
for trace_idx in used_traces_idxs:
198+
assert not isinstance(old_states[trace_idx].type, NoneTypeT)
199+
200+
update_fg = op.update
201+
prev_idx = scalar(dtype="int64", name="prev_idx")
202+
prev_states = update_fg.inputs[:n_state_vars]
203+
prev_traces = [old_traces[i].type() for i in used_traces_idxs]
204+
consts = update_fg.inputs[n_state_vars * 2 :]
205+
206+
go_on, *next_states = update_fg.outputs
207+
next_idx = prev_idx + 1
208+
next_idx.name = "next_idx"
209+
next_traces = [
210+
set_subtensor(prev_trace[prev_idx], next_states[trace_idx])
211+
for trace_idx, prev_trace in zip(used_traces_idxs, prev_traces)
212+
]
213+
go_on = and_(go_on, next_idx < n_steps)
214+
go_on.name = "go_on"
215+
216+
new_update_fg = FunctionGraph(
217+
inputs=[prev_idx, *prev_states, *prev_traces, *consts],
218+
outputs=[go_on, next_idx, *next_states, *next_traces],
219+
)
220+
221+
# TODO: Implement Reverse?
222+
loop_op = Loop(update=new_update_fg)
223+
224+
init_idx = constant(np.array(0, dtype="int64"))
225+
init_states = node.inputs[1 : len(op.state_types) + 1]
226+
init_traces = [
227+
empty((n_steps, *tuple(init_states[trace_idx].shape)))
228+
for trace_idx in used_traces_idxs
229+
]
230+
final_idx, *new_outs = loop_op(init_idx, *init_states, *init_traces)
231+
new_states = new_outs[:n_state_vars]
232+
new_traces = new_outs[n_state_vars:]
233+
234+
replacements = dict(zip(old_states, new_states))
235+
for trace_idx, new_trace in zip(used_traces_idxs, new_traces):
236+
replacements[old_traces[trace_idx]] = new_trace[:final_idx]
237+
return replacements
238+
239+
240+
# TODO: Create new Loop dataset
241+
optdb.register(
242+
"scan_to_loop",
243+
in2out(scan_to_loop),
244+
"fast_compile",
245+
"fast_run",
246+
position=-0.1, # TODO: When?
247+
)

tests/loop/__init__.py

Whitespace-only changes.

tests/loop/test_op.py

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
import numpy as np
2+
3+
from pytensor import function, shared
4+
from pytensor.graph import FunctionGraph
5+
from pytensor.loop.op import Loop, Scan
6+
from pytensor.tensor import constant, lscalar, scalar
7+
from pytensor.tensor.random import normal
8+
from pytensor.tensor.type_other import NoneTypeT
9+
10+
11+
def test_loop_basic():
12+
i = lscalar("i")
13+
x = scalar("x")
14+
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])
15+
16+
_, y = Loop(update=update_fg)(np.array(0, dtype="int64"), x)
17+
assert y.eval({x: 0}) == 20
18+
19+
20+
def test_for_scan():
21+
x = scalar("x")
22+
update_fg = FunctionGraph([x], [constant(np.array(True)), x + 2])
23+
24+
n_steps = 10
25+
y, ys = Scan(update=update_fg)(n_steps, x)
26+
27+
fn = function([x], [y, ys])
28+
29+
loop_nodes = tuple(
30+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
31+
)
32+
assert len(loop_nodes) == 1
33+
(loop_node,) = loop_nodes
34+
print(loop_node.inputs)
35+
assert len(loop_node.outputs) == 3
36+
assert loop_node.outputs[0].type.shape == ()
37+
assert loop_node.outputs[1].type.shape == ()
38+
assert loop_node.outputs[2].type.shape == (None,) # This could be known
39+
40+
y_eval, ys_eval = fn(0)
41+
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
42+
np.testing.assert_array_equal(ys_eval[-1], y_eval)
43+
44+
45+
def test_while_scan():
46+
i = lscalar("i")
47+
x = scalar("x")
48+
update_fg = FunctionGraph([i, x], [(i + 1) < 10, i + 1, x + 2])
49+
50+
n_steps = 1000
51+
_, y, _, ys = Scan(update=update_fg)(n_steps, np.array(0, dtype="int64"), x)
52+
53+
fn = function([x], [y, ys])
54+
55+
loop_nodes = tuple(
56+
node for node in fn.maker.fgraph.apply_nodes if isinstance(node.op, Loop)
57+
)
58+
assert len(loop_nodes) == 1
59+
(loop_node,) = loop_nodes
60+
print(loop_node.inputs)
61+
assert len(loop_node.outputs) == 4
62+
assert loop_node.outputs[0].type.shape == ()
63+
assert loop_node.outputs[1].type.shape == ()
64+
assert loop_node.outputs[2].type.shape == ()
65+
assert loop_node.outputs[3].type.shape == (None,) # This could be known
66+
67+
y_eval, ys_eval = fn(0)
68+
np.testing.assert_array_equal(ys_eval, np.arange(2, 22, 2))
69+
np.testing.assert_array_equal(ys_eval[-1], y_eval)
70+
71+
72+
def test_scan_random():
73+
rng_test = np.random.default_rng(123)
74+
rng_shared = shared(np.random.default_rng(123))
75+
n_steps = 5
76+
77+
x = scalar(
78+
"x"
79+
) # TODO: x shouldn't be needed when the initial_state does not matter!
80+
rng = rng_shared.type()
81+
update_fg = FunctionGraph(
82+
[x, rng], [constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]]
83+
)
84+
85+
_, new_rng, ys, rngs = Scan(update=update_fg)(n_steps, x, rng_shared)
86+
assert isinstance(rngs.type, NoneTypeT)
87+
88+
fn = function([x], ys, updates={rng_shared: new_rng})
89+
90+
np.testing.assert_array_equal(fn(0), rng_test.normal(size=5))
91+
np.testing.assert_array_equal(fn(0), rng_test.normal(size=5))

0 commit comments

Comments
 (0)