17
17
18
18
import warnings
19
19
20
- from functools import reduce
20
+ from functools import partial , reduce
21
21
from typing import Optional
22
22
23
23
import numpy as np
30
30
from pytensor .raise_op import Assert
31
31
from pytensor .sparse .basic import sp_sum
32
32
from pytensor .tensor import TensorConstant , gammaln , sigmoid
33
- from pytensor .tensor .nlinalg import det , eigh , matrix_inverse , trace
33
+ from pytensor .tensor .linalg import cholesky , det , eigh
34
+ from pytensor .tensor .linalg import inv as matrix_inverse
35
+ from pytensor .tensor .linalg import solve_triangular , trace
34
36
from pytensor .tensor .random .basic import dirichlet , multinomial , multivariate_normal
35
37
from pytensor .tensor .random .op import RandomVariable
36
38
from pytensor .tensor .random .utils import (
37
39
broadcast_params ,
38
40
supp_shape_from_ref_param_shape ,
39
41
)
40
- from pytensor .tensor .slinalg import Cholesky , SolveTriangular
41
42
from pytensor .tensor .type import TensorType
42
- from scipy import linalg , stats
43
+ from scipy import stats
43
44
44
45
import pymc as pm
45
46
93
94
"StickBreakingWeights" ,
94
95
]
95
96
96
- solve_lower = SolveTriangular ( lower = True )
97
- solve_upper = SolveTriangular ( lower = False )
97
+ solve_lower = partial ( solve_triangular , lower = True )
98
+ solve_upper = partial ( solve_triangular , lower = False )
98
99
99
100
100
101
class SimplexContinuous (Continuous ):
@@ -110,7 +111,7 @@ def simplex_cont_transform(op, rv):
110
111
# moment. We work around that by using a cholesky op
111
112
# that returns a nan as first entry instead of raising
112
113
# an error.
113
- cholesky = Cholesky ( lower = True , on_error = "nan" )
114
+ nan_lower_cholesky = partial ( cholesky , lower = True , on_error = "nan" )
114
115
115
116
116
117
def quaddist_matrix (cov = None , chol = None , tau = None , lower = True , * args , ** kwargs ):
@@ -155,7 +156,7 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
155
156
onedim = False
156
157
157
158
delta = value - mu
158
- chol_cov = cholesky (cov )
159
+ chol_cov = nan_lower_cholesky (cov )
159
160
if mat_type != "tau" :
160
161
dist , logdet , ok = quaddist_chol (delta , chol_cov )
161
162
else :
@@ -847,9 +848,9 @@ def dist(cls, *args, **kwargs):
847
848
848
849
def posdef (AA ):
849
850
try :
850
- linalg .cholesky (AA )
851
+ scipy . linalg .cholesky (AA )
851
852
return True
852
- except linalg .LinAlgError :
853
+ except scipy . linalg .LinAlgError :
853
854
return False
854
855
855
856
@@ -1073,7 +1074,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
1073
1074
if initval is not None :
1074
1075
# Inverse transform
1075
1076
initval = np .dot (np .dot (np .linalg .inv (L ), initval ), np .linalg .inv (L .T ))
1076
- initval = linalg .cholesky (initval , lower = True )
1077
+ initval = scipy . linalg .cholesky (initval , lower = True )
1077
1078
diag_testval = initval [diag_idx ] ** 2
1078
1079
tril_testval = initval [tril_idx ]
1079
1080
else :
@@ -1785,7 +1786,7 @@ def dist(
1785
1786
* args ,
1786
1787
** kwargs ,
1787
1788
):
1788
- cholesky = Cholesky ( lower = True , on_error = "raise" )
1789
+ lower_cholesky = partial ( cholesky , lower = True , on_error = "raise" )
1789
1790
1790
1791
# Among-row matrices
1791
1792
if len ([i for i in [rowcov , rowchol ] if i is not None ]) != 1 :
@@ -1795,7 +1796,7 @@ def dist(
1795
1796
if rowcov is not None :
1796
1797
if rowcov .ndim != 2 :
1797
1798
raise ValueError ("rowcov must be two dimensional." )
1798
- rowchol_cov = cholesky (rowcov )
1799
+ rowchol_cov = lower_cholesky (rowcov )
1799
1800
else :
1800
1801
if rowchol .ndim != 2 :
1801
1802
raise ValueError ("rowchol must be two dimensional." )
@@ -1810,7 +1811,7 @@ def dist(
1810
1811
colcov = pt .as_tensor_variable (colcov )
1811
1812
if colcov .ndim != 2 :
1812
1813
raise ValueError ("colcov must be two dimensional." )
1813
- colchol_cov = cholesky (colcov )
1814
+ colchol_cov = lower_cholesky (colcov )
1814
1815
else :
1815
1816
if colchol .ndim != 2 :
1816
1817
raise ValueError ("colchol must be two dimensional." )
@@ -1851,10 +1852,10 @@ def logp(value, mu, rowchol, colchol):
1851
1852
1852
1853
# Find exponent piece by piece
1853
1854
right_quaddist = solve_lower (rowchol , delta )
1854
- quaddist = pt .nlinalg .matrix_dot (right_quaddist .T , right_quaddist )
1855
+ quaddist = pt .linalg .matrix_dot (right_quaddist .T , right_quaddist )
1855
1856
quaddist = solve_lower (colchol , quaddist )
1856
1857
quaddist = solve_upper (colchol .T , quaddist )
1857
- trquaddist = pt .nlinalg .trace (quaddist )
1858
+ trquaddist = pt .linalg .trace (quaddist )
1858
1859
1859
1860
coldiag = pt .diag (colchol )
1860
1861
rowdiag = pt .diag (rowchol )
@@ -1887,7 +1888,7 @@ def rng_fn(self, rng, mu, sigma, *covs, size=None):
1887
1888
size = size if size else covs [- 1 ]
1888
1889
covs = covs [:- 1 ] if covs [- 1 ] == size else covs
1889
1890
1890
- cov = reduce (linalg .kron , covs )
1891
+ cov = reduce (scipy . linalg .kron , covs )
1891
1892
1892
1893
if sigma :
1893
1894
cov = cov + sigma ** 2 * np .eye (cov .shape [0 ])
@@ -1930,7 +1931,7 @@ class KroneckerNormal(Continuous):
1930
1931
:math:`[(v_1, Q_1), (v_2, Q_2), ...]` such that
1931
1932
:math:`K_i = Q_i \text{diag}(v_i) Q_i'`. For example::
1932
1933
1933
- v_i, Q_i = pt.nlinalg .eigh(K_i)
1934
+ v_i, Q_i = pt.linalg .eigh(K_i)
1934
1935
sigma : scalar, optional
1935
1936
Standard deviation of the Gaussian white noise.
1936
1937
@@ -2228,7 +2229,7 @@ def logp(value, mu, W, alpha, tau):
2228
2229
D = W .sum (axis = 0 )
2229
2230
Dinv_sqrt = pt .diag (1 / pt .sqrt (D ))
2230
2231
DWD = pt .dot (pt .dot (Dinv_sqrt , W ), Dinv_sqrt )
2231
- lam = pt .slinalg .eigvalsh (DWD , pt .eye (DWD .shape [0 ]))
2232
+ lam = pt .linalg .eigvalsh (DWD , pt .eye (DWD .shape [0 ]))
2232
2233
2233
2234
d , _ = W .shape
2234
2235
0 commit comments