Skip to content

Commit 0c7e087

Browse files
Added test for mT
1 parent 8e61562 commit 0c7e087

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

tests/tensor/test_basic.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3849,11 +3849,10 @@ def test_matrix_transpose():
38493849
var1 = ptb.matrix_transpose(x2)
38503850
expected_var1 = swapaxes(x2, -1, -2)
38513851

3852-
var2 = ptb.matrix_transpose(x3)
3852+
var2 = x3.mT
38533853
expected_var2 = swapaxes(x3, -1, -2)
38543854

38553855
assert equal_computations([var1], [expected_var1])
3856-
# TODO: Replace np.asarray([np.transpose(x3v[0]), np.transpose(x3v[1])]) with np.matrix_transpose(x3v) once numpy adds support for it (https://github.com/numpy/numpy/pull/24099)
38573856
assert equal_computations([var2], [expected_var2])
38583857

38593858

0 commit comments

Comments
 (0)