Skip to content

Commit d651753

Browse files
committed
Create scalarize rewrite pass
1 parent 1784965 commit d651753

File tree

6 files changed

+94
-41
lines changed

6 files changed

+94
-41
lines changed

pytensor/compile/mode.py

+32-5
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,11 @@ def apply(self, fgraph):
250250
# misc special cases for speed that break canonicalization
251251
optdb.register("uncanonicalize", EquilibriumDB(), "fast_run", position=3)
252252

253+
# Turn tensor operations to scalar operations where possible.
254+
# This is currently marked as numba-only, but this could be changed
255+
# in the future.
256+
optdb.register("scalarize", EquilibriumDB(), "numba_only", position=3.1)
257+
253258
# misc special cases for speed that are dependent on the device.
254259
optdb.register(
255260
"specialize_device", EquilibriumDB(), "fast_compile", "fast_run", position=48.6
@@ -459,20 +464,42 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
459464
# FunctionMaker, the Mode will be taken from this dictionary using the
460465
# string as the key
461466
# Use VM_linker to allow lazy evaluation by default.
462-
FAST_COMPILE = Mode(VMLinker(use_cloop=False, c_thunks=False), "fast_compile")
467+
FAST_COMPILE = Mode(
468+
VMLinker(use_cloop=False, c_thunks=False),
469+
RewriteDatabaseQuery(
470+
include=["fast_compile"],
471+
exclude=["numba_only"],
472+
),
473+
)
463474
if config.cxx:
464-
FAST_RUN = Mode("cvm", "fast_run")
475+
FAST_RUN = Mode(
476+
"cvm",
477+
RewriteDatabaseQuery(
478+
include=["fast_run"],
479+
exclude=["numba_only"],
480+
),
481+
)
465482
else:
466-
FAST_RUN = Mode("vm", "fast_run")
483+
FAST_RUN = Mode(
484+
"vm",
485+
RewriteDatabaseQuery(
486+
include=["fast_run"],
487+
exclude=["numba_only"],
488+
),
489+
)
467490

468491
JAX = Mode(
469492
JAXLinker(),
470-
RewriteDatabaseQuery(include=["fast_run", "jax"], exclude=["cxx_only", "BlasOpt"]),
493+
RewriteDatabaseQuery(
494+
include=["fast_run", "jax"],
495+
exclude=["cxx_only", "BlasOpt", "numba_only"],
496+
),
471497
)
498+
472499
NUMBA = Mode(
473500
NumbaLinker(),
474501
RewriteDatabaseQuery(
475-
include=["fast_run", "fast_run_numba", "fast_compile_numba"],
502+
include=["fast_run", "numba_only"],
476503
exclude=["cxx_only", "BlasOpt"],
477504
),
478505
)

pytensor/tensor/rewriting/basic.py

+37-14
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,23 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
186186
return node_rewriter
187187

188188

189+
def register_scalarize(
190+
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
191+
):
192+
if isinstance(node_rewriter, str):
193+
194+
def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
195+
return register_specialize(inner_rewriter, node_rewriter, *tags, **kwargs)
196+
197+
return register
198+
else:
199+
name = kwargs.pop("name", None) or node_rewriter.__name__
200+
compile.optdb["scalarize"].register(
201+
name, node_rewriter, "fast_run", "fast_compile", *tags, **kwargs
202+
)
203+
return node_rewriter
204+
205+
189206
def register_uncanonicalize(
190207
node_rewriter: Union[RewriteDatabase, NodeRewriter, str], *tags: str, **kwargs
191208
):
@@ -226,30 +243,36 @@ def register(inner_rewriter: Union[RewriteDatabase, Rewriter]):
226243

227244
@register_canonicalize
228245
@register_specialize
246+
@register_scalarize
229247
@node_rewriter([TensorFromScalar])
230248
def local_tensor_scalar_tensor(fgraph, node):
231249
"""tensor_from_scalar(scalar_from_tensor(x)) -> x"""
232-
if isinstance(node.op, TensorFromScalar):
233-
s = node.inputs[0]
234-
if s.owner and isinstance(s.owner.op, ScalarFromTensor):
235-
t = s.owner.inputs[0]
250+
s = node.inputs[0]
251+
if s.owner and isinstance(s.owner.op, ScalarFromTensor):
252+
t = s.owner.inputs[0]
236253

237-
# We don't need to copy over any stack traces here
238-
return [t]
254+
# We don't need to copy over any stack traces here
255+
return [t]
239256

240257

241258
@register_canonicalize
242259
@register_specialize
260+
@register_scalarize
243261
@node_rewriter([ScalarFromTensor])
244262
def local_scalar_tensor_scalar(fgraph, node):
245-
"""scalar_from_tensor(tensor_from_scalar(x)) -> x"""
246-
if isinstance(node.op, ScalarFromTensor):
247-
t = node.inputs[0]
248-
if t.owner and isinstance(t.owner.op, TensorFromScalar):
249-
s = t.owner.inputs[0]
250-
251-
# We don't need to copy over any stack traces here
252-
return [s]
263+
"""scalar_from_tensor(tensor_from_scalar(x)) -> x
264+
265+
and scalar_from_tensor(TensorConstant(x)) -> x
266+
"""
267+
t = node.inputs[0]
268+
if t.owner and isinstance(t.owner.op, TensorFromScalar):
269+
s = t.owner.inputs[0]
270+
271+
# We don't need to copy over any stack traces here
272+
return [s]
273+
if isinstance(t, TensorConstant):
274+
assert t.ndim == 0
275+
return [aes.constant(t.value.item(), t.name, t.dtype)]
253276

254277

255278
@register_specialize("local_alloc_elemwise")

pytensor/tensor/rewriting/elemwise.py

+9-15
Original file line numberDiff line numberDiff line change
@@ -381,12 +381,8 @@ def is_dimshuffle_useless(new_order, input):
381381

382382

383383
@node_rewriter([Elemwise])
384-
def local_elemwise_lift_scalars(fgraph, node):
384+
def elemwise_to_scalar(fgraph, node):
385385
op = node.op
386-
387-
if not isinstance(op, Elemwise):
388-
return False
389-
390386
if not all(input.ndim == 0 for input in node.inputs):
391387
return False
392388

@@ -397,11 +393,11 @@ def local_elemwise_lift_scalars(fgraph, node):
397393
return [as_tensor_variable(out) for out in op.scalar_op.make_node(*scalars).outputs]
398394

399395

400-
compile.optdb["specialize"].register(
401-
"local_elemwise_lift_scalars",
402-
local_elemwise_lift_scalars,
403-
"fast_run_numba",
404-
"fast_compile_numba",
396+
compile.optdb["scalarize"].register(
397+
"local_elemwise_to_scalar",
398+
elemwise_to_scalar,
399+
"fast_run",
400+
"fast_compile",
405401
)
406402

407403

@@ -411,9 +407,6 @@ def push_elemwise_constants(fgraph, node):
411407
contained scalar op.
412408
"""
413409
op = node.op
414-
if not isinstance(op, Elemwise):
415-
return False
416-
417410
if any(op.inplace_pattern):
418411
return False
419412

@@ -467,8 +460,9 @@ def is_constant_scalar(x):
467460
compile.optdb["post_fusion"].register(
468461
"push_elemwise_constants",
469462
push_elemwise_constants,
470-
"fast_run_numba",
471-
"fast_compile_numba",
463+
"fast_run",
464+
"fast_compile",
465+
"numba_only",
472466
)
473467

474468

pytensor/tensor/rewriting/math.py

+13
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
encompasses_broadcastable,
8787
local_fill_sink,
8888
register_canonicalize,
89+
register_scalarize,
8990
register_specialize,
9091
register_specialize_device,
9192
register_stabilize,
@@ -1568,6 +1569,18 @@ def local_op_of_op(fgraph, node):
15681569
return [combined(node_inps.owner.inputs[0])]
15691570

15701571

1572+
@register_scalarize
1573+
@node_rewriter([Sum])
1574+
def local_sum_of_makevector(fgraph, node):
1575+
(array,) = node.inputs
1576+
if not array.owner or not isinstance(array.owner.op, MakeVector):
1577+
return False
1578+
1579+
values = array.owner.inputs
1580+
summed = aes.add(*values)
1581+
return [as_tensor_variable(summed)]
1582+
1583+
15711584
ALL_REDUCE = (
15721585
[
15731586
CAReduce,

pytensor/tensor/rewriting/subtensor.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -469,7 +469,8 @@ def local_subtensor_lift(fgraph, node):
469469
return [rbcast_subt_x]
470470

471471

472-
@register_specialize
472+
@register_stabilize("cxx_only")
473+
@register_canonicalize("cxx_only")
473474
@node_rewriter([Subtensor])
474475
def local_subtensor_merge(fgraph, node):
475476
"""

tests/link/numba/test_scan.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
from pytensor.scan.op import Scan
1111
from pytensor.scan.utils import until
1212
from pytensor.tensor import log, vector
13-
from pytensor.tensor.elemwise import Elemwise
1413
from pytensor.tensor.random.utils import RandomStream
1514
from tests import unittest_tools as utt
1615
from tests.link.numba.test_basic import compare_numba_and_py
@@ -437,8 +436,4 @@ def test_inner_graph_optimized():
437436
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
438437
]
439438
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
440-
assert len(inner_scan_nodes) == 1
441-
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
442-
assert isinstance(inner_scan_node.op, Elemwise) and isinstance(
443-
inner_scan_node.op.scalar_op, Log1p
444-
)
439+
assert any(isinstance(node.op, Log1p) for node in inner_scan_nodes)

0 commit comments

Comments
 (0)