Skip to content

Commit cc443b6

Browse files
committed
Rewrite batched dots that do not reduce as multiplication
1 parent 92ebf60 commit cc443b6

File tree

3 files changed

+146
-9
lines changed

3 files changed

+146
-9
lines changed

pytensor/tensor/math.py

+34-8
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
stack,
3030
switch,
3131
)
32-
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
32+
from pytensor.tensor.blockwise import Blockwise
3333
from pytensor.tensor.elemwise import (
3434
CAReduce,
3535
Elemwise,
@@ -2726,6 +2726,22 @@ def logsumexp(x, axis=None, keepdims=False):
27262726
return log(sum(exp(x), axis=axis, keepdims=keepdims))
27272727

27282728

2729+
# Predefine all batched variations of Dot
2730+
_inner_prod = Blockwise(
2731+
_dot,
2732+
signature="(n),(n)->()",
2733+
)
2734+
2735+
_matrix_vec_prod = Blockwise(
2736+
_dot,
2737+
signature="(m,k),(k)->(m)",
2738+
)
2739+
2740+
_vec_matrix_prod = Blockwise(
2741+
_dot,
2742+
signature="(k),(k,n)->(n)",
2743+
)
2744+
27292745
_matrix_matrix_matmul = Blockwise(
27302746
_dot,
27312747
signature="(m,k),(k,n)->(m,n)",
@@ -2795,14 +2811,24 @@ def matmul(x1: "ArrayLike", x2: "ArrayLike", dtype: Optional["DTypeLike"] = None
27952811

27962812

27972813
@_vectorize_node.register(Dot)
2798-
def vectorize_node_dot_to_matmul(op, node, batched_x, batched_y):
2814+
def vectorize_node_dot(op, node, batched_x, batched_y):
27992815
old_x, old_y = node.inputs
2800-
if old_x.type.ndim == 2 and old_y.type.ndim == 2:
2801-
# If original input is equivalent to a matrix-matrix product,
2802-
# return specialized Matmul Op to avoid unnecessary new Ops.
2803-
return matmul(batched_x, batched_y).owner
2804-
else:
2805-
return vectorize_node_fallback(op, node, batched_x, batched_y)
2816+
old_x_ndim = old_x.type.ndim
2817+
old_y_ndim = old_y.type.ndim
2818+
match (old_x_ndim, old_y_ndim):
2819+
case (1, 1):
2820+
batch_op = _inner_prod
2821+
case (2, 1):
2822+
batch_op = _matrix_vec_prod
2823+
case (1, 2):
2824+
batch_op = _vec_matrix_prod
2825+
case (2, 2):
2826+
batch_op = _matrix_matrix_matmul
2827+
case _:
2828+
raise ValueError(
2829+
f"Core dot Op should have 1D or 2D inputs, got {old_x_ndim}D and {old_y_ndim}D."
2830+
)
2831+
return batch_op(batched_x, batched_y).owner
28062832

28072833

28082834
def nan_to_num(x, nan=0.0, posinf=None, neginf=None):

pytensor/tensor/rewriting/math.py

+60
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,10 @@
4444
Prod,
4545
Sum,
4646
_conj,
47+
_inner_prod,
48+
_matrix_matrix_matmul,
49+
_matrix_vec_prod,
50+
_vec_matrix_prod,
4751
add,
4852
digamma,
4953
dot,
@@ -242,6 +246,62 @@ def local_batched_matmul_to_core_matmul(fgraph, node):
242246
return None
243247

244248

249+
@register_canonicalize
250+
@register_specialize
251+
@node_rewriter([_inner_prod, _matrix_vec_prod, _vec_matrix_prod, _matrix_matrix_matmul])
252+
def local_blockwise_dot_to_mul(fgraph, node):
253+
"""Rewrite blockwise dots that correspond to multiplication without summation.
254+
255+
We don't touch the regular dot, to not interfere with the BLAS optimizations.
256+
"""
257+
a, b = node.inputs
258+
a_static_shape = a.type.shape
259+
b_static_shape = b.type.shape
260+
core_a_ndim = len(node.op.inputs_sig[0])
261+
core_b_ndim = len(node.op.inputs_sig[1])
262+
263+
if core_a_ndim > 2 or core_b_ndim > 2:
264+
# Shouldn't happen, but here just in case
265+
return None
266+
267+
if core_b_ndim == 1:
268+
if a_static_shape[-1] == 1 or b_static_shape[-1] == 1:
269+
if core_a_ndim == 1:
270+
# inner product: (..., 1) * (..., 1) -> (...)
271+
# just squeeze the last dimensions of a and b
272+
new_a = a.squeeze(-1)
273+
new_b = b.squeeze(-1)
274+
else:
275+
# matrix vector product: (..., m, 1) * (..., 1) -> (..., m)
276+
# the last dimension of b is already aligned for the elemwise multiplication
277+
# after we squeeze the last dimension of a
278+
new_a = a.squeeze(-1)
279+
new_b = b
280+
else:
281+
return None
282+
283+
else:
284+
if a_static_shape[-1] == 1 or b_static_shape[-2] == 1:
285+
if core_a_ndim == 1:
286+
# vector_matrix product: (..., 1) * (..., 1, n) -> (..., n)
287+
# the last dimension of a is already aligned for the elemwise multiplication
288+
# after we squeeze the one to last dimension of b
289+
new_a = a
290+
new_b = b.squeeze(-2)
291+
else:
292+
# matrix matrix product: (..., m, 1) * (..., 1, n) -> (..., m, n)
293+
# the dimensions of a and b are already aligned for the elemwise multiplication
294+
new_a = a
295+
new_b = b
296+
else:
297+
return None
298+
299+
new_a = copy_stack_trace(a, new_a)
300+
new_b = copy_stack_trace(b, new_b)
301+
new_out = copy_stack_trace(node.out, mul(new_a, new_b))
302+
return [new_out]
303+
304+
245305
def is_inverse_pair(node_op, prev_op, inv_pair):
246306
"""
247307
Given two consecutive operations, check if they are the

tests/tensor/rewriting/test_math.py

+52-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@
1616
from pytensor.compile.mode import Mode, get_default_mode, get_mode
1717
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
1818
from pytensor.configdefaults import config
19-
from pytensor.graph.basic import Apply, equal_computations
19+
from pytensor.graph import vectorize_graph
20+
from pytensor.graph.basic import Apply, ancestors, equal_computations
2021
from pytensor.graph.fg import FunctionGraph
2122
from pytensor.graph.rewriting.basic import (
2223
SequentialNodeRewriter,
@@ -4571,3 +4572,53 @@ def test_log_kv_stabilization():
45714572
out.eval({x: 1000.0}, mode=mode),
45724573
-1003.2180912984705,
45734574
)
4575+
4576+
4577+
@pytest.mark.parametrize(
4578+
"a_shape,b_shape",
4579+
[
4580+
((1,), (1,)),
4581+
((3, 1), (1,)),
4582+
((1,), (1, 3)),
4583+
((3, 1), (1, 3)),
4584+
],
4585+
ids=str,
4586+
)
4587+
@pytest.mark.parametrize("batched", (False, True))
4588+
def test_local_dot_to_mul(batched, a_shape, b_shape):
4589+
a = tensor("a", shape=a_shape)
4590+
b = tensor("b", shape=b_shape)
4591+
4592+
out = dot(a, b)
4593+
if batched:
4594+
batch_a = tensor("batch_a", shape=(1, 5, *a_shape))
4595+
batch_b = tensor("batch_b", shape=(7, 1, *b_shape))
4596+
out = vectorize_graph(out, {a: batch_a, b: batch_b})
4597+
a = batch_a
4598+
b = batch_b
4599+
4600+
assert (
4601+
sum(
4602+
isinstance(var.owner.op, (Blockwise | Dot))
4603+
for var in ancestors([out])
4604+
if var.owner
4605+
)
4606+
== 1
4607+
)
4608+
4609+
# For now rewrite only applies to Batched Dots
4610+
rewritten_out = rewrite_graph(out)
4611+
assert rewritten_out.type.shape == out.type.shape
4612+
assert sum(
4613+
isinstance(var.owner.op, (Blockwise | Dot))
4614+
for var in ancestors([rewritten_out])
4615+
if var.owner
4616+
) == (0 if batched else 1)
4617+
4618+
a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
4619+
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
4620+
test_mode = Mode(linker="py", optimizer=None)
4621+
np.testing.assert_allclose(
4622+
out.eval({a: a_test, b: b_test}, mode=test_mode),
4623+
rewritten_out.eval({a: a_test, b: b_test}, mode=test_mode),
4624+
)

0 commit comments

Comments
 (0)