Skip to content

Commit 9a7bad2

Browse files
committed
Rewrite dots as multiplication without summation
1 parent 92ebf60 commit 9a7bad2

File tree

8 files changed

+210
-36
lines changed

8 files changed

+210
-36
lines changed

pytensor/tensor/math.py

+35-9
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
stack,
3030
switch,
3131
)
32-
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
32+
from pytensor.tensor.blockwise import Blockwise
3333
from pytensor.tensor.elemwise import (
3434
CAReduce,
3535
Elemwise,
@@ -2220,7 +2220,7 @@ def outer(x, y):
22202220
x = x.flatten()
22212221
if y.ndim != 1:
22222222
y = y.flatten()
2223-
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
2223+
return mul.outer(x, y)
22242224

22252225

22262226
class All(FixedOpCAReduce):
@@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False):
27262726
return log(sum(exp(x), axis=axis, keepdims=keepdims))
27272727

27282728

2729+
# Predefine all batched variations of Dot
2730+
_inner_prod = Blockwise(
2731+
_dot,
2732+
signature="(n),(n)->()",
2733+
)
2734+
2735+
_matrix_vec_prod = Blockwise(
2736+
_dot,
2737+
signature="(m,k),(k)->(m)",
2738+
)
2739+
2740+
_vec_matrix_prod = Blockwise(
2741+
_dot,
2742+
signature="(k),(k,n)->(n)",
2743+
)
2744+
27292745
_matrix_matrix_matmul = Blockwise(
27302746
_dot,
27312747
signature="(m,k),(k,n)->(m,n)",
@@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
27952811

27962812

27972813
@_vectorize_node.register(Dot)
2798-
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
2814+
def vectorize_node_dot(op, node, batched_x, batched_y):
27992815
old_x, old_y = node.inputs
2800-
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
2801-
# If original input is equivalent to a matrix-matrix product,
2802-
# return specialized Matmul Op to avoid unnecessary new Ops.
2803-
return matmul(batched_x, batched_y).owner
2804-
else:
2805-
return vectorize_node_fallback(op, node, batched_x, batched_y)
2816+
old_x_ndim = old_x.type.ndim
2817+
old_y_ndim = old_y.type.ndim
2818+
match (old_x_ndim, old_y_ndim):
2819+
case (1, 1):
2820+
batch_op = _inner_prod
2821+
case (2, 1):
2822+
batch_op = _matrix_vec_prod
2823+
case (1, 2):
2824+
batch_op = _vec_matrix_prod
2825+
case (2, 2):
2826+
batch_op = _matrix_matrix_matmul
2827+
case _:
2828+
raise ValueError(
2829+
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
2830+
)
2831+
return batch_op(batched_x, batched_y).owner
28062832

28072833

28082834
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):

pytensor/tensor/rewriting/math.py

+65
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,11 @@
4444
Prod,
4545
Sum,
4646
_conj,
47+
_dot,
48+
_inner_prod,
49+
_matrix_matrix_matmul,
50+
_matrix_vec_prod,
51+
_vec_matrix_prod,
4752
add,
4853
digamma,
4954
dot,
@@ -242,6 +247,66 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
242247
return None
243248

244249

250+
@register_canonicalize
251+
@register_specialize
252+
@node_rewriter(
253+
[_dot, _inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul]
254+
)
255+
def local_dot_to_mul(fgraph, node):
256+
"""Rewrite dots that correspond to multiplication without summation."""
257+
a, b = node.inputs
258+
a_st_shape = a.type.shape
259+
b_st_shape = b.type.shape
260+
if isinstance(node.op, Dot):
261+
core_a_ndim = a.type.ndim
262+
core_b_ndim = b.type.ndim
263+
else:
264+
# Blockwise variants of Dot
265+
core_a_ndim = len(node.op.inputs_sig[0])
266+
core_b_ndim = len(node.op.inputs_sig[1])
267+
268+
if core_a_ndim > 2 or core_b_ndim > 2:
269+
# Shouldn't happen, but here just in case
270+
return None
271+
272+
if core_b_ndim == 1:
273+
if a_st_shape[-1] == 1 or b_st_shape[-1] == 1:
274+
if core_a_ndim == 1:
275+
# inner product: (..., 1) * (..., 1) -> (...)
276+
# just squeeze the last dimensions of a and b
277+
new_a = a.squeeze(-1)
278+
new_b = b.squeeze(-1)
279+
else:
280+
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
281+
# the last dimension b is already aligned for the elemwise multiplication
282+
# after we squeeze the last dimension of a
283+
new_a = a.squeeze(-1)
284+
new_b = b
285+
else:
286+
return None
287+
288+
else:
289+
if a_st_shape[-1] == 1 or b_st_shape[-2] == 1:
290+
if core_a_ndim == 1:
291+
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
292+
# the last dimension of a is already aligned for the elemwise multiplication
293+
# after we squeeze the one to last dimension of b
294+
new_a = a
295+
new_b = b.squeeze(-2)
296+
else:
297+
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
298+
# the dimensions of a and b are already aligned for the elemwise multiplication
299+
new_a = a
300+
new_b = b
301+
else:
302+
return None
303+
304+
new_a = copy_stack_trace(a, new_a)
305+
new_b = copy_stack_trace(b, new_b)
306+
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
307+
return [new_out]
308+
309+
245310
def is_inverse_pair(node_op, prev_op, inv_pair):
246311
"""
247312
Given two consecutive operations, check if they are the

tests/compile/test_profiling.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
import pytensor.tensor as pt
9-
from pytensor.compile import ProfileStats
9+
from pytensor.compile import ProfileStats, get_mode
1010
from pytensor.compile.function import function
1111
from pytensor.configdefaults import config
1212
from pytensor.ifelse import ifelse
@@ -28,7 +28,10 @@ def test_profiling(self):
2828
x = [fvector(f"val{i}") for i in range(3)]
2929

3030
z = []
31-
z += [pt.outer(x[i], x[i + 1]).sum(axis=1) for i in range(len(x) - 1)]
31+
z += [
32+
pt.dot(x[i][:, None], x[i + 1][None, :]).sum(axis=1)
33+
for i in range(len(x) - 1)
34+
]
3235
z += [x[i] + x[i + 1] for i in range(len(x) - 1)]
3336

3437
p = ProfileStats(False, gpu_checks=False)
@@ -38,6 +41,9 @@ def test_profiling(self):
3841
else:
3942
m = None
4043

44+
# This test requires an unoptimized outer mul written as a dot
45+
m = get_mode(m).excluding("local_dot_to_mul")
46+
4147
f = function(x, z, profile=p, name="test_profiling", mode=m)
4248

4349
inp = [np.arange(1024, dtype="float32") + 1 for i in range(len(x))]

tests/tensor/rewriting/test_math.py

+50-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1717
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1818
from pytensor.configdefaults import config
19-
from pytensor.graph.basic import Apply, equal_computations
19+
from pytensor.graph import vectorize_graph
20+
from pytensor.graph.basic import Apply, ancestors, equal_computations
2021
from pytensor.graph.fg import FunctionGraph
2122
from pytensor.graph.rewriting.basic import (
2223
SequentialNodeRewriter,
@@ -4571,3 +4572,51 @@ def test_log_kv_stabilization():
45714572
out.eval({x: 1000.0}, mode=mode),
45724573
-1003.2180912984705,
45734574
)
4575+
4576+
4577+
@pytest.mark.parametrize(
4578+
"a_shape,b_shape",
4579+
[
4580+
((1,), (1,)),
4581+
((3, 1), (1,)),
4582+
((1,), (1, 3)),
4583+
((3, 1), (1, 3)),
4584+
],
4585+
)
4586+
@pytest.mark.parametrize("batched", (False, True))
4587+
def test_local_dot_to_mul(batched, a_shape, b_shape):
4588+
a = tensor("a", shape=a_shape)
4589+
b = tensor("b", shape=b_shape)
4590+
4591+
out = dot(a, b)
4592+
if batched:
4593+
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
4594+
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
4595+
out = vectorize_graph(out, {a: batch_a, b: batch_b})
4596+
a = batch_a
4597+
b = batch_b
4598+
4599+
assert (
4600+
sum(
4601+
isinstance(var.owner.op, (Blockwise | Dot))
4602+
for var in ancestors([out])
4603+
if var.owner
4604+
)
4605+
== 1
4606+
)
4607+
4608+
rewritten_out = rewrite_graph(out)
4609+
assert rewritten_out.type.shape == out.type.shape
4610+
assert not any(
4611+
isinstance(var.owner.op, (Blockwise | Dot))
4612+
for var in ancestors([rewritten_out])
4613+
if var.owner
4614+
)
4615+
4616+
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
4617+
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
4618+
test_mode = Mode(linker="py", optimizer=None)
4619+
np.testing.assert_allclose(
4620+
out.eval({a: a_test, b: b_test}, mode=test_mode),
4621+
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
4622+
)

tests/tensor/test_basic.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -770,9 +770,9 @@ def test_alloc_constant_folding(self):
770770
self.allocs,
771771
[
772772
# IncSubtensor1
773-
(some_matrix[:60], 2),
773+
(some_matrix[:60], 1),
774774
# AdvancedIncSubtensor1
775-
(some_matrix[arange(60)], 2),
775+
(some_matrix[arange(60)], 1),
776776
# AdvancedIncSubtensor
777777
(some_matrix[idx, idx], 1),
778778
],

tests/tensor/test_blas.py

+25-10
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@
4040
ger,
4141
ger_destructive,
4242
)
43-
from pytensor.tensor.math import Dot, dot, mean, mul, outer, sigmoid
43+
from pytensor.tensor.math import Dot, dot, mean, mul, sigmoid
4444
from pytensor.tensor.rewriting.blas import local_dot22_to_dot22scalar, local_gemm_to_ger
4545
from pytensor.tensor.type import (
4646
cmatrix,
@@ -1721,9 +1721,12 @@ def clone(self, op):
17211721
class TestGer(unittest_tools.OptimizationTestMixin):
17221722
shared = staticmethod(shared)
17231723

1724+
def outer_via_dot(self, x, y):
1725+
return pt.dot(x[:, None], y[None, :])
1726+
17241727
def setup_method(self):
17251728
self.mode = pytensor.compile.get_default_mode().including("fast_run")
1726-
self.mode = self.mode.excluding("c_blas", "scipy_blas")
1729+
self.mode = self.mode.excluding("c_blas", "scipy_blas", "local_dot_to_mul")
17271730
dtype = self.dtype = "float64" # optimization isn't dtype-dependent
17281731
self.A = tensor(dtype=dtype, shape=(None, None))
17291732
self.a = tensor(dtype=dtype, shape=())
@@ -1795,7 +1798,7 @@ def test_b_nonconst_does_not_triggers_ger(self):
17951798

17961799
def test_outer(self):
17971800
rng = np.random.default_rng(unittest_tools.fetch_seed())
1798-
f = self.function([self.x, self.y], outer(self.x, self.y))
1801+
f = self.function([self.x, self.y], self.outer_via_dot(self.x, self.y))
17991802
self.assertFunctionContains(f, self.ger_destructive)
18001803
f(
18011804
rng.random(5).astype(self.dtype),
@@ -1804,7 +1807,9 @@ def test_outer(self):
18041807

18051808
def test_A_plus_outer(self):
18061809
rng = np.random.default_rng(unittest_tools.fetch_seed())
1807-
f = self.function([self.A, self.x, self.y], self.A + outer(self.x, self.y))
1810+
f = self.function(
1811+
[self.A, self.x, self.y], self.A + self.outer_via_dot(self.x, self.y)
1812+
)
18081813
self.assertFunctionContains(f, self.ger)
18091814
f(
18101815
rng.random((5, 4)).astype(self.dtype),
@@ -1820,7 +1825,7 @@ def test_A_plus_outer(self):
18201825
def test_A_plus_scaled_outer(self):
18211826
rng = np.random.default_rng(unittest_tools.fetch_seed())
18221827
f = self.function(
1823-
[self.A, self.x, self.y], self.A + 0.1 * outer(self.x, self.y)
1828+
[self.A, self.x, self.y], self.A + 0.1 * self.outer_via_dot(self.x, self.y)
18241829
)
18251830
self.assertFunctionContains(f, self.ger)
18261831
f(
@@ -1839,7 +1844,7 @@ def test_scaled_A_plus_scaled_outer(self):
18391844
f = self.function(
18401845
[self.A, self.x, self.y],
18411846
np.asarray(0.2, self.dtype) * self.A
1842-
+ np.asarray(0.1, self.dtype) * outer(self.x, self.y),
1847+
+ np.asarray(0.1, self.dtype) * self.outer_via_dot(self.x, self.y),
18431848
)
18441849
# Why gemm? This make the graph simpler did we test that it
18451850
# make it faster?
@@ -1863,7 +1868,7 @@ def given_dtype(self, dtype, M, N, *, destructive=True):
18631868
x = tensor(dtype=dtype, shape=(None,))
18641869
y = tensor(dtype=dtype, shape=(None,))
18651870

1866-
f = self.function([A, x, y], A + 0.1 * outer(x, y))
1871+
f = self.function([A, x, y], A + 0.1 * self.outer_via_dot(x, y))
18671872
self.assertFunctionContains(
18681873
f, self.ger_destructive if destructive else self.ger
18691874
)
@@ -1923,7 +1928,12 @@ def test_inplace(self):
19231928
[self.x, self.y],
19241929
[],
19251930
updates=[
1926-
(A, A + pt.constant(0.1, dtype=self.dtype) * outer(self.x, self.y))
1931+
(
1932+
A,
1933+
A
1934+
+ pt.constant(0.1, dtype=self.dtype)
1935+
* self.outer_via_dot(self.x, self.y),
1936+
)
19271937
],
19281938
)
19291939
self.assertFunctionContains(f, self.ger_destructive)
@@ -2264,10 +2274,15 @@ def cmp_ger(self, a_shp, b_shp, c_shp, rng):
22642274
b_dev = b.get_value(borrow=False, return_internal_type=True)
22652275
c_dev = c.get_value(borrow=False, return_internal_type=True)
22662276

2267-
f_n = function([], [], updates=[(a, (a + l * outer(b, c)))], mode=self.mode)
2277+
f_n = function(
2278+
[], [], updates=[(a, (a + l * self.outer_via_dot(b, c)))], mode=self.mode
2279+
)
22682280

22692281
f_t = function(
2270-
[], [], updates=[(a_t, (a_t + l * outer(b, c).T))], mode=self.mode
2282+
[],
2283+
[],
2284+
updates=[(a_t, (a_t + l * self.outer_via_dot(b, c).T))],
2285+
mode=self.mode,
22712286
)
22722287

22732288
# Try with all stride patterns, and all transposed patterns

0 commit comments

Comments
 (0)