Skip to content

Commit 4257a49

Browse files
committed
Allow non-TensorVariable types to be traced in new Scan Op
1 parent f312a0f commit 4257a49

File tree

2 files changed

+24
-23
lines changed

2 files changed

+24
-23
lines changed

pytensor/loop/op.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,12 @@
77
from pytensor.graph import Apply, FunctionGraph, Op, Type, node_rewriter
88
from pytensor.graph.rewriting.basic import in2out
99
from pytensor.scalar import constant
10-
from pytensor.tensor import (
11-
NoneConst,
12-
add,
13-
and_,
14-
empty,
15-
get_scalar_constant_value,
16-
set_subtensor,
17-
)
10+
from pytensor.tensor import add, and_, empty, get_scalar_constant_value, set_subtensor
1811
from pytensor.tensor.exceptions import NotScalarConstantError
1912
from pytensor.tensor.shape import Shape_i
2013
from pytensor.tensor.type import DenseTensorType, TensorType
2114
from pytensor.tensor.type_other import NoneTypeT
15+
from pytensor.typed_list import TypedListType, append, make_empty_list
2216

2317

2418
def validate_loop_update_types(update):
@@ -176,8 +170,7 @@ def __init__(
176170
)
177171
)
178172
else:
179-
# We can't concatenate all types of states, such as RandomTypes
180-
self.trace_types.append(NoneConst.type)
173+
self.trace_types.append(TypedListType(state_type))
181174

182175
self.constant_types = [inp.type for inp in update_fg.inputs[self.n_states :]]
183176
self.n_constants = len(self.constant_types)
@@ -312,10 +305,6 @@ def scan(fn, idx, initial_states, constants, max_iters):
312305
if fgraph.clients[trace]
313306
]
314307

315-
# Check that outputs that cannot be converted into sequences (such as RandomTypes) are not being referenced
316-
for trace_idx in used_traces_idxs:
317-
assert not isinstance(old_states[trace_idx].type, NoneTypeT)
318-
319308
# Inputs to the new Loop
320309
max_iters = node.inputs[0]
321310
init_states = node.inputs[1 : 1 + op.n_states]
@@ -324,6 +313,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
324313
(max_iters, *tuple(init_states[trace_idx].shape)),
325314
dtype=init_states[trace_idx].dtype,
326315
)
316+
if isinstance(init_states[trace_idx].type, DenseTensorType)
317+
else make_empty_list(init_states[trace_idx].type)
327318
for trace_idx in used_traces_idxs
328319
]
329320
constants = node.inputs[1 + op.n_states :]
@@ -376,6 +367,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
376367
# Inner traces
377368
inner_states = update_fg.inputs[: op.n_states]
378369
inner_traces = [init_trace.type() for init_trace in init_traces]
370+
379371
for s, t in zip(inner_states, inner_traces):
380372
t.name = "trace"
381373
if s.name:
@@ -387,6 +379,8 @@ def scan(fn, idx, initial_states, constants, max_iters):
387379
inner_while_cond, *inner_next_states = update_fg.outputs
388380
inner_next_traces = [
389381
set_subtensor(prev_trace[inner_idx], inner_next_states[trace_idx])
382+
if isinstance(prev_trace.type, DenseTensorType)
383+
else append(prev_trace, inner_next_states[trace_idx])
390384
for trace_idx, prev_trace in zip(used_traces_idxs, inner_traces)
391385
]
392386
for t in inner_next_traces:
@@ -429,7 +423,7 @@ def scan(fn, idx, initial_states, constants, max_iters):
429423
replacements = dict(zip(old_states, new_states))
430424
for trace_idx, new_trace in zip(used_traces_idxs, new_traces):
431425
# If there is no while condition, the whole trace will be used
432-
if op.has_while_condition:
426+
if op.has_while_condition and isinstance(new_trace.type, DenseTensorType):
433427
new_trace = new_trace[:final_idx]
434428
replacements[old_traces[trace_idx]] = new_trace
435429
return replacements

tests/loop/test_op.py

+15-8
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from pytensor.loop.op import Loop, Scan
88
from pytensor.tensor import constant, empty, lscalar, scalar, vector
99
from pytensor.tensor.random import normal
10+
from pytensor.tensor.random.type import RandomGeneratorType
1011
from pytensor.tensor.subtensor import Subtensor
11-
from pytensor.tensor.type_other import NoneTypeT
12+
from pytensor.typed_list import TypedListType
1213

1314

1415
def test_loop_basic():
@@ -152,10 +153,16 @@ def test_fori_random_scan():
152153
[constant(np.array(True)), *normal(rng=rng).owner.outputs[::-1]],
153154
)
154155

155-
_, new_rng, ys, rngs = Scan(update_fg=update_fg)(n_iters, dummy_init, rng_shared)
156-
assert isinstance(rngs.type, NoneTypeT)
157-
158-
fn = function([], ys, updates={rng_shared: new_rng})
159-
160-
np.testing.assert_array_equal(fn(), rng_test.normal(size=5))
161-
np.testing.assert_array_equal(fn(), rng_test.normal(size=5))
156+
last_y, last_rng, ys, rngs = Scan(update_fg=update_fg)(
157+
n_iters, dummy_init, rng_shared
158+
)
159+
assert isinstance(last_rng.type, RandomGeneratorType)
160+
assert isinstance(rngs.type, TypedListType)
161+
assert isinstance(rngs.type.ttype, RandomGeneratorType)
162+
163+
fn = function([], [ys, rngs], updates={rng_shared: last_rng})
164+
for i in range(2):
165+
ys_res, rngs_res = fn()
166+
for y_res, rng_res in zip(ys_res, rngs_res):
167+
np.testing.assert_almost_equal(y_res, rng_test.normal())
168+
assert rng_res.__getstate__() == rng_test.__getstate__()

0 commit comments

Comments
 (0)