Skip to content

Commit 6b7dcd4

Browse files
committed
Rewrite dots as multiplication without summation
1 parent 92ebf60 commit 6b7dcd4

File tree

6 files changed

+166
-16
lines changed

6 files changed

+166
-16
lines changed

pytensor/tensor/math.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
stack,
3030
switch,
3131
)
32-
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
32+
from pytensor.tensor.blockwise import Blockwise
3333
from pytensor.tensor.elemwise import (
3434
CAReduce,
3535
Elemwise,
@@ -2220,7 +2220,7 @@ def outer(x, y):
22202220
x = x.flatten()
22212221
if y.ndim != 1:
22222222
y = y.flatten()
2223-
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
2223+
return mul.outer(x, y)
22242224

22252225

22262226
class All(FixedOpCAReduce):
@@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False):
27262726
return log(sum(exp(x), axis=axis, keepdims=keepdims))
27272727

27282728

2729+
# Predefine all batched variations of Dot
2730+
_inner_prod = Blockwise(
2731+
_dot,
2732+
signature="(n),(n)->()",
2733+
)
2734+
2735+
_matrix_vec_prod = Blockwise(
2736+
_dot,
2737+
signature="(m,k),(k)->(m)",
2738+
)
2739+
2740+
_vec_matrix_prod = Blockwise(
2741+
_dot,
2742+
signature="(k),(k,n)->(n)",
2743+
)
2744+
27292745
_matrix_matrix_matmul = Blockwise(
27302746
_dot,
27312747
signature="(m,k),(k,n)->(m,n)",
@@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
27952811

27962812

27972813
@_vectorize_node.register(Dot)
2798-
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
2814+
def vectorize_node_dot(op, node, batched_x, batched_y):
27992815
old_x, old_y = node.inputs
2800-
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
2801-
# If original input is equivalent to a matrix-matrix product,
2802-
# return specialized Matmul Op to avoid unnecessary new Ops.
2803-
return matmul(batched_x, batched_y).owner
2804-
else:
2805-
return vectorize_node_fallback(op, node, batched_x, batched_y)
2816+
old_x_ndim = old_x.type.ndim
2817+
old_y_ndim = old_y.type.ndim
2818+
match (old_x_ndim, old_y_ndim):
2819+
case (1, 1):
2820+
batch_op = _inner_prod
2821+
case (2, 1):
2822+
batch_op = _matrix_vec_prod
2823+
case (1, 2):
2824+
batch_op = _vec_matrix_prod
2825+
case (2, 2):
2826+
batch_op = _matrix_matrix_matmul
2827+
case _:
2828+
raise ValueError(
2829+
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
2830+
)
2831+
return batch_op(batched_x, batched_y).owner
28062832

28072833

28082834
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):

pytensor/tensor/rewriting/math.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
Prod,
4545
Sum,
4646
_conj,
47+
_dot,
48+
_inner_prod,
49+
_matrix_matrix_matmul,
50+
_matrix_vec_prod,
51+
_vec_matrix_prod,
4752
add,
4853
digamma,
4954
dot,
@@ -242,6 +247,66 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
242247
return None
243248

244249

250+
@register_canonicalize
251+
@register_specialize
252+
@node_rewriter(
253+
[_dot, _inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]
254+
)
255+
def local_dot_to_mul(fgraph, node):
256+
"""Rewrite dots that correspond to multiplication without summation."""
257+
a, b = node.inputs
258+
a_st_shape = a.type.shape
259+
b_st_shape = b.type.shape
260+
if isinstance(node.op, Dot):
261+
core_a_ndim = a.type.ndim
262+
core_b_ndim = b.type.ndim
263+
else:
264+
# Blockwise variants of Dot
265+
core_a_ndim = len(node.op.inputs_sig[0])
266+
core_b_ndim = len(node.op.inputs_sig[1])
267+
268+
if core_a_ndim > 2 or core_b_ndim > 2:
269+
# Shouldn't happen, but here just in case
270+
return None
271+
272+
if core_b_ndim == 1:
273+
if a_st_shape[-1] == 1 or b_st_shape[-1] == 1:
274+
if core_a_ndim == 1:
275+
# inner product: (..., 1) * (..., 1) -> (...)
276+
# just squeeze the last dimensions of a and b
277+
new_a = a.squeeze(-1)
278+
new_b = b.squeeze(-1)
279+
else:
280+
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
281+
# the last dimension b is already aligned for the elemwise multiplication
282+
# after we squeeze the last dimension of a
283+
new_a = a.squeeze(-1)
284+
new_b = b
285+
else:
286+
return None
287+
288+
else:
289+
if a_st_shape[-1] == 1 or b_st_shape[-2] == 1:
290+
if core_a_ndim == 1:
291+
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
292+
# the last dimension of a is already aligned for the elemwise multiplication
293+
# after we squeeze the one to last dimension of b
294+
new_a = a
295+
new_b = b.squeeze(-2)
296+
else:
297+
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
298+
# the dimensions of a and b are already aligned for the elemwise multiplication
299+
new_a = a
300+
new_b = b
301+
else:
302+
return None
303+
304+
new_a = copy_stack_trace(a, new_a)
305+
new_b = copy_stack_trace(b, new_b)
306+
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
307+
return [new_out]
308+
309+
245310
def is_inverse_pair(node_op, prev_op, inv_pair):
246311
"""
247312
Given two consecutive operations, check if they are the

tests/compile/test_profiling.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
import pytensor.tensor as pt
9-
from pytensor.compile import ProfileStats
9+
from pytensor.compile import ProfileStats, get_mode
1010
from pytensor.compile.function import function
1111
from pytensor.configdefaults import config
1212
from pytensor.ifelse import ifelse
@@ -28,7 +28,10 @@ def test_profiling(self):
2828
x = [fvector(f"val{i}") for i in range(3)]
2929

3030
z = []
31-
z += [pt.outer(x[i], x[i + 1]).sum(axis=1) for i in range(len(x) - 1)]
31+
z += [
32+
pt.dot(x[i][:, None], x[i + 1][None, :]).sum(axis=1)
33+
for i in range(len(x) - 1)
34+
]
3235
z += [x[i] + x[i + 1] for i in range(len(x) - 1)]
3336

3437
p = ProfileStats(False, gpu_checks=False)
@@ -38,6 +41,9 @@ def test_profiling(self):
3841
else:
3942
m = None
4043

44+
# This test requires an unoptimized outer mul written as a dot
45+
m = get_mode(m).excluding("local_dot_to_mul")
46+
4147
f = function(x, z, profile=p, name="test_profiling", mode=m)
4248

4349
inp = [np.arange(1024, dtype="float32") + 1 for i in range(len(x))]

tests/tensor/rewriting/test_math.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1717
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1818
from pytensor.configdefaults import config
19-
from pytensor.graph.basic import Apply, equal_computations
19+
from pytensor.graph import vectorize_graph
20+
from pytensor.graph.basic import Apply, ancestors, equal_computations
2021
from pytensor.graph.fg import FunctionGraph
2122
from pytensor.graph.rewriting.basic import (
2223
SequentialNodeRewriter,
@@ -4571,3 +4572,51 @@ def test_log_kv_stabilization():
45714572
out.eval({x: 1000.0}, mode=mode),
45724573
-1003.2180912984705,
45734574
)
4575+
4576+
4577+
@pytest.mark.parametrize(
4578+
"a_shape,b_shape",
4579+
[
4580+
((1,), (1,)),
4581+
((3, 1), (1,)),
4582+
((1,), (1, 3)),
4583+
((3, 1), (1, 3)),
4584+
],
4585+
)
4586+
@pytest.mark.parametrize("batched", (False, True))
4587+
def test_local_dot_to_mul(batched, a_shape, b_shape):
4588+
a = tensor("a", shape=a_shape)
4589+
b = tensor("b", shape=b_shape)
4590+
4591+
out = dot(a, b)
4592+
if batched:
4593+
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
4594+
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
4595+
out = vectorize_graph(out, {a: batch_a, b: batch_b})
4596+
a = batch_a
4597+
b = batch_b
4598+
4599+
assert (
4600+
sum(
4601+
isinstance(var.owner.op, (Blockwise | Dot))
4602+
for var in ancestors([out])
4603+
if var.owner
4604+
)
4605+
== 1
4606+
)
4607+
4608+
rewritten_out = rewrite_graph(out)
4609+
assert rewritten_out.type.shape == out.type.shape
4610+
assert not any(
4611+
isinstance(var.owner.op, (Blockwise | Dot))
4612+
for var in ancestors([rewritten_out])
4613+
if var.owner
4614+
)
4615+
4616+
a_test = np.random.normal(size=a.type.shape)
4617+
b_test = np.random.normal(size=b.type.shape)
4618+
test_mode = Mode(linker="py", optimizer=None)
4619+
np.testing.assert_allclose(
4620+
out.eval({a: a_test, b: b_test}, mode=test_mode),
4621+
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
4622+
)

tests/tensor/test_basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,9 @@ def test_alloc_constant_folding(self):
770770
self.allocs,
771771
[
772772
# IncSubtensor1
773-
(some_matrix[:60], 2),
773+
(some_matrix[:60], 1),
774774
# AdvancedIncSubtensor1
775-
(some_matrix[arange(60)], 2),
775+
(some_matrix[arange(60)], 1),
776776
# AdvancedIncSubtensor
777777
(some_matrix[idx, idx], 1),
778778
],

tests/tensor/test_blas.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,7 @@ class TestGer(unittest_tools.OptimizationTestMixin):
17231723

17241724
def setup_method(self):
17251725
self.mode = pytensor.compile.get_default_mode().including("fast_run")
1726-
self.mode = self.mode.excluding("c_blas", "scipy_blas")
1726+
self.mode = self.mode.excluding("c_blas", "scipy_blas", "local_dot_to_mul")
17271727
dtype = self.dtype = "float64" # optimization isn't dtype-dependent
17281728
self.A = tensor(dtype=dtype, shape=(None, None))
17291729
self.a = tensor(dtype=dtype, shape=())
@@ -1795,7 +1795,11 @@ def test_b_nonconst_does_not_triggers_ger(self):
17951795

17961796
def test_outer(self):
17971797
rng = np.random.default_rng(unittest_tools.fetch_seed())
1798-
f = self.function([self.x, self.y], outer(self.x, self.y))
1798+
f = self.function(
1799+
[self.x, self.y],
1800+
# Old outer used to be written like this
1801+
pt.dot(self.x[:, None], self.y[None, :]),
1802+
)
17991803
self.assertFunctionContains(f, self.ger_destructive)
18001804
f(
18011805
rng.random(5).astype(self.dtype),

0 commit comments

Comments
 (0)