Skip to content

Commit 49bf1fe

Browse files
committed
Remove duplicated Inv Op
1 parent e58bd91 commit 49bf1fe

File tree

2 files changed

+1
-33
lines changed

2 files changed

+1
-33
lines changed

pytensor/link/numba/dispatch/nlinalg.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
Det,
1515
Eig,
1616
Eigh,
17-
Inv,
1817
MatrixInverse,
1918
MatrixPinv,
2019
QRFull,
@@ -125,18 +124,6 @@ def eigh(x):
125124
return eigh
126125

127126

128-
@numba_funcify.register(Inv)
129-
def numba_funcify_Inv(op, node, **kwargs):
130-
out_dtype = node.outputs[0].type.numpy_dtype
131-
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
132-
133-
@numba_basic.numba_njit(inline="always")
134-
def inv(x):
135-
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
136-
137-
return inv
138-
139-
140127
@numba_funcify.register(MatrixInverse)
141128
def numba_funcify_MatrixInverse(op, node, **kwargs):
142129
out_dtype = node.outputs[0].type.numpy_dtype

pytensor/tensor/nlinalg.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -78,25 +78,6 @@ def pinv(x, hermitian=False):
7878
return MatrixPinv(hermitian=hermitian)(x)
7979

8080

81-
class Inv(Op):
82-
"""Computes the inverse of one or more matrices."""
83-
84-
def make_node(self, x):
85-
x = as_tensor_variable(x)
86-
return Apply(self, [x], [x.type()])
87-
88-
def perform(self, node, inputs, outputs):
89-
(x,) = inputs
90-
(z,) = outputs
91-
z[0] = np.linalg.inv(x).astype(x.dtype)
92-
93-
def infer_shape(self, fgraph, node, shapes):
94-
return shapes
95-
96-
97-
inv = Inv()
98-
99-
10081
class MatrixInverse(Op):
10182
r"""Computes the inverse of a matrix :math:`A`.
10283
@@ -169,7 +150,7 @@ def infer_shape(self, fgraph, node, shapes):
169150
return shapes
170151

171152

172-
matrix_inverse = MatrixInverse()
153+
inv = matrix_inverse = MatrixInverse()
173154

174155

175156
def matrix_dot(*args):

0 commit comments

Comments
 (0)