Skip to content

Commit ca969f2

Browse files
committed
Remove duplicated Inv Op
1 parent e58bd91 commit ca969f2

File tree

3 files changed

+1
-53
lines changed

3 files changed

+1
-53
lines changed

pytensor/link/numba/dispatch/nlinalg.py

-13
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Det,
1515
Eig,
1616
Eigh,
17-
Inv,
1817
MatrixInverse,
1918
MatrixPinv,
2019
QRFull,
@@ -125,18 +124,6 @@ def eigh(x):
125124
return eigh
126125

127126

128-
@numba_funcify.register(Inv)
129-
def numba_funcify_Inv(op, node, **kwargs):
130-
out_dtype = node.outputs[0].type.numpy_dtype
131-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
132-
133-
@numba_basic.numba_njit(inline="always")
134-
def inv(x):
135-
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
136-
137-
return inv
138-
139-
140127
@numba_funcify.register(MatrixInverse)
141128
def numba_funcify_MatrixInverse(op, node, **kwargs):
142129
out_dtype = node.outputs[0].type.numpy_dtype

pytensor/tensor/nlinalg.py

+1-20
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,6 @@ def pinv(x, hermitian=False):
7878
return MatrixPinv(hermitian=hermitian)(x)
7979

8080

81-
class Inv(Op):
82-
"""Computes the inverse of one or more matrices."""
83-
84-
def make_node(self, x):
85-
x = as_tensor_variable(x)
86-
return Apply(self, [x], [x.type()])
87-
88-
def perform(self, node, inputs, outputs):
89-
(x,) = inputs
90-
(z,) = outputs
91-
z[0] = np.linalg.inv(x).astype(x.dtype)
92-
93-
def infer_shape(self, fgraph, node, shapes):
94-
return shapes
95-
96-
97-
inv = Inv()
98-
99-
10081
class MatrixInverse(Op):
10182
r"""Computes the inverse of a matrix :math:`A`.
10283
@@ -169,7 +150,7 @@ def infer_shape(self, fgraph, node, shapes):
169150
return shapes
170151

171152

172-
matrix_inverse = MatrixInverse()
153+
inv = matrix_inverse = MatrixInverse()
173154

174155

175156
def matrix_dot(*args):

tests/link/numba/test_nlinalg.py

-20
Original file line numberDiff line numberDiff line change
@@ -352,26 +352,6 @@ def test_Eigh(x, uplo, exc):
352352
None,
353353
(),
354354
),
355-
(
356-
nlinalg.Inv,
357-
set_test_value(
358-
at.dmatrix(),
359-
(lambda x: x.T.dot(x))(rng.random(size=(3, 3)).astype("float64")),
360-
),
361-
None,
362-
(),
363-
),
364-
(
365-
nlinalg.Inv,
366-
set_test_value(
367-
at.lmatrix(),
368-
(lambda x: x.T.dot(x))(
369-
rng.integers(1, 10, size=(3, 3)).astype("int64")
370-
),
371-
),
372-
None,
373-
(),
374-
),
375355
(
376356
nlinalg.MatrixPinv,
377357
set_test_value(

0 commit comments

Comments
 (0)