Skip to content

Commit 43b84f5

Browse files
committed
Supervisor not needed for JAX rewrites
As it no longer includes inplace operations.
1 parent 450e7f6 commit 43b84f5

File tree

1 file changed

+1
-13
lines changed

1 file changed

+1
-13
lines changed

pymc/sampling/jax.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from arviz.data.base import make_attrs
3131
from jax.lax import scan
3232
from numpy.typing import ArrayLike
33-
from pytensor.compile import SharedVariable, Supervisor, mode
33+
from pytensor.compile import SharedVariable, mode
3434
from pytensor.graph.basic import graph_inputs
3535
from pytensor.graph.fg import FunctionGraph
3636
from pytensor.graph.replace import clone_replace
@@ -127,18 +127,6 @@ def get_jaxified_graph(
127127
graph = _replace_shared_variables(outputs) if outputs is not None else None
128128

129129
fgraph = FunctionGraph(inputs=inputs, outputs=graph, clone=True)
130-
# We need to add a Supervisor to the fgraph to be able to run the
131-
# JAX sequential optimizer without warnings. We made sure there
132-
# are no mutable input variables, so we only need to check for
133-
# "destroyers". This should be automatically handled by PyTensor
134-
# once https://github.com/aesara-devs/aesara/issues/637 is fixed.
135-
fgraph.attach_feature(
136-
Supervisor(
137-
input
138-
for input in fgraph.inputs
139-
if not (hasattr(fgraph, "destroyers") and fgraph.has_destroyers([input]))
140-
)
141-
)
142130
mode.JAX.optimizer.rewrite(fgraph)
143131

144132
# We now jaxify the optimized fgraph

0 commit comments

Comments
 (0)