Skip to content

Add matrix_transpose and .mT helpers #702

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1982,6 +1982,62 @@ def transpose(x, axes=None):
return ret


def matrix_transpose(x: "TensorLike") -> TensorVariable:
"""
Transposes each 2-dimensional matrix tensor along the last two dimensions of a higher-dimensional tensor.

Parameters
----------
x : array_like
Input tensor with shape (..., M, N), where `M` and `N` represent the dimensions
of the matrices. Each matrix is of shape (M, N).

Returns
-------
out : tensor
Transposed tensor with the shape (..., N, M), where each 2-dimensional matrix
in the input tensor has been transposed along the last two dimensions.

Examples
--------
>>> import pytensor as pt
>>> import numpy as np
>>> x = np.arange(24).reshape((2, 3, 4))
[[[ 0 1 2 3]
[ 4 5 6 7]
[ 8 9 10 11]]

[[12 13 14 15]
[16 17 18 19]
[20 21 22 23]]]


>>> pt.matrix_transpose(x).eval()
[[[ 0 4 8]
[ 1 5 9]
[ 2 6 10]
[ 3 7 11]]

[[12 16 20]
[13 17 21]
[14 18 22]
[15 19 23]]]


Notes
-----
This function transposes each 2-dimensional matrix within the input tensor along
the last two dimensions. If the input tensor has more than two dimensions, it
transposes each 2-dimensional matrix independently while preserving other dimensions.
"""
x = as_tensor_variable(x)
if x.ndim < 2:
raise ValueError(
f"Input array must be at least 2-dimensional, but it is {x.ndim}"
)
return swapaxes(x, -1, -2)


def split(x, splits_size, n_splits, axis=0):
the_split = Split(n_splits)
return the_split(x, axis, splits_size)
Expand Down Expand Up @@ -4302,6 +4358,7 @@ def ix_(*args):
"join",
"split",
"transpose",
"matrix_transpose",
"extract_constant",
"default",
"tensor_copy",
Expand Down
13 changes: 4 additions & 9 deletions pytensor/tensor/rewriting/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import cast

from pytensor.graph.rewriting.basic import copy_stack_trace, node_rewriter
from pytensor.tensor.basic import TensorVariable, diagonal, swapaxes
from pytensor.tensor.basic import TensorVariable, diagonal
from pytensor.tensor.blas import Dot22
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
Expand Down Expand Up @@ -43,11 +43,6 @@ def is_matrix_transpose(x: TensorVariable) -> bool:
return False


def _T(x: TensorVariable) -> TensorVariable:
"""Matrix transpose for potentially higher dimensionality tensors"""
return swapaxes(x, -1, -2)


@register_canonicalize
@node_rewriter([DimShuffle])
def transinv_to_invtrans(fgraph, node):
Expand Down Expand Up @@ -83,9 +78,9 @@ def inv_as_solve(fgraph, node):
):
x = r.owner.inputs[0]
if getattr(x.tag, "symmetric", None) is True:
return [_T(solve(x, _T(l)))]
return [solve(x, (l.mT)).mT]
else:
return [_T(solve(_T(x), _T(l)))]
return [solve((x.mT), (l.mT)).mT]


@register_stabilize
Expand Down Expand Up @@ -216,7 +211,7 @@ def psd_solve_with_chol(fgraph, node):
# __if__ no other Op makes use of the L matrix during the
# stabilization
Li_b = solve_triangular(L, b, lower=True, b_ndim=2)
x = solve_triangular(_T(L), Li_b, lower=False, b_ndim=2)
x = solve_triangular((L.mT), Li_b, lower=False, b_ndim=2)
return [x]


Expand Down
4 changes: 4 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,10 @@ def __trunc__(self):
def T(self):
return pt.basic.transpose(self)

@property
def mT(self):
return pt.basic.matrix_transpose(self)

def transpose(self, *axes):
"""Transpose this array.

Expand Down
18 changes: 18 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3813,6 +3813,7 @@ def test_transpose():
)

t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v)

assert t1.shape == np.transpose(x1v).shape
assert t2.shape == np.transpose(x2v).shape
assert t3.shape == np.transpose(x3v).shape
Expand All @@ -3838,6 +3839,23 @@ def test_transpose():
assert ptb.transpose(dmatrix()).name is None


def test_matrix_transpose():
with pytest.raises(ValueError, match="Input array must be at least 2-dimensional"):
ptb.matrix_transpose(dvector("x1"))

x2 = dmatrix("x2")
x3 = dtensor3("x3")

var1 = ptb.matrix_transpose(x2)
expected_var1 = swapaxes(x2, -1, -2)

var2 = x3.mT
expected_var2 = swapaxes(x3, -1, -2)

assert equal_computations([var1], [expected_var1])
assert equal_computations([var2], [expected_var2])


def test_stacklists():
a, b, c, d = map(scalar, "abcd")
X = stacklists([[a, b], [c, d]])
Expand Down