File tree 1 file changed +1
-13
lines changed 1 file changed +1
-13
lines changed Original file line number Diff line number Diff line change 30
30
from arviz .data .base import make_attrs
31
31
from jax .lax import scan
32
32
from numpy .typing import ArrayLike
33
- from pytensor .compile import SharedVariable , Supervisor , mode
33
+ from pytensor .compile import SharedVariable , mode
34
34
from pytensor .graph .basic import graph_inputs
35
35
from pytensor .graph .fg import FunctionGraph
36
36
from pytensor .graph .replace import clone_replace
@@ -127,18 +127,6 @@ def get_jaxified_graph(
127
127
graph = _replace_shared_variables (outputs ) if outputs is not None else None
128
128
129
129
fgraph = FunctionGraph (inputs = inputs , outputs = graph , clone = True )
130
- # We need to add a Supervisor to the fgraph to be able to run the
131
- # JAX sequential optimizer without warnings. We made sure there
132
- # are no mutable input variables, so we only need to check for
133
- # "destroyers". This should be automatically handled by PyTensor
134
- # once https://github.com/aesara-devs/aesara/issues/637 is fixed.
135
- fgraph .attach_feature (
136
- Supervisor (
137
- input
138
- for input in fgraph .inputs
139
- if not (hasattr (fgraph , "destroyers" ) and fgraph .has_destroyers ([input ]))
140
- )
141
- )
142
130
mode .JAX .optimizer .rewrite (fgraph )
143
131
144
132
# We now jaxify the optimized fgraph
You can’t perform that action at this time.
0 commit comments