Skip to content

Commit 2ea5a54

Browse files
Add jax implementation of pt.linalg.pinv (#294)
1 parent a99a7b2 commit 2ea5a54

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

pytensor/link/jax/dispatch/nlinalg.py

+18-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,16 @@
33
from pytensor.link.jax.dispatch import jax_funcify
44
from pytensor.tensor.blas import BatchedDot
55
from pytensor.tensor.math import Dot, MaxAndArgmax
6-
from pytensor.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull, SLogDet
6+
from pytensor.tensor.nlinalg import (
7+
SVD,
8+
Det,
9+
Eig,
10+
Eigh,
11+
MatrixInverse,
12+
MatrixPinv,
13+
QRFull,
14+
SLogDet,
15+
)
716

817

918
@jax_funcify.register(SVD)
@@ -77,6 +86,14 @@ def dot(x, y):
7786
return dot
7887

7988

89+
@jax_funcify.register(MatrixPinv)
90+
def jax_funcify_Pinv(op, **kwargs):
91+
def pinv(x):
92+
return jnp.linalg.pinv(x)
93+
94+
return pinv
95+
96+
8097
@jax_funcify.register(BatchedDot)
8198
def jax_funcify_BatchedDot(op, **kwargs):
8299
def batched_dot(a, b):

tests/link/jax/test_nlinalg.py

+9
Original file line numberDiff line numberDiff line change
@@ -134,3 +134,12 @@ def test_tensor_basics():
134134
out = at_max(y)
135135
fgraph = FunctionGraph([y], [out])
136136
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
137+
138+
139+
def test_pinv():
140+
x = matrix("x")
141+
x_inv = at_nlinalg.pinv(x)
142+
143+
fgraph = FunctionGraph([x], [x_inv])
144+
x_np = np.array([[1.0, 2.0], [3.0, 4.0]])
145+
compare_jax_and_py(fgraph, [x_np])

0 commit comments

Comments
 (0)