Skip to content

Commit 301f10d

Browse files
committed
Add rewrite to remove Blockwise of AdvancedIncSubtensor
1 parent 2e2c871 commit 301f10d

File tree

2 files changed

+153
-1
lines changed

2 files changed

+153
-1
lines changed

pytensor/tensor/rewriting/subtensor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
register_infer_shape,
3030
switch,
3131
)
32+
from pytensor.tensor.blockwise import Blockwise
3233
from pytensor.tensor.elemwise import Elemwise
3334
from pytensor.tensor.exceptions import NotScalarConstantError
3435
from pytensor.tensor.math import Dot, add
@@ -1880,3 +1881,58 @@ def local_uint_constant_indices(fgraph, node):
18801881
copy_stack_trace(node.outputs, new_outs)
18811882

18821883
return new_outs
1884+
1885+
1886+
@register_canonicalize("shape_unsafe")
1887+
@register_stabilize("shape_unsafe")
1888+
@register_specialize("shape_unsafe")
1889+
@node_rewriter([Blockwise])
1890+
def local_blockwise_advanced_inc_subtensor(fgraph, node):
1891+
"""Rewrite blockwise advanced inc_subtensor whithout batched indexes as an inc_subtensor with prepended empty slices."""
1892+
if not isinstance(node.op.core_op, AdvancedIncSubtensor):
1893+
return None
1894+
1895+
x, y, *idxs = node.inputs
1896+
1897+
# It is currently not possible to Vectorize such AdvancedIncSubtensor, but we check again just in case
1898+
if any(
1899+
(
1900+
isinstance(idx, (SliceType, NoneTypeT))
1901+
or (idx.type.dtype == "bool" and idx.type.ndim > 0)
1902+
)
1903+
for idx in idxs
1904+
):
1905+
return None
1906+
1907+
op: Blockwise = node.op # type: ignore
1908+
batch_ndim = op.batch_ndim(node)
1909+
1910+
new_idxs = []
1911+
for idx in idxs:
1912+
if all(idx.type.broadcastable[:batch_ndim]):
1913+
new_idxs.append(idx.squeeze(tuple(range(batch_ndim))))
1914+
else:
1915+
# Rewrite does not apply
1916+
return None
1917+
1918+
x_batch_bcast = x.type.broadcastable[:batch_ndim]
1919+
y_batch_bcast = y.type.broadcastable[:batch_ndim]
1920+
if any(xb and not yb for xb, yb in zip(x_batch_bcast, y_batch_bcast)):
1921+
# Need to broadcast batch x dims
1922+
batch_shape = tuple(
1923+
x_dim if (not xb or yb) else y_dim
1924+
for xb, x_dim, yb, y_dim in zip(
1925+
x_batch_bcast,
1926+
tuple(x.shape)[:batch_ndim],
1927+
y_batch_bcast,
1928+
tuple(y.shape)[:batch_ndim],
1929+
)
1930+
)
1931+
core_shape = tuple(x.shape)[batch_ndim:]
1932+
x = alloc(x, *batch_shape, *core_shape)
1933+
1934+
new_idxs = [slice(None)] * batch_ndim + new_idxs
1935+
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
1936+
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
1937+
copy_stack_trace(node.outputs, new_out)
1938+
return new_out

tests/tensor/rewriting/test_subtensor.py

Lines changed: 97 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1010
from pytensor.compile.ops import DeepCopyOp
1111
from pytensor.configdefaults import config
12-
from pytensor.graph import FunctionGraph
12+
from pytensor.graph import FunctionGraph, vectorize_graph
1313
from pytensor.graph.basic import Constant, Variable, ancestors
1414
from pytensor.graph.rewriting.basic import check_stack_trace
1515
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -18,6 +18,7 @@
1818
from pytensor.raise_op import Assert
1919
from pytensor.tensor import inplace
2020
from pytensor.tensor.basic import Alloc, MakeVector, _convert_to_int8, make_vector
21+
from pytensor.tensor.blockwise import Blockwise
2122
from pytensor.tensor.elemwise import DimShuffle, Elemwise
2223
from pytensor.tensor.math import Dot, add, dot, exp, sqr
2324
from pytensor.tensor.rewriting.subtensor import (
@@ -2314,3 +2315,98 @@ def test_local_uint_constant_indices():
23142315
new_index = subtensor_node.inputs[1]
23152316
assert isinstance(new_index, Constant)
23162317
assert new_index.type.dtype == "uint8"
2318+
2319+
2320+
@pytest.mark.parametrize("set_instead_of_inc", (True, False))
2321+
def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
2322+
core_x = tensor("x", shape=(6,))
2323+
core_y = tensor("y", shape=(3,))
2324+
core_idxs = [0, 2, 4]
2325+
if set_instead_of_inc:
2326+
core_graph = set_subtensor(core_x[core_idxs], core_y)
2327+
else:
2328+
core_graph = inc_subtensor(core_x[core_idxs], core_y)
2329+
2330+
# Only x is batched
2331+
x = tensor("x", shape=(5, 2, 6))
2332+
y = tensor("y", shape=(3,))
2333+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2334+
assert isinstance(out.owner.op, Blockwise)
2335+
2336+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2337+
assert not any(
2338+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2339+
)
2340+
2341+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2342+
test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype)
2343+
expected_out = test_x.copy()
2344+
if set_instead_of_inc:
2345+
expected_out[:, :, core_idxs] = test_y
2346+
else:
2347+
expected_out[:, :, core_idxs] += test_y
2348+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2349+
2350+
# Only y is batched
2351+
x = tensor("y", shape=(6,))
2352+
y = tensor("y", shape=(2, 3))
2353+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2354+
assert isinstance(out.owner.op, Blockwise)
2355+
2356+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2357+
assert not any(
2358+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2359+
)
2360+
2361+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2362+
test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype)
2363+
expected_out = np.ones((2, *x.type.shape))
2364+
if set_instead_of_inc:
2365+
expected_out[:, core_idxs] = test_y
2366+
else:
2367+
expected_out[:, core_idxs] += test_y
2368+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2369+
2370+
# Both x and y are batched, and do not need to be broadcasted
2371+
x = tensor("y", shape=(2, 6))
2372+
y = tensor("y", shape=(2, 3))
2373+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2374+
assert isinstance(out.owner.op, Blockwise)
2375+
2376+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2377+
assert not any(
2378+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2379+
)
2380+
2381+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2382+
test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype)
2383+
expected_out = test_x.copy()
2384+
if set_instead_of_inc:
2385+
expected_out[:, core_idxs] = test_y
2386+
else:
2387+
expected_out[:, core_idxs] += test_y
2388+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
2389+
2390+
# Both x and y are batched, but must be broadcasted
2391+
x = tensor("y", shape=(5, 1, 6))
2392+
y = tensor("y", shape=(1, 2, 3))
2393+
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
2394+
assert isinstance(out.owner.op, Blockwise)
2395+
2396+
fn = pytensor.function([x, y], out, mode="FAST_RUN")
2397+
assert not any(
2398+
isinstance(node.op, Blockwise) for node in fn.maker.fgraph.apply_nodes
2399+
)
2400+
2401+
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
2402+
test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype)
2403+
final_shape = (
2404+
*np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]),
2405+
x.type.shape[-1],
2406+
)
2407+
expected_out = np.broadcast_to(test_x, final_shape).copy()
2408+
if set_instead_of_inc:
2409+
expected_out[:, :, core_idxs] = test_y
2410+
else:
2411+
expected_out[:, :, core_idxs] += test_y
2412+
np.testing.assert_allclose(fn(test_x, test_y), expected_out)

0 commit comments

Comments
 (0)