1
- import typing
2
1
from functools import partial
3
- from typing import Callable , Tuple
2
+ from typing import Tuple
4
3
5
4
import numpy as np
6
5
@@ -271,7 +270,6 @@ class Eig(Op):
271
270
272
271
"""
273
272
274
- _numop = staticmethod (np .linalg .eig )
275
273
__props__ : Tuple [str , ...] = ()
276
274
277
275
def make_node (self , x ):
@@ -284,7 +282,7 @@ def make_node(self, x):
284
282
def perform (self , node , inputs , outputs ):
285
283
(x ,) = inputs
286
284
(w , v ) = outputs
287
- w [0 ], v [0 ] = (z .astype (x .dtype ) for z in self . _numop (x ))
285
+ w [0 ], v [0 ] = (z .astype (x .dtype ) for z in np . linalg . eig (x ))
288
286
289
287
def infer_shape (self , fgraph , node , shapes ):
290
288
n = shapes [0 ][0 ]
@@ -300,7 +298,6 @@ class Eigh(Eig):
300
298
301
299
"""
302
300
303
- _numop = typing .cast (Callable , staticmethod (np .linalg .eigh ))
304
301
__props__ = ("UPLO" ,)
305
302
306
303
def __init__ (self , UPLO = "L" ):
@@ -315,15 +312,15 @@ def make_node(self, x):
315
312
# LAPACK. Rather than trying to reproduce the (rather
316
313
# involved) logic, we just probe linalg.eigh with a trivial
317
314
# input.
318
- w_dtype = self . _numop ([[np .dtype (x .dtype ).type ()]])[0 ].dtype .name
315
+ w_dtype = np . linalg . eigh ([[np .dtype (x .dtype ).type ()]])[0 ].dtype .name
319
316
w = vector (dtype = w_dtype )
320
317
v = matrix (dtype = w_dtype )
321
318
return Apply (self , [x ], [w , v ])
322
319
323
320
def perform (self , node , inputs , outputs ):
324
321
(x ,) = inputs
325
322
(w , v ) = outputs
326
- w [0 ], v [0 ] = self . _numop (x , self .UPLO )
323
+ w [0 ], v [0 ] = np . linalg . eigh (x , self .UPLO )
327
324
328
325
def grad (self , inputs , g_outputs ):
329
326
r"""The gradient function should return
@@ -446,7 +443,6 @@ class QRFull(Op):
446
443
447
444
"""
448
445
449
- _numop = staticmethod (np .linalg .qr )
450
446
__props__ = ("mode" ,)
451
447
452
448
def __init__ (self , mode ):
@@ -478,7 +474,7 @@ def make_node(self, x):
478
474
def perform (self , node , inputs , outputs ):
479
475
(x ,) = inputs
480
476
assert x .ndim == 2 , "The input of qr function should be a matrix."
481
- res = self . _numop (x , self .mode )
477
+ res = np . linalg . qr (x , self .mode )
482
478
if self .mode != "r" :
483
479
outputs [0 ][0 ], outputs [1 ][0 ] = res
484
480
else :
@@ -547,7 +543,6 @@ class SVD(Op):
547
543
"""
548
544
549
545
# See doc in the docstring of the function just after this class.
550
- _numop = staticmethod (np .linalg .svd )
551
546
__props__ = ("full_matrices" , "compute_uv" )
552
547
553
548
def __init__ (self , full_matrices = True , compute_uv = True ):
@@ -575,10 +570,10 @@ def perform(self, node, inputs, outputs):
575
570
assert x .ndim == 2 , "The input of svd function should be a matrix."
576
571
if self .compute_uv :
577
572
u , s , vt = outputs
578
- u [0 ], s [0 ], vt [0 ] = self . _numop (x , self .full_matrices , self .compute_uv )
573
+ u [0 ], s [0 ], vt [0 ] = np . linalg . svd (x , self .full_matrices , self .compute_uv )
579
574
else :
580
575
(s ,) = outputs
581
- s [0 ] = self . _numop (x , self .full_matrices , self .compute_uv )
576
+ s [0 ] = np . linalg . svd (x , self .full_matrices , self .compute_uv )
582
577
583
578
def infer_shape (self , fgraph , node , shapes ):
584
579
(x_shape ,) = shapes
@@ -730,7 +725,6 @@ class TensorInv(Op):
730
725
PyTensor utilization of numpy.linalg.tensorinv;
731
726
"""
732
727
733
- _numop = staticmethod (np .linalg .tensorinv )
734
728
__props__ = ("ind" ,)
735
729
736
730
def __init__ (self , ind = 2 ):
@@ -744,7 +738,7 @@ def make_node(self, a):
744
738
def perform (self , node , inputs , outputs ):
745
739
(a ,) = inputs
746
740
(x ,) = outputs
747
- x [0 ] = self . _numop (a , self .ind )
741
+ x [0 ] = np . linalg . tensorinv (a , self .ind )
748
742
749
743
def infer_shape (self , fgraph , node , shapes ):
750
744
sp = shapes [0 ][self .ind :] + shapes [0 ][: self .ind ]
@@ -790,7 +784,6 @@ class TensorSolve(Op):
790
784
791
785
"""
792
786
793
- _numop = staticmethod (np .linalg .tensorsolve )
794
787
__props__ = ("axes" ,)
795
788
796
789
def __init__ (self , axes = None ):
@@ -809,7 +802,7 @@ def perform(self, node, inputs, outputs):
809
802
b ,
810
803
) = inputs
811
804
(x ,) = outputs
812
- x [0 ] = self . _numop (a , b , self .axes )
805
+ x [0 ] = np . linalg . tensorsolve (a , b , self .axes )
813
806
814
807
815
808
def tensorsolve (a , b , axes = None ):
0 commit comments