|
27 | 27 | from pytensor.compile.sharedvalue import shared
|
28 | 28 | from pytensor.configdefaults import config
|
29 | 29 | from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
|
| 30 | +from pytensor.graph import vectorize_graph |
30 | 31 | from pytensor.graph.basic import Apply, ancestors, equal_computations
|
31 | 32 | from pytensor.graph.fg import FunctionGraph
|
32 | 33 | from pytensor.graph.op import Op
|
@@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1):
|
1178 | 1179 |
|
1179 | 1180 | utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng)
|
1180 | 1181 |
|
| 1182 | + def test_blockwise_scan(self): |
| 1183 | + x = pt.tensor("x", shape=()) |
| 1184 | + out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10) |
| 1185 | + x_vec = pt.tensor("x_vec", shape=(None,)) |
| 1186 | + out_vec = vectorize_graph(out, {x: x_vec}) |
| 1187 | + |
| 1188 | + fn = function([x_vec], out_vec) |
| 1189 | + o1 = fn([1, 2, 3]) |
| 1190 | + o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1) |
| 1191 | + assert np.allclose(o1, o2) |
| 1192 | + |
1181 | 1193 | def test_connection_pattern(self):
|
1182 | 1194 | """Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps."""
|
1183 | 1195 |
|
|
0 commit comments