Skip to content

Commit cae467e

Browse files
committed
Don't run useless fusion and inplace rewrites in JAX mode
1 parent 6762240 commit cae467e

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

pytensor/compile/mode.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,10 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
454454

455455
JAX = Mode(
456456
JAXLinker(),
457-
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
457+
RewriteDatabaseQuery(
458+
include=["fast_run", "jax"],
459+
exclude=["cxx_only", "BlasOpt", "fusion", "inplace"],
460+
),
458461
)
459462
NUMBA = Mode(
460463
NumbaLinker(),

0 commit comments

Comments
 (0)