Skip to content

Commit 864ebd1

Browse files
authored
Handle compute_map=None in Scan (#1435)
1 parent 92eef5e commit 864ebd1

File tree

2 files changed

+15
-2
lines changed

2 files changed

+15
-2
lines changed

pytensor/scan/op.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,8 +1647,9 @@ def rval(
16471647
p=p, i=node_input_storage, o=node_output_storage, n=node, allow_gc=allow_gc
16481648
):
16491649
r = p(n, [x[0] for x in i], o)
1650-
for o in node.outputs:
1651-
compute_map[o][0] = True
1650+
if compute_map is not None:
1651+
for o in node.outputs:
1652+
compute_map[o][0] = True
16521653
if allow_gc:
16531654
self.fn.free()
16541655
return r

tests/scan/test_basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from pytensor.compile.sharedvalue import shared
2828
from pytensor.configdefaults import config
2929
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
30+
from pytensor.graph import vectorize_graph
3031
from pytensor.graph.basic import Apply, ancestors, equal_computations
3132
from pytensor.graph.fg import FunctionGraph
3233
from pytensor.graph.op import Op
@@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1):
11781179

11791180
utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng)
11801181

1182+
def test_blockwise_scan(self):
1183+
x = pt.tensor("x", shape=())
1184+
out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10)
1185+
x_vec = pt.tensor("x_vec", shape=(None,))
1186+
out_vec = vectorize_graph(out, {x: x_vec})
1187+
1188+
fn = function([x_vec], out_vec)
1189+
o1 = fn([1, 2, 3])
1190+
o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1)
1191+
assert np.allclose(o1, o2)
1192+
11811193
def test_connection_pattern(self):
11821194
"""Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps."""
11831195

0 commit comments

Comments
 (0)