|
9 | 9 | from pytensor.compile.mode import Mode, get_default_mode, get_mode
|
10 | 10 | from pytensor.compile.ops import DeepCopyOp
|
11 | 11 | from pytensor.configdefaults import config
|
12 |
| -from pytensor.graph import FunctionGraph |
| 12 | +from pytensor.graph import FunctionGraph, vectorize_graph |
13 | 13 | from pytensor.graph.basic import Constant, Variable, ancestors
|
14 | 14 | from pytensor.graph.rewriting.basic import check_stack_trace
|
15 | 15 | from pytensor.graph.rewriting.db import RewriteDatabaseQuery
|
|
18 | 18 | from pytensor.raise_op import Assert
|
19 | 19 | from pytensor.tensor import inplace
|
20 | 20 | from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
|
| 21 | +from pytensor.tensor.blockwise import Blockwise |
21 | 22 | from pytensor.tensor.elemwise import DimShuffle, Elemwise
|
22 | 23 | from pytensor.tensor.math import Dot, add, dot, exp, sqr
|
23 | 24 | from pytensor.tensor.rewriting.subtensor import (
|
@@ -2314,3 +2315,98 @@ def test_local_uint_constant_indices():
|
2314 | 2315 | new_index = subtensor_node.inputs[1]
|
2315 | 2316 | assert isinstance(new_index, Constant)
|
2316 | 2317 | assert new_index.type.dtype == "uint8"
|
| 2318 | + |
| 2319 | + |
| 2320 | +@pytest.mark.parametrize("set_instead_of_inc", (True, False)) |
| 2321 | +def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc): |
| 2322 | + core_x = tensor("x", shape=(6,)) |
| 2323 | + core_y = tensor("y", shape=(3,)) |
| 2324 | + core_idxs = [0, 2, 4] |
| 2325 | + if set_instead_of_inc: |
| 2326 | + core_graph = set_subtensor(core_x[core_idxs], core_y) |
| 2327 | + else: |
| 2328 | + core_graph = inc_subtensor(core_x[core_idxs], core_y) |
| 2329 | + |
| 2330 | + # Only x is batched |
| 2331 | + x = tensor("x", shape=(5, 2, 6)) |
| 2332 | + y = tensor("y", shape=(3,)) |
| 2333 | + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) |
| 2334 | + assert isinstance(out.owner.op, Blockwise) |
| 2335 | + |
| 2336 | + fn = pytensor.function([x, y], out, mode="FAST_RUN") |
| 2337 | + assert not any( |
| 2338 | + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes |
| 2339 | + ) |
| 2340 | + |
| 2341 | + test_x = np.ones(x.type.shape, dtype=x.type.dtype) |
| 2342 | + test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype) |
| 2343 | + expected_out = test_x.copy() |
| 2344 | + if set_instead_of_inc: |
| 2345 | + expected_out[:, :, core_idxs] = test_y |
| 2346 | + else: |
| 2347 | + expected_out[:, :, core_idxs] += test_y |
| 2348 | + np.testing.assert_allclose(fn(test_x, test_y), expected_out) |
| 2349 | + |
| 2350 | + # Only y is batched |
| 2351 | + x = tensor("y", shape=(6,)) |
| 2352 | + y = tensor("y", shape=(2, 3)) |
| 2353 | + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) |
| 2354 | + assert isinstance(out.owner.op, Blockwise) |
| 2355 | + |
| 2356 | + fn = pytensor.function([x, y], out, mode="FAST_RUN") |
| 2357 | + assert not any( |
| 2358 | + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes |
| 2359 | + ) |
| 2360 | + |
| 2361 | + test_x = np.ones(x.type.shape, dtype=x.type.dtype) |
| 2362 | + test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype) |
| 2363 | + expected_out = np.ones((2, *x.type.shape)) |
| 2364 | + if set_instead_of_inc: |
| 2365 | + expected_out[:, core_idxs] = test_y |
| 2366 | + else: |
| 2367 | + expected_out[:, core_idxs] += test_y |
| 2368 | + np.testing.assert_allclose(fn(test_x, test_y), expected_out) |
| 2369 | + |
| 2370 | + # Both x and y are batched, and do not need to be broadcasted |
| 2371 | + x = tensor("y", shape=(2, 6)) |
| 2372 | + y = tensor("y", shape=(2, 3)) |
| 2373 | + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) |
| 2374 | + assert isinstance(out.owner.op, Blockwise) |
| 2375 | + |
| 2376 | + fn = pytensor.function([x, y], out, mode="FAST_RUN") |
| 2377 | + assert not any( |
| 2378 | + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes |
| 2379 | + ) |
| 2380 | + |
| 2381 | + test_x = np.ones(x.type.shape, dtype=x.type.dtype) |
| 2382 | + test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype) |
| 2383 | + expected_out = test_x.copy() |
| 2384 | + if set_instead_of_inc: |
| 2385 | + expected_out[:, core_idxs] = test_y |
| 2386 | + else: |
| 2387 | + expected_out[:, core_idxs] += test_y |
| 2388 | + np.testing.assert_allclose(fn(test_x, test_y), expected_out) |
| 2389 | + |
| 2390 | + # Both x and y are batched, but must be broadcasted |
| 2391 | + x = tensor("y", shape=(5, 1, 6)) |
| 2392 | + y = tensor("y", shape=(1, 2, 3)) |
| 2393 | + out = vectorize_graph(core_graph, replace={core_x: x, core_y: y}) |
| 2394 | + assert isinstance(out.owner.op, Blockwise) |
| 2395 | + |
| 2396 | + fn = pytensor.function([x, y], out, mode="FAST_RUN") |
| 2397 | + assert not any( |
| 2398 | + isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes |
| 2399 | + ) |
| 2400 | + |
| 2401 | + test_x = np.ones(x.type.shape, dtype=x.type.dtype) |
| 2402 | + test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype) |
| 2403 | + final_shape = ( |
| 2404 | + *np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]), |
| 2405 | + x.type.shape[-1], |
| 2406 | + ) |
| 2407 | + expected_out = np.broadcast_to(test_x, final_shape).copy() |
| 2408 | + if set_instead_of_inc: |
| 2409 | + expected_out[:, :, core_idxs] = test_y |
| 2410 | + else: |
| 2411 | + expected_out[:, :, core_idxs] += test_y |
| 2412 | + np.testing.assert_allclose(fn(test_x, test_y), expected_out) |
0 commit comments