Skip to content

Commit 69efc68

Browse files
committed
Handle inplace rewrites correctly in dispatch of OpFromGraph and Scan
JAX needs no special handling because it excludes inplace rewrites.
1 parent ad1af2e commit 69efc68

File tree

8 files changed

+131
-45
lines changed

8 files changed

+131
-45
lines changed

pytensor/compile/function/types.py

Lines changed: 56 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import logging
66
import time
77
import warnings
8+
from collections.abc import Sequence
89
from itertools import chain
910
from typing import TYPE_CHECKING
1011

@@ -168,6 +169,59 @@ def validate(self, fgraph):
168169
raise InconsistencyError(f"Trying to destroy a protected variable: {r}")
169170

170171

172+
def add_supervisor_to_fgraph(
173+
fgraph: FunctionGraph,
174+
input_specs: Sequence[SymbolicInput],
175+
accept_inplace: bool = False,
176+
) -> None:
177+
"""Setup Supervisor Feature in a FunctionGraph, so that inplace rewrites can be used.
178+
179+
Parameters
180+
----------
181+
fgraph: FunctionGraph
182+
The FunctionGraph to setup the Supervisor Feature in.
183+
input_specs: Sequence of SymbolicInput
184+
The input specifications for the FunctionGraph.
185+
Inputs with the attribute `mutable=False` and which are not already destroyed by an inplace operation
186+
(if `accept_inplace` is True) will be protected from inplace operations.
187+
Otherwise, they will be allowed to be destroyed.
188+
accept_inplace: bool
189+
Whether to allow inplace operations to already be present in the graph.
190+
191+
Raises
192+
------
193+
TypeError
194+
If inplace operations are not allowed and the graph already contains inplace operations.
195+
196+
"""
197+
198+
has_destroy_handler = hasattr(fgraph, "destroyers")
199+
if not (has_destroy_handler and accept_inplace):
200+
# Check if fgraph already contains destructive operations,
201+
# in which case we need to add a DestroyHandler or raise an error
202+
for node in fgraph.apply_nodes:
203+
if node.op.destroy_map:
204+
if not accept_inplace:
205+
raise TypeError(
206+
f"Graph must not contain inplace operations: {node}"
207+
)
208+
else:
209+
has_destroy_handler = True
210+
fgraph.attach_feature(DestroyHandler())
211+
break
212+
213+
# Protect all immutable inputs from inplace operations.
214+
fgraph.attach_feature(
215+
Supervisor(
216+
input
217+
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
218+
if not (
219+
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
220+
)
221+
)
222+
)
223+
224+
171225
def std_fgraph(
172226
input_specs: list[SymbolicInput],
173227
output_specs: list[SymbolicOutput],
@@ -229,24 +283,8 @@ def std_fgraph(
229283

230284
found_updates.extend(map(SymbolicOutput, updates))
231285

232-
for node in fgraph.apply_nodes:
233-
if node.op.destroy_map:
234-
if not accept_inplace:
235-
raise TypeError(f"Graph must not contain inplace operations: {node}")
236-
else:
237-
fgraph.attach_feature(DestroyHandler())
238-
break
239-
240-
# We need to protect all immutable inputs from inplace operations.
241-
fgraph.attach_feature(
242-
Supervisor(
243-
input
244-
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
245-
if not (
246-
spec.mutable
247-
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
248-
)
249-
)
286+
add_supervisor_to_fgraph(
287+
fgraph=fgraph, input_specs=input_specs, accept_inplace=accept_inplace
250288
)
251289

252290
# If named nodes are replaced, keep the name

pytensor/compile/mode.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,11 @@ def apply(self, fgraph):
138138
break
139139
if not supervisor_added:
140140
warnings.warn(
141-
f"A Supervisor feature is missing from {fgraph}.",
141+
(
142+
f"A Supervisor feature is missing from {fgraph}.\n"
143+
"This is needed for inplace rewrites. Either exclude inplace rewrites or add a Supervisor feature.\n"
144+
"A Supervisor feature can be added via `pytensor.compile.function.types.add_supervisor_to_fgraph`."
145+
),
142146
stacklevel=3,
143147
)
144148

pytensor/link/jax/dispatch/scan.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import jax
22
import jax.numpy as jnp
33

4-
from pytensor.compile.mode import JAX
4+
from pytensor.compile.mode import JAX, get_mode
55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.scan.op import Scan
77

@@ -19,7 +19,9 @@ def jax_funcify_Scan(op: Scan, **kwargs):
1919
)
2020

2121
# Optimize inner graph (exclude any defalut rewrites that are incompatible with JAX mode)
22-
rewriter = op.mode_instance.excluding(*JAX._optimizer.exclude).optimizer
22+
rewriter = (
23+
get_mode(op.mode).including("jax").excluding(*JAX._optimizer.exclude).optimizer
24+
)
2325
rewriter(op.fgraph)
2426
scan_inner_func = jax_funcify(op.fgraph, **kwargs)
2527

pytensor/link/numba/dispatch/basic.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1717
from numba.extending import box, overload
1818

19-
from pytensor import config
19+
from pytensor import In, config
2020
from pytensor.compile import NUMBA
2121
from pytensor.compile.builders import OpFromGraph
22+
from pytensor.compile.function.types import add_supervisor_to_fgraph
2223
from pytensor.compile.ops import DeepCopyOp
2324
from pytensor.graph.basic import Apply
2425
from pytensor.graph.fg import FunctionGraph
@@ -430,7 +431,13 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
430431
# TODO: Not sure this is the right place to do this, should we have a rewrite that
431432
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
432433
# The C-code defers it to the make_thunk phase
433-
NUMBA.optimizer(op.fgraph)
434+
fgraph = op.fgraph
435+
add_supervisor_to_fgraph(
436+
fgraph=fgraph,
437+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
438+
accept_inplace=True,
439+
)
440+
NUMBA.optimizer(fgraph)
434441
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
435442

436443
if len(op.fgraph.outputs) == 1:

pytensor/link/numba/dispatch/scan.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,9 @@
44
from numba import types
55
from numba.extending import overload
66

7-
from pytensor.compile.mode import NUMBA
7+
from pytensor import In
8+
from pytensor.compile.function.types import add_supervisor_to_fgraph
9+
from pytensor.compile.mode import NUMBA, get_mode
810
from pytensor.link.numba.dispatch import basic as numba_basic
911
from pytensor.link.numba.dispatch.basic import (
1012
create_arg_string,
@@ -59,11 +61,18 @@ def numba_funcify_Scan(op, node, **kwargs):
5961
# explicitly triggers the optimization of the inner graphs of Scan?
6062
# The C-code defers it to the make_thunk phase
6163
rewriter = (
62-
op.mode_instance.including("numba")
64+
get_mode(op.mode)
65+
.including("numba")
6366
.excluding(*NUMBA._optimizer.exclude)
6467
.optimizer
6568
)
66-
rewriter(op.fgraph)
69+
fgraph = op.fgraph
70+
add_supervisor_to_fgraph(
71+
fgraph=fgraph,
72+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
73+
accept_inplace=True,
74+
)
75+
rewriter(fgraph)
6776

6877
scan_inner_func = numba_basic.numba_njit(numba_funcify(op.fgraph))
6978

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
import torch
66
import torch.compiler
77

8+
from pytensor import In
89
from pytensor.compile import PYTORCH
910
from pytensor.compile.builders import OpFromGraph
11+
from pytensor.compile.function.types import add_supervisor_to_fgraph
1012
from pytensor.compile.ops import DeepCopyOp
1113
from pytensor.graph.basic import Constant
1214
from pytensor.graph.fg import FunctionGraph
@@ -185,6 +187,13 @@ def pytorch_funcify_OpFromGraph(op, node, **kwargs):
185187
kwargs.pop("storage_map", None)
186188
# Apply inner rewrites
187189
PYTORCH.optimizer(op.fgraph)
190+
fgraph = op.fgraph
191+
add_supervisor_to_fgraph(
192+
fgraph=fgraph,
193+
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
194+
accept_inplace=True,
195+
)
196+
PYTORCH.optimizer(fgraph)
188197
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
189198
return fgraph_fn
190199

pytensor/scan/op.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
from pytensor import tensor as pt
5858
from pytensor.compile.builders import construct_nominal_fgraph, infer_shape
5959
from pytensor.compile.function.pfunc import pfunc
60+
from pytensor.compile.function.types import add_supervisor_to_fgraph
6061
from pytensor.compile.io import In, Out
6162
from pytensor.compile.mode import Mode, get_mode
6263
from pytensor.compile.profiling import register_profiler_printer
@@ -834,8 +835,6 @@ def tensorConstructor(shape, dtype):
834835
self.n_outer_inputs = info.n_outer_inputs
835836
self.n_outer_outputs = info.n_outer_outputs
836837

837-
_ = self.prepare_fgraph(self.fgraph)
838-
839838
if any(node.op.destroy_map for node in self.fgraph.apply_nodes):
840839
raise InconsistencyError(
841840
"Inner-graphs must not contain in-place operations."
@@ -1394,23 +1393,8 @@ def prepare_fgraph(self, fgraph):
13941393

13951394
fgraph.update_mapping = update_mapping
13961395

1397-
from pytensor.compile.function.types import Supervisor
1398-
from pytensor.graph.destroyhandler import DestroyHandler
1399-
1400-
for node in fgraph.apply_nodes:
1401-
if node.op.destroy_map:
1402-
fgraph.attach_feature(DestroyHandler())
1403-
break
1404-
1405-
fgraph.attach_feature(
1406-
Supervisor(
1407-
inp
1408-
for spec, inp in zip(wrapped_inputs, fgraph.inputs, strict=True)
1409-
if not (
1410-
getattr(spec, "mutable", None)
1411-
or (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([inp]))
1412-
)
1413-
)
1396+
add_supervisor_to_fgraph(
1397+
fgraph=fgraph, input_specs=wrapped_inputs, accept_inplace=True
14141398
)
14151399

14161400
return wrapped_inputs, wrapped_outputs

tests/link/numba/test_basic.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -835,6 +835,39 @@ def test_OpFromGraph():
835835
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
836836

837837

838+
@pytest.mark.filterwarnings("error")
839+
def test_ofg_inner_inplace():
840+
x = pt.vector("x")
841+
set0 = x[0].set(1) # SetSubtensor should not inplace on x
842+
exp_x = pt.exp(x)
843+
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
844+
ofg0 = OpFromGraph([x], [set0])
845+
ofg1 = OpFromGraph([x], [set1])
846+
847+
y, z = pt.vectors("y", "z")
848+
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")
849+
850+
fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
851+
assert isinstance(fn_ofg0, OpFromGraph)
852+
fn_set0 = fn_ofg0.fgraph.outputs[0]
853+
assert fn_set0.owner.op.destroy_map == {}
854+
855+
fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
856+
assert isinstance(fn_ofg1, OpFromGraph)
857+
fn_set1 = fn_ofg1.fgraph.outputs[0]
858+
assert fn_set1.owner.op.destroy_map == {0: [0]}
859+
860+
x_test = np.array([0, 1, 1], dtype=config.floatX)
861+
y_test = np.array([0, 1, 1], dtype=config.floatX)
862+
res0, res1 = fn(x_test, y_test)
863+
# Check inputs were not mutated
864+
np.testing.assert_allclose(x_test, [0, 1, 1])
865+
np.testing.assert_allclose(y_test, [0, 1, 1])
866+
# Check outputs are correct
867+
np.testing.assert_allclose(res0, [1, 1, 1])
868+
np.testing.assert_allclose(res1, [1, np.e, np.e])
869+
870+
838871
@pytest.mark.filterwarnings("error")
839872
def test_cache_warning_suppressed():
840873
x = pt.vector("x", shape=(5,), dtype="float64")

0 commit comments

Comments
 (0)