Skip to content

Commit 448be82

Browse files
committed
Remove specialize_device database
1 parent cae467e commit 448be82

File tree

4 files changed

+8
-31
lines changed

4 files changed

+8
-31
lines changed

pytensor/compile/mode.py

-5
Original file line numberDiff line numberDiff line change
@@ -248,11 +248,6 @@ def apply(self, fgraph):
248248
# misc special cases for speed that break canonicalization
249249
optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)
250250

251-
# misc special cases for speed that are dependent on the device.
252-
optdb.register(
253-
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
254-
) # must be after gpu stuff at 48.5
255-
256251
# especially constant merge
257252
optdb.register("merge2", MergeOptimizer(), "fast_run", "merge", position=49)
258253

pytensor/tensor/rewriting/basic.py

-19
Original file line numberDiff line numberDiff line change
@@ -205,25 +205,6 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
205205
return node_rewriter
206206

207207

208-
def register_specialize_device(
209-
node_rewriter: Union[RewriteDatabase, Rewriter, str], *tags: str, **kwargs
210-
):
211-
if isinstance(node_rewriter, str):
212-
213-
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
214-
return register_specialize_device(
215-
inner_rewriter, node_rewriter, *tags, **kwargs
216-
)
217-
218-
return register
219-
else:
220-
name = (kwargs and kwargs.pop("name", None)) or node_rewriter.__name__
221-
compile.optdb["specialize_device"].register(
222-
name, node_rewriter, "fast_run", *tags, **kwargs
223-
)
224-
return node_rewriter
225-
226-
227208
@register_canonicalize
228209
@register_specialize
229210
@node_rewriter([TensorFromScalar])

pytensor/tensor/rewriting/math.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,6 @@
8888
local_fill_sink,
8989
register_canonicalize,
9090
register_specialize,
91-
register_specialize_device,
9291
register_stabilize,
9392
register_uncanonicalize,
9493
register_useless,
@@ -2078,12 +2077,14 @@ def local_pow_specialize(fgraph, node):
20782077
return False
20792078

20802079

2081-
@register_specialize_device
2080+
@register_specialize
20822081
@node_rewriter([at_pow])
2083-
def local_pow_specialize_device(fgraph, node):
2084-
"""
2085-
This rewrite is not the same on all device. We do it only on cpu here.
2082+
def local_pow_to_nested_squaring(fgraph, node):
2083+
"""Convert a large power exponent to multiple squaring operations.
2084+
2085+
Note: This sounds like the kind of thing any half-decent compiler can do by itself?
20862086
"""
2087+
20872088
if node.op == at_pow:
20882089
# the idea here is that we have pow(x, y)
20892090
odtype = node.outputs[0].dtype

tests/tensor/rewriting/test_math.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -1672,12 +1672,12 @@ def test_local_pow_specialize():
16721672
utt.assert_allclose(f(val_no0), val_no0 ** (-0.5))
16731673

16741674

1675-
def test_local_pow_specialize_device_more_aggressive_on_cpu():
1675+
def test_local_pow_to_nested_squaring():
16761676
mode = config.mode
16771677
if mode == "FAST_COMPILE":
16781678
mode = "FAST_RUN"
16791679
mode = get_mode(mode)
1680-
mode = mode.excluding("fusion").excluding("gpu")
1680+
mode = mode.excluding("fusion")
16811681

16821682
v = vector()
16831683
val = np.arange(10, dtype=config.floatX)

0 commit comments

Comments
 (0)