Skip to content

Commit a750fd7

Browse files
committed
Allow non-TensorVariable types to be traced in new Scan Op
1 parent 60e80f5 commit a750fd7

File tree

4 files changed

+100
-29
lines changed

4 files changed

+100
-29
lines changed

pytensor/link/jax/dispatch/loop.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import jax
2+
from jax.tree_util import tree_flatten, tree_unflatten
23

34
from pytensor.compile.mode import get_mode
45
from pytensor.link.jax.dispatch.basic import jax_funcify
56
from pytensor.loop.op import Scan
7+
from pytensor.typed_list import TypedListType
68

79

810
@jax_funcify.register(Scan)
@@ -43,10 +45,17 @@ def scan_fn(carry, _):
4345
states, traces = jax.lax.scan(
4446
scan_fn, init=list(states), xs=None, length=max_iters
4547
)
46-
for i in range(len(states)):
47-
if i not in used_traces_idxs:
48-
traces.insert(i, None)
49-
50-
return *states, *traces
48+
final_traces = [None] * len(states)
49+
for idx, trace in zip(used_traces_idxs, traces):
50+
if isinstance(op.trace_types[idx], TypedListType):
51+
flattened_trace, treedef = tree_flatten(trace)
52+
transposed_trace = [
53+
tree_unflatten(treedef, l) for l in zip(*flattened_trace)
54+
]
55+
final_traces[idx] = transposed_trace
56+
else:
57+
final_traces[idx] = trace
58+
59+
return *states, *final_traces
5160

5261
return scan

pytensor/loop/op.py

Lines changed: 45 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,13 @@
77
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter
88
from pytensor.graph.rewriting.basic import in2out
99
from pytensor.scalar import constant
10-
from pytensor.tensor import (
11-
NoneConst,
12-
add,
13-
and_,
14-
empty,
15-
get_scalar_constant_value,
16-
set_subtensor,
17-
)
10+
from pytensor.tensor import add, and_, empty, get_scalar_constant_value, set_subtensor
1811
from pytensor.tensor.exceptions import NotScalarConstantError
1912
from pytensor.tensor.shape import Shape_i
13+
from pytensor.tensor.subtensor import Subtensor, get_idx_list
2014
from pytensor.tensor.type import DenseTensorType, TensorType
2115
from pytensor.tensor.type_other import NoneTypeT
16+
from pytensor.typed_list import GetItem, TypedListType, append, make_empty_list
2217

2318

2419
def validate_loop_update_types(update):
@@ -176,8 +171,7 @@ def __init__(
176171
)
177172
)
178173
else:
179-
# We can't concatenate all types of states, such as RandomTypes
180-
self.trace_types.append(NoneConst.type)
174+
self.trace_types.append(TypedListType(state_type))
181175

182176
self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]]
183177
self.n_constants = len(self.constant_types)
@@ -312,10 +306,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
312306
if fgraph.clients[trace]
313307
]
314308

315-
# Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
316-
for trace_idx in used_traces_idxs:
317-
assert not isinstance(old_states[trace_idx].type, NoneTypeT)
318-
319309
# Inputs to the new Loop
320310
max_iters = node.inputs[0]
321311
init_states = node.inputs[1 : 1 + op.n_states]
@@ -324,6 +314,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
324314
(max_iters, *tuple(init_states[trace_idx].shape)),
325315
dtype=init_states[trace_idx].dtype,
326316
)
317+
if isinstance(init_states[trace_idx].type, DenseTensorType)
318+
else make_empty_list(init_states[trace_idx].type)
327319
for trace_idx in used_traces_idxs
328320
]
329321
constants = node.inputs[1 + op.n_states :]
@@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
387379
inner_while_cond, *inner_next_states = update_fg.outputs
388380
inner_next_traces = [
389381
set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx])
382+
if isinstance(prev_trace.type, DenseTensorType)
383+
else append(prev_trace, inner_next_states[trace_idx])
390384
for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces)
391385
]
392386
for t in inner_next_traces:
@@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
429423
replacements = dict(zip(old_states, new_states))
430424
for trace_idx, new_trace in zip(used_traces_idxs, new_traces):
431425
# If there is no while condition, the whole trace will be used
432-
if op.has_while_condition:
426+
if op.has_while_condition and isinstance(new_trace.type, DenseTensorType):
433427
new_trace = new_trace[:final_idx]
434428
replacements[old_traces[trace_idx]] = new_trace
435429
return replacements
@@ -446,3 +440,39 @@ def scan(fn, idx, initial_states, constants, max_iters):
446440
"not_jax",
447441
position=1.0,
448442
)
443+
444+
445+
@node_rewriter([Scan])
446+
def scan_view_last_state(fgraph, node):
447+
"""Replace trace[-1] by the last state output of a Scan node"""
448+
replacements = {}
449+
for final_state, trace in zip(
450+
node.outputs[: node.op.n_states], node.outputs[node.op.n_states :]
451+
):
452+
clients = fgraph.clients[trace]
453+
for client, _ in clients:
454+
if client == "output":
455+
continue
456+
if isinstance(client.op, (Subtensor, GetItem)):
457+
if isinstance(client.op, Subtensor):
458+
idxs = get_idx_list(client.inputs, client.op.idx_list)
459+
if len(idxs) == 1:
460+
idx = idxs[0]
461+
else:
462+
idx = client.inputs[1]
463+
try:
464+
last_index = get_scalar_constant_value(idx) == -1
465+
except NotScalarConstantError:
466+
continue
467+
if last_index:
468+
replacements[client.default_output()] = final_state
469+
return replacements
470+
471+
472+
optdb.register(
473+
"scan_view_last_state",
474+
in2out(scan_view_last_state),
475+
"fast_compile",
476+
"fast_run",
477+
position=0.999,
478+
)

tests/link/jax/test_loop.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import numpy as np
22
import pytest
3+
from jax.tree_util import tree_leaves
34

45
from pytensor import function, shared
56
from pytensor.graph import FunctionGraph
@@ -70,7 +71,7 @@ def test_scan_with_sequence_and_carried_state():
7071
def test_scan_with_rvs():
7172
rng = shared(np.random.default_rng(123))
7273

73-
[next_rng, _], [_, xs] = scan(
74+
[final_rng, _], [rngs, xs] = scan(
7475
fn=lambda prev_rng: normal(rng=prev_rng).owner.outputs,
7576
init_states=[rng, None],
7677
n_steps=10,
@@ -83,11 +84,19 @@ def test_scan_with_rvs():
8384
assert not set(tuple(np.array(res1))) ^ set(tuple(np.array(res2)))
8485

8586
# Now with updates
86-
fn = function([], xs, mode="JAX", updates={rng: next_rng})
87+
fn = function([], xs, mode="JAX", updates={rng: final_rng})
8788
res1 = fn()
8889
res2 = fn()
8990
assert not set(tuple(np.array(res1))) & set(tuple(np.array(res2)))
9091

92+
# Test traced rngs
93+
fn = function([], [rngs, final_rng], mode="JAX")
94+
rngs_res, final_rng_res = fn()
95+
assert isinstance(rngs_res, list) and len(rngs_res) == 10
96+
assert [np.array(v).tolist() for v in tree_leaves(rngs_res[-1])] == [
97+
np.array(v).tolist() for v in tree_leaves(final_rng_res)
98+
]
99+
91100

92101
def test_while_scan_fails():
93102
_, [xs] = scan(

tests/loop/test_op.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,13 @@
44
from pytensor import function, shared
55
from pytensor.compile import DeepCopyOp
66
from pytensor.graph import FunctionGraph
7-
from pytensor.loop.op import Loop, Scan
7+
from pytensor.graph.rewriting.basic import in2out
8+
from pytensor.loop.op import Loop, Scan, scan_view_last_state
89
from pytensor.tensor import constant, empty, lscalar, scalar, vector
910
from pytensor.tensor.random import normal
11+
from pytensor.tensor.random.type import RandomGeneratorType
1012
from pytensor.tensor.subtensor import Subtensor
11-
from pytensor.tensor.type_other import NoneTypeT
13+
from pytensor.typed_list import TypedListType
1214

1315

1416
def test_loop_basic():
@@ -152,10 +154,31 @@ def test_fori_random_scan():
152154
[constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]],
153155
)
154156

155-
_, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared)
156-
assert isinstance(rngs.type, NoneTypeT)
157+
last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)(
158+
n_iters, dummy_init, rng_shared
159+
)
160+
assert isinstance(last_rng.type, RandomGeneratorType)
161+
assert isinstance(rngs.type, TypedListType)
162+
assert isinstance(rngs.type.ttype, RandomGeneratorType)
163+
164+
fn = function([], [ys, rngs], updates={rng_shared: last_rng})
165+
for i in range(2):
166+
ys_res, rngs_res = fn()
167+
for y_res, rng_res in zip(ys_res, rngs_res):
168+
np.testing.assert_almost_equal(y_res, rng_test.normal())
169+
assert rng_res.__getstate__() == rng_test.__getstate__()
157170

158-
fn = function([], ys, updates={rng_shared: new_rng})
159171

160-
np.testing.assert_array_equal(fn(), rng_test.normal(size=5))
161-
np.testing.assert_array_equal(fn(), rng_test.normal(size=5))
172+
def test_scan_view_last_state():
173+
x = scalar("x")
174+
update_fg = FunctionGraph([x], [x > 5, x + 2])
175+
176+
n_iters = 10
177+
y1, ys = Scan(update_fg=update_fg)(n_iters, x)
178+
179+
y2 = ys[-1]
180+
fgraph = FunctionGraph(outputs=[y2, ys], clone=False)
181+
assert fgraph.outputs[0] is not y1
182+
in2out(scan_view_last_state).apply(fgraph)
183+
assert fgraph.outputs[0] is y1
184+
assert fgraph.outputs[1] is ys

0 commit comments

Comments
 (0)