Skip to content

Commit a0ff28e

Browse files
brandonwillardtwiecki
authored andcommitted
Add get_target_language function and remove tests.compile.test_modes
1 parent e22b1ef commit a0ff28e

File tree

3 files changed

+121
-60
lines changed

3 files changed

+121
-60
lines changed

pytensor/compile/mode.py

+25
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
import warnings
88
from typing import Optional, Tuple, Union
99

10+
from typing_extensions import Literal
11+
1012
from pytensor.compile.function.types import Supervisor
1113
from pytensor.configdefaults import config
1214
from pytensor.graph.destroyhandler import DestroyHandler
@@ -530,3 +532,26 @@ def register_mode(name, mode):
530532
if name in predefined_modes:
531533
raise ValueError(f"Mode name already taken: {name}")
532534
predefined_modes[name] = mode
535+
536+
537+
def get_target_language(mode=None) -> Tuple[Literal["py", "c", "numba", "jax"], ...]:
538+
"""Get the compilation target language."""
539+
540+
if mode is None:
541+
mode = get_default_mode()
542+
543+
linker = mode.linker
544+
545+
if isinstance(linker, NumbaLinker):
546+
return ("numba",)
547+
if isinstance(linker, JAXLinker):
548+
return ("jax",)
549+
if isinstance(linker, PerformLinker):
550+
return ("py",)
551+
if isinstance(linker, CLinker):
552+
return ("c",)
553+
554+
if isinstance(linker, (VMLinker, OpWiseCLinker)):
555+
return ("c", "py") if config.cxx else ("py",)
556+
557+
raise Exception(f"Unsupported Linker: {linker}")

tests/compile/test_mode.py

+96-2
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
1+
import copy
2+
3+
import pytest
4+
15
from pytensor.compile.function import function
2-
from pytensor.compile.mode import AddFeatureOptimizer, Mode
6+
from pytensor.compile.mode import (
7+
AddFeatureOptimizer,
8+
Mode,
9+
get_default_mode,
10+
get_target_language,
11+
)
12+
from pytensor.configdefaults import config
313
from pytensor.graph.features import NoOutputFromInplace
414
from pytensor.graph.rewriting.db import RewriteDatabaseQuery, SequenceDB
15+
from pytensor.link.basic import LocalLinker
516
from pytensor.tensor.math import dot, tanh
6-
from pytensor.tensor.type import matrix
17+
from pytensor.tensor.type import matrix, vector
718

819

920
def test_Mode_basic():
@@ -48,3 +59,86 @@ def test_including():
4859

4960
new_mode = mode.including("fast_compile")
5061
assert set(new_mode._optimizer.include) == {"merge", "fast_compile"}
62+
63+
64+
class TestBunchOfModes:
65+
def test_modes(self):
66+
# this is a quick test after the LazyLinker branch merge
67+
# to check that all the current modes can still be used.
68+
linker_classes_involved = []
69+
70+
predef_modes = ["FAST_COMPILE", "FAST_RUN", "DEBUG_MODE"]
71+
72+
# Linkers to use with regular Mode
73+
if config.cxx:
74+
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc", "cvm", "cvm_nogc"]
75+
else:
76+
linkers = ["py", "c|py", "c|py_nogc", "vm", "vm_nogc"]
77+
modes = predef_modes + [Mode(linker, "fast_run") for linker in linkers]
78+
79+
for mode in modes:
80+
x = matrix()
81+
y = vector()
82+
f = function([x, y], x + y, mode=mode)
83+
# test that it runs something
84+
f([[1, 2], [3, 4]], [5, 6])
85+
linker_classes_involved.append(f.maker.mode.linker.__class__)
86+
# print 'MODE:', mode, f.maker.mode.linker, 'stop'
87+
88+
# regression check:
89+
# there should be
90+
# - `VMLinker`
91+
# - OpWiseCLinker (FAST_RUN)
92+
# - PerformLinker (FAST_COMPILE)
93+
# - DebugMode's Linker (DEBUG_MODE)
94+
assert 4 == len(set(linker_classes_involved))
95+
96+
97+
class TestOldModesProblem:
98+
def test_modes(self):
99+
# Then, build a mode with the same linker, and a modified optimizer
100+
default_mode = get_default_mode()
101+
modified_mode = default_mode.including("specialize")
102+
103+
# The following line used to fail, with Python 2.4, in July 2012,
104+
# because an fgraph was associated to the default linker
105+
copy.deepcopy(modified_mode)
106+
107+
# More straightforward test
108+
linker = get_default_mode().linker
109+
assert not hasattr(linker, "fgraph") or linker.fgraph is None
110+
111+
112+
def test_get_target_language():
113+
with config.change_flags(mode=Mode(linker="py")):
114+
res = get_target_language()
115+
assert res == ("py",)
116+
117+
res = get_target_language(Mode(linker="py"))
118+
assert res == ("py",)
119+
120+
res = get_target_language(Mode(linker="c"))
121+
assert res == ("c",)
122+
123+
res = get_target_language(Mode(linker="c|py"))
124+
assert res == ("c", "py")
125+
126+
res = get_target_language(Mode(linker="vm"))
127+
assert res == ("c", "py")
128+
129+
with config.change_flags(cxx=""):
130+
res = get_target_language(Mode(linker="vm"))
131+
assert res == ("py",)
132+
133+
res = get_target_language(Mode(linker="jax"))
134+
assert res == ("jax",)
135+
136+
res = get_target_language(Mode(linker="numba"))
137+
assert res == ("numba",)
138+
139+
class MyLinker(LocalLinker):
140+
pass
141+
142+
test_mode = Mode(linker=MyLinker())
143+
with pytest.raises(Exception):
144+
get_target_language(test_mode)

tests/compile/test_modes.py

-58
This file was deleted.

0 commit comments

Comments
 (0)