Skip to content

Commit 5e612ab

Browse files
Add matrix_transpose and .mT property helpers (#702)
1 parent 0a13fbd commit 5e612ab

File tree

4 files changed

+83
-9
lines changed

4 files changed

+83
-9
lines changed

pytensor/tensor/basic.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1982,6 +1982,62 @@ def transpose(x, axes=None):
19821982
return ret
19831983

19841984

1985+
def matrix_transpose(x: "TensorLike") -> TensorVariable:
1986+
"""
1987+
Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor.
1988+
1989+
Parameters
1990+
----------
1991+
x : array_like
1992+
Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions
1993+
of the matrices. Each matrix is of shape (M, N).
1994+
1995+
Returns
1996+
-------
1997+
out : tensor
1998+
Transposed tensor with the shape (..., N, M), where each 2-dimensional matrix
1999+
in the input tensor has been transposed along the last two dimensions.
2000+
2001+
Examples
2002+
--------
2003+
>>> import pytensor as pt
2004+
>>> import numpy as np
2005+
>>> x = np.arange(24).reshape((2, 3, 4))
2006+
[[[ 0 1 2 3]
2007+
[ 4 5 6 7]
2008+
[ 8 9 10 11]]
2009+
2010+
[[12 13 14 15]
2011+
[16 17 18 19]
2012+
[20 21 22 23]]]
2013+
2014+
2015+
>>> pt.matrix_transpose(x).eval()
2016+
[[[ 0 4 8]
2017+
[ 1 5 9]
2018+
[ 2 6 10]
2019+
[ 3 7 11]]
2020+
2021+
[[12 16 20]
2022+
[13 17 21]
2023+
[14 18 22]
2024+
[15 19 23]]]
2025+
2026+
2027+
Notes
2028+
-----
2029+
This function transposes each 2-dimensional matrix within the input tensor along
2030+
the last two dimensions. If the input tensor has more than two dimensions, it
2031+
transposes each 2-dimensional matrix independently while preserving other dimensions.
2032+
"""
2033+
x = as_tensor_variable(x)
2034+
if x.ndim < 2:
2035+
raise ValueError(
2036+
f"Input array must be at least 2-dimensional, but it is {x.ndim}"
2037+
)
2038+
return swapaxes(x, -1, -2)
2039+
2040+
19852041
def split(x, splits_size, n_splits, axis=0):
19862042
the_split = Split(n_splits)
19872043
return the_split(x, axis, splits_size)
@@ -4302,6 +4358,7 @@ def ix_(*args):
43024358
"join",
43034359
"split",
43044360
"transpose",
4361+
"matrix_transpose",
43054362
"extract_constant",
43064363
"default",
43074364
"tensor_copy",

pytensor/tensor/rewriting/linalg.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from typing import cast
33

44
from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
5-
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
5+
from pytensor.tensor.basic import TensorVariable, diagonal
66
from pytensor.tensor.blas import Dot22
77
from pytensor.tensor.blockwise import Blockwise
88
from pytensor.tensor.elemwise import DimShuffle
@@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
4343
return False
4444

4545

46-
def _T(x: TensorVariable) -> TensorVariable:
47-
"""Matrix transpose for potentially higher dimensionality tensors"""
48-
return swapaxes(x, -1, -2)
49-
50-
5146
@register_canonicalize
5247
@node_rewriter([DimShuffle])
5348
def transinv_to_invtrans(fgraph, node):
@@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node):
8378
):
8479
x = r.owner.inputs[0]
8580
if getattr(x.tag, "symmetric", None) is True:
86-
return [_T(solve(x, _T(l)))]
81+
return [solve(x, (l.mT)).mT]
8782
else:
88-
return [_T(solve(_T(x), _T(l)))]
83+
return [solve((x.mT), (l.mT)).mT]
8984

9085

9186
@register_stabilize
@@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node):
216211
# __if__ no other Op makes use of the L matrix during the
217212
# stabilization
218213
Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
219-
x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2)
214+
x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2)
220215
return [x]
221216

222217

pytensor/tensor/variable.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,10 @@ def __trunc__(self):
232232
def T(self):
233233
return pt.basic.transpose(self)
234234

235+
@property
236+
def mT(self):
237+
return pt.basic.matrix_transpose(self)
238+
235239
def transpose(self, *axes):
236240
"""Transpose this array.
237241

tests/tensor/test_basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3813,6 +3813,7 @@ def test_transpose():
38133813
)
38143814

38153815
t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v)
3816+
38163817
assert t1.shape == np.transpose(x1v).shape
38173818
assert t2.shape == np.transpose(x2v).shape
38183819
assert t3.shape == np.transpose(x3v).shape
@@ -3838,6 +3839,23 @@ def test_transpose():
38383839
assert ptb.transpose(dmatrix()).name is None
38393840

38403841

3842+
def test_matrix_transpose():
3843+
with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"):
3844+
ptb.matrix_transpose(dvector("x1"))
3845+
3846+
x2 = dmatrix("x2")
3847+
x3 = dtensor3("x3")
3848+
3849+
var1 = ptb.matrix_transpose(x2)
3850+
expected_var1 = swapaxes(x2, -1, -2)
3851+
3852+
var2 = x3.mT
3853+
expected_var2 = swapaxes(x3, -1, -2)
3854+
3855+
assert equal_computations([var1], [expected_var1])
3856+
assert equal_computations([var2], [expected_var2])
3857+
3858+
38413859
def test_stacklists():
38423860
a, b, c, d = map(scalar, "abcd")
38433861
X = stacklists([[a, b], [c, d]])

0 commit comments

Comments
 (0)