Skip to content

Commit 41e9ea6

Browse files
brandonwillardtwiecki
authored andcommitted
Set the global mode during compilation
1 parent a0ff28e commit 41e9ea6

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

pytensor/compile/function/types.py

+14-7
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232

3333

3434
if TYPE_CHECKING:
35+
from pytensor.compile.mode import Mode
3536
from pytensor.link.vm import VM
3637

3738

@@ -1391,16 +1392,24 @@ def check_unused_inputs(inputs, outputs, on_unused_input):
13911392

13921393
@staticmethod
13931394
def prepare_fgraph(
1394-
inputs, outputs, additional_outputs, fgraph, rewriter, linker, profile
1395+
inputs,
1396+
outputs,
1397+
additional_outputs,
1398+
fgraph: FunctionGraph,
1399+
mode: "Mode",
1400+
profile,
13951401
):
13961402

1403+
rewriter = mode.optimizer
1404+
13971405
try:
13981406
start_rewriter = time.perf_counter()
13991407

14001408
rewriter_profile = None
14011409
rewrite_time = None
14021410

14031411
with config.change_flags(
1412+
mode=mode,
14041413
compute_test_value=config.compute_test_value_opt,
14051414
traceback__limit=config.traceback__compile_limit,
14061415
):
@@ -1440,7 +1449,7 @@ def prepare_fgraph(
14401449
stacklevel=3,
14411450
)
14421451

1443-
if not hasattr(linker, "accept"):
1452+
if not hasattr(mode.linker, "accept"):
14441453
raise ValueError(
14451454
"'linker' parameter of FunctionMaker should be "
14461455
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}"
@@ -1511,12 +1520,8 @@ def __init__(
15111520

15121521
self.fgraph = fgraph
15131522

1514-
rewriter, linker = mode.optimizer, copy.copy(mode.linker)
1515-
15161523
if not no_fgraph_prep:
1517-
self.prepare_fgraph(
1518-
inputs, outputs, found_updates, fgraph, rewriter, linker, profile
1519-
)
1524+
self.prepare_fgraph(inputs, outputs, found_updates, fgraph, mode, profile)
15201525

15211526
assert len(fgraph.outputs) == len(outputs + found_updates)
15221527

@@ -1528,6 +1533,8 @@ def __init__(
15281533
if not spec.borrow
15291534
]
15301535

1536+
linker = copy.copy(mode.linker)
1537+
15311538
if no_borrow:
15321539
self.linker = linker.accept(
15331540
fgraph,

0 commit comments

Comments
 (0)