Skip to content

Commit 4bbe4c6

Browse files
committed
Remove _numop attribute from linalg Ops
1 parent c88b70b commit 4bbe4c6

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

pytensor/tensor/nlinalg.py

+9-16
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
1-
import typing
21
from functools import partial
3-
from typing import Callable, Tuple
2+
from typing import Tuple
43

54
import numpy as np
65

@@ -271,7 +270,6 @@ class Eig(Op):
271270
272271
"""
273272

274-
_numop = staticmethod(np.linalg.eig)
275273
__props__: Tuple[str, ...] = ()
276274

277275
def make_node(self, x):
@@ -284,7 +282,7 @@ def make_node(self, x):
284282
def perform(self, node, inputs, outputs):
285283
(x,) = inputs
286284
(w, v) = outputs
287-
w[0], v[0] = (z.astype(x.dtype) for z in self._numop(x))
285+
w[0], v[0] = (z.astype(x.dtype) for z in np.linalg.eig(x))
288286

289287
def infer_shape(self, fgraph, node, shapes):
290288
n = shapes[0][0]
@@ -300,7 +298,6 @@ class Eigh(Eig):
300298
301299
"""
302300

303-
_numop = typing.cast(Callable, staticmethod(np.linalg.eigh))
304301
__props__ = ("UPLO",)
305302

306303
def __init__(self, UPLO="L"):
@@ -315,15 +312,15 @@ def make_node(self, x):
315312
# LAPACK. Rather than trying to reproduce the (rather
316313
# involved) logic, we just probe linalg.eigh with a trivial
317314
# input.
318-
w_dtype = self._numop([[np.dtype(x.dtype).type()]])[0].dtype.name
315+
w_dtype = np.linalg.eigh([[np.dtype(x.dtype).type()]])[0].dtype.name
319316
w = vector(dtype=w_dtype)
320317
v = matrix(dtype=w_dtype)
321318
return Apply(self, [x], [w, v])
322319

323320
def perform(self, node, inputs, outputs):
324321
(x,) = inputs
325322
(w, v) = outputs
326-
w[0], v[0] = self._numop(x, self.UPLO)
323+
w[0], v[0] = np.linalg.eigh(x, self.UPLO)
327324

328325
def grad(self, inputs, g_outputs):
329326
r"""The gradient function should return
@@ -446,7 +443,6 @@ class QRFull(Op):
446443
447444
"""
448445

449-
_numop = staticmethod(np.linalg.qr)
450446
__props__ = ("mode",)
451447

452448
def __init__(self, mode):
@@ -478,7 +474,7 @@ def make_node(self, x):
478474
def perform(self, node, inputs, outputs):
479475
(x,) = inputs
480476
assert x.ndim == 2, "The input of qr function should be a matrix."
481-
res = self._numop(x, self.mode)
477+
res = np.linalg.qr(x, self.mode)
482478
if self.mode != "r":
483479
outputs[0][0], outputs[1][0] = res
484480
else:
@@ -547,7 +543,6 @@ class SVD(Op):
547543
"""
548544

549545
# See doc in the docstring of the function just after this class.
550-
_numop = staticmethod(np.linalg.svd)
551546
__props__ = ("full_matrices", "compute_uv")
552547

553548
def __init__(self, full_matrices=True, compute_uv=True):
@@ -575,10 +570,10 @@ def perform(self, node, inputs, outputs):
575570
assert x.ndim == 2, "The input of svd function should be a matrix."
576571
if self.compute_uv:
577572
u, s, vt = outputs
578-
u[0], s[0], vt[0] = self._numop(x, self.full_matrices, self.compute_uv)
573+
u[0], s[0], vt[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
579574
else:
580575
(s,) = outputs
581-
s[0] = self._numop(x, self.full_matrices, self.compute_uv)
576+
s[0] = np.linalg.svd(x, self.full_matrices, self.compute_uv)
582577

583578
def infer_shape(self, fgraph, node, shapes):
584579
(x_shape,) = shapes
@@ -730,7 +725,6 @@ class TensorInv(Op):
730725
PyTensor utilization of numpy.linalg.tensorinv;
731726
"""
732727

733-
_numop = staticmethod(np.linalg.tensorinv)
734728
__props__ = ("ind",)
735729

736730
def __init__(self, ind=2):
@@ -744,7 +738,7 @@ def make_node(self, a):
744738
def perform(self, node, inputs, outputs):
745739
(a,) = inputs
746740
(x,) = outputs
747-
x[0] = self._numop(a, self.ind)
741+
x[0] = np.linalg.tensorinv(a, self.ind)
748742

749743
def infer_shape(self, fgraph, node, shapes):
750744
sp = shapes[0][self.ind :] + shapes[0][: self.ind]
@@ -790,7 +784,6 @@ class TensorSolve(Op):
790784
791785
"""
792786

793-
_numop = staticmethod(np.linalg.tensorsolve)
794787
__props__ = ("axes",)
795788

796789
def __init__(self, axes=None):
@@ -809,7 +802,7 @@ def perform(self, node, inputs, outputs):
809802
b,
810803
) = inputs
811804
(x,) = outputs
812-
x[0] = self._numop(a, b, self.axes)
805+
x[0] = np.linalg.tensorsolve(a, b, self.axes)
813806

814807

815808
def tensorsolve(a, b, axes=None):

0 commit comments

Comments
 (0)