Skip to content

Commit ec64d40

Browse files
committed
Do not predefine custom Cholesky and SolveTriangular Ops
Also standardize pytensor linalg calls
1 parent 5a9c8cd commit ec64d40

File tree

4 files changed

+30
-40
lines changed

4 files changed

+30
-40
lines changed

pymc/distributions/multivariate.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import warnings
1919

20-
from functools import reduce
20+
from functools import partial, reduce
2121
from typing import Optional
2222

2323
import numpy as np
@@ -30,16 +30,17 @@
3030
from pytensor.raise_op import Assert
3131
from pytensor.sparse.basic import sp_sum
3232
from pytensor.tensor import TensorConstant, gammaln, sigmoid
33-
from pytensor.tensor.nlinalg import det, eigh, matrix_inverse, trace
33+
from pytensor.tensor.linalg import cholesky, det, eigh
34+
from pytensor.tensor.linalg import inv as matrix_inverse
35+
from pytensor.tensor.linalg import solve_triangular, trace
3436
from pytensor.tensor.random.basic import dirichlet, multinomial, multivariate_normal
3537
from pytensor.tensor.random.op import RandomVariable
3638
from pytensor.tensor.random.utils import (
3739
broadcast_params,
3840
supp_shape_from_ref_param_shape,
3941
)
40-
from pytensor.tensor.slinalg import Cholesky, SolveTriangular
4142
from pytensor.tensor.type import TensorType
42-
from scipy import linalg, stats
43+
from scipy import stats
4344

4445
import pymc as pm
4546

@@ -93,8 +94,8 @@
9394
"StickBreakingWeights",
9495
]
9596

96-
solve_lower = SolveTriangular(lower=True)
97-
solve_upper = SolveTriangular(lower=False)
97+
solve_lower = partial(solve_triangular, lower=True)
98+
solve_upper = partial(solve_triangular, lower=False)
9899

99100

100101
class SimplexContinuous(Continuous):
@@ -110,7 +111,7 @@ def simplex_cont_transform(op, rv):
110111
# moment. We work around that by using a cholesky op
111112
# that returns a nan as first entry instead of raising
112113
# an error.
113-
cholesky = Cholesky(lower=True, on_error="nan")
114+
nan_lower_cholesky = partial(cholesky, lower=True, on_error="nan")
114115

115116

116117
def quaddist_matrix(cov=None, chol=None, tau=None, lower=True, *args, **kwargs):
@@ -155,7 +156,7 @@ def quaddist_parse(value, mu, cov, mat_type="cov"):
155156
onedim = False
156157

157158
delta = value - mu
158-
chol_cov = cholesky(cov)
159+
chol_cov = nan_lower_cholesky(cov)
159160
if mat_type != "tau":
160161
dist, logdet, ok = quaddist_chol(delta, chol_cov)
161162
else:
@@ -847,9 +848,9 @@ def dist(cls, *args, **kwargs):
847848

848849
def posdef(AA):
849850
try:
850-
linalg.cholesky(AA)
851+
scipy.linalg.cholesky(AA)
851852
return True
852-
except linalg.LinAlgError:
853+
except scipy.linalg.LinAlgError:
853854
return False
854855

855856

@@ -1073,7 +1074,7 @@ def WishartBartlett(name, S, nu, is_cholesky=False, return_cholesky=False, initv
10731074
if initval is not None:
10741075
# Inverse transform
10751076
initval = np.dot(np.dot(np.linalg.inv(L), initval), np.linalg.inv(L.T))
1076-
initval = linalg.cholesky(initval, lower=True)
1077+
initval = scipy.linalg.cholesky(initval, lower=True)
10771078
diag_testval = initval[diag_idx] ** 2
10781079
tril_testval = initval[tril_idx]
10791080
else:
@@ -1785,7 +1786,7 @@ def dist(
17851786
*args,
17861787
**kwargs,
17871788
):
1788-
cholesky = Cholesky(lower=True, on_error="raise")
1789+
lower_cholesky = partial(cholesky, lower=True, on_error="raise")
17891790

17901791
# Among-row matrices
17911792
if len([i for i in [rowcov, rowchol] if i is not None]) != 1:
@@ -1795,7 +1796,7 @@ def dist(
17951796
if rowcov is not None:
17961797
if rowcov.ndim != 2:
17971798
raise ValueError("rowcov must be two dimensional.")
1798-
rowchol_cov = cholesky(rowcov)
1799+
rowchol_cov = lower_cholesky(rowcov)
17991800
else:
18001801
if rowchol.ndim != 2:
18011802
raise ValueError("rowchol must be two dimensional.")
@@ -1810,7 +1811,7 @@ def dist(
18101811
colcov = pt.as_tensor_variable(colcov)
18111812
if colcov.ndim != 2:
18121813
raise ValueError("colcov must be two dimensional.")
1813-
colchol_cov = cholesky(colcov)
1814+
colchol_cov = lower_cholesky(colcov)
18141815
else:
18151816
if colchol.ndim != 2:
18161817
raise ValueError("colchol must be two dimensional.")
@@ -1851,10 +1852,10 @@ def logp(value, mu, rowchol, colchol):
18511852

18521853
# Find exponent piece by piece
18531854
right_quaddist = solve_lower(rowchol, delta)
1854-
quaddist = pt.nlinalg.matrix_dot(right_quaddist.T, right_quaddist)
1855+
quaddist = pt.linalg.matrix_dot(right_quaddist.T, right_quaddist)
18551856
quaddist = solve_lower(colchol, quaddist)
18561857
quaddist = solve_upper(colchol.T, quaddist)
1857-
trquaddist = pt.nlinalg.trace(quaddist)
1858+
trquaddist = pt.linalg.trace(quaddist)
18581859

18591860
coldiag = pt.diag(colchol)
18601861
rowdiag = pt.diag(rowchol)
@@ -1887,7 +1888,7 @@ def rng_fn(self, rng, mu, sigma, *covs, size=None):
18871888
size = size if size else covs[-1]
18881889
covs = covs[:-1] if covs[-1] == size else covs
18891890

1890-
cov = reduce(linalg.kron, covs)
1891+
cov = reduce(scipy.linalg.kron, covs)
18911892

18921893
if sigma:
18931894
cov = cov + sigma**2 * np.eye(cov.shape[0])
@@ -1930,7 +1931,7 @@ class KroneckerNormal(Continuous):
19301931
:math:`[(v_1, Q_1), (v_2, Q_2), ...]` such that
19311932
:math:`K_i = Q_i \text{diag}(v_i) Q_i'`. For example::
19321933
1933-
v_i, Q_i = pt.nlinalg.eigh(K_i)
1934+
v_i, Q_i = pt.linalg.eigh(K_i)
19341935
sigma : scalar, optional
19351936
Standard deviation of the Gaussian white noise.
19361937
@@ -2228,7 +2229,7 @@ def logp(value, mu, W, alpha, tau):
22282229
D = W.sum(axis=0)
22292230
Dinv_sqrt = pt.diag(1 / pt.sqrt(D))
22302231
DWD = pt.dot(pt.dot(Dinv_sqrt, W), Dinv_sqrt)
2231-
lam = pt.slinalg.eigvalsh(DWD, pt.eye(DWD.shape[0]))
2232+
lam = pt.linalg.eigvalsh(DWD, pt.eye(DWD.shape[0]))
22322233

22332234
d, _ = W.shape
22342235

pymc/gp/gp.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,26 +14,28 @@
1414

1515
import warnings
1616

17+
from functools import partial
18+
1719
import numpy as np
1820
import pytensor.tensor as pt
1921

20-
from pytensor.tensor.nlinalg import eigh
22+
from pytensor.tensor.linalg import cholesky, eigh, solve_triangular
2123

2224
import pymc as pm
2325

2426
from pymc.gp.cov import BaseCovariance, Constant
2527
from pymc.gp.mean import Zero
2628
from pymc.gp.util import (
2729
JITTER_DEFAULT,
28-
cholesky,
2930
conditioned_vars,
3031
replace_with_values,
31-
solve_lower,
32-
solve_upper,
3332
stabilize,
3433
)
3534
from pymc.math import cartesian, kron_diag, kron_dot, kron_solve_lower, kron_solve_upper
3635

36+
solve_lower = partial(solve_triangular, lower=True)
37+
solve_upper = partial(solve_triangular, lower=False)
38+
3739
__all__ = ["Latent", "Marginal", "TP", "MarginalApprox", "LatentKron", "MarginalKron"]
3840

3941

pymc/gp/util.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@
1818
import pytensor.tensor as pt
1919

2020
from pytensor.compile import SharedVariable
21-
from pytensor.tensor.slinalg import ( # noqa: W0611; pylint: disable=unused-import
22-
SolveTriangular,
23-
cholesky,
24-
solve,
25-
)
2621
from pytensor.tensor.var import TensorConstant
2722
from scipy.cluster.vq import kmeans
2823

@@ -35,9 +30,6 @@
3530

3631
JITTER_DEFAULT = 1e-6
3732

38-
solve_lower = SolveTriangular(lower=True)
39-
solve_upper = SolveTriangular(lower=False)
40-
4133

4234
def replace_with_values(vars_needed, replacements=None, model=None):
4335
R"""

pymc/math.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -75,14 +75,9 @@
7575
where,
7676
zeros_like,
7777
)
78-
from pytensor.tensor.special import log_softmax, softmax
79-
80-
try:
81-
from pytensor.tensor.basic import extract_diag
82-
except ImportError:
83-
from pytensor.tensor.nlinalg import extract_diag
84-
78+
from pytensor.tensor.linalg import solve_triangular
8579
from pytensor.tensor.nlinalg import matrix_inverse
80+
from pytensor.tensor.special import log_softmax, softmax
8681
from scipy.linalg import block_diag as scipy_block_diag
8782

8883
from pymc.pytensorf import floatX, ix_, largest_common_dtype
@@ -230,8 +225,8 @@ def kron_vector_op(v):
230225

231226
# Define kronecker functions that work on 1D and 2D arrays
232227
kron_dot = partial(kron_matrix_op, op=pt.dot)
233-
kron_solve_lower = partial(kron_matrix_op, op=pt.slinalg.SolveTriangular(lower=True))
234-
kron_solve_upper = partial(kron_matrix_op, op=pt.slinalg.SolveTriangular(lower=False))
228+
kron_solve_lower = partial(kron_matrix_op, op=partial(solve_triangular, lower=True))
229+
kron_solve_upper = partial(kron_matrix_op, op=partial(solve_triangular, lower=False))
235230

236231

237232
def flat_outer(a, b):

0 commit comments

Comments
 (0)