Skip to content

Commit e58bd91

Browse files
committed
Remove _numop attribute from linalg Ops
1 parent 58de169 commit e58bd91

File tree

1 file changed

+8
-14
lines changed

1 file changed

+8
-14
lines changed

pytensor/tensor/nlinalg.py

+8-14
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,6 @@ class Eig(Op):
270270
271271
"""
272272

273-
_numop = staticmethod(np.linalg.eig)
274273
__props__: Tuple[str, ...] = ()
275274

276275
def make_node(self, x):
@@ -283,7 +282,7 @@ def make_node(self, x):
283282
def perform(self, node, inputs, outputs):
284283
(x,) = inputs
285284
(w, v) = outputs
286-
w[0], v[0] = (z.astype(x.dtype) for z in self._numop(x))
285+
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
287286

288287
def infer_shape(self, fgraph, node, shapes):
289288
n = shapes[0][0]
@@ -299,7 +298,6 @@ class Eigh(Eig):
299298
300299
"""
301300

302-
_numop = staticmethod(np.linalg.eigh)
303301
__props__ = ("UPLO",)
304302

305303
def __init__(self, UPLO="L"):
@@ -314,15 +312,15 @@ def make_node(self, x):
314312
# LAPACK. Rather than trying to reproduce the (rather
315313
# involved) logic, we just probe linalg.eigh with a trivial
316314
# input.
317-
w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name
315+
w_dtype = np.linalg.eigh([[np.dtype(x.dtype).type()]])[0].dtype.name
318316
w = vector(dtype=w_dtype)
319317
v = matrix(dtype=w_dtype)
320318
return Apply(self, [x], [w, v])
321319

322320
def perform(self, node, inputs, outputs):
323321
(x,) = inputs
324322
(w, v) = outputs
325-
w[0], v[0] = self._numop(x, self.UPLO)
323+
w[0], v[0] = np.linalg.eigh(x, self.UPLO)
326324

327325
def grad(self, inputs, g_outputs):
328326
r"""The gradient function should return
@@ -445,7 +443,6 @@ class QRFull(Op):
445443
446444
"""
447445

448-
_numop = staticmethod(np.linalg.qr)
449446
__props__ = ("mode",)
450447

451448
def __init__(self, mode):
@@ -477,7 +474,7 @@ def make_node(self, x):
477474
def perform(self, node, inputs, outputs):
478475
(x,) = inputs
479476
assert x.ndim == 2, "The input of qr function should be a matrix."
480-
res = self._numop(x, self.mode)
477+
res = np.linalg.qr(x, self.mode)
481478
if self.mode != "r":
482479
outputs[0][0], outputs[1][0] = res
483480
else:
@@ -546,7 +543,6 @@ class SVD(Op):
546543
"""
547544

548545
# See doc in the docstring of the function just after this class.
549-
_numop = staticmethod(np.linalg.svd)
550546
__props__ = ("full_matrices", "compute_uv")
551547

552548
def __init__(self, full_matrices=True, compute_uv=True):
@@ -574,10 +570,10 @@ def perform(self, node, inputs, outputs):
574570
assert x.ndim == 2, "The input of svd function should be a matrix."
575571
if self.compute_uv:
576572
u, s, vt = outputs
577-
u[0], s[0], vt[0] = self._numop(x, self.full_matrices, self.compute_uv)
573+
u[0], s[0], vt[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
578574
else:
579575
(s,) = outputs
580-
s[0] = self._numop(x, self.full_matrices, self.compute_uv)
576+
s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
581577

582578
def infer_shape(self, fgraph, node, shapes):
583579
(x_shape,) = shapes
@@ -729,7 +725,6 @@ class TensorInv(Op):
729725
PyTensor utilization of numpy.linalg.tensorinv;
730726
"""
731727

732-
_numop = staticmethod(np.linalg.tensorinv)
733728
__props__ = ("ind",)
734729

735730
def __init__(self, ind=2):
@@ -743,7 +738,7 @@ def make_node(self, a):
743738
def perform(self, node, inputs, outputs):
744739
(a,) = inputs
745740
(x,) = outputs
746-
x[0] = self._numop(a, self.ind)
741+
x[0] = np.linalg.tensorinv(a, self.ind)
747742

748743
def infer_shape(self, fgraph, node, shapes):
749744
sp = shapes[0][self.ind :] + shapes[0][: self.ind]
@@ -789,7 +784,6 @@ class TensorSolve(Op):
789784
790785
"""
791786

792-
_numop = staticmethod(np.linalg.tensorsolve)
793787
__props__ = ("axes",)
794788

795789
def __init__(self, axes=None):
@@ -808,7 +802,7 @@ def perform(self, node, inputs, outputs):
808802
b,
809803
) = inputs
810804
(x,) = outputs
811-
x[0] = self._numop(a, b, self.axes)
805+
x[0] = np.linalg.tensorsolve(a, b, self.axes)
812806

813807

814808
def tensorsolve(a, b, axes=None):

0 commit comments

Comments
 (0)