Skip to content

Commit 27dbd0c

Browse files
juanitorduzfonnesbeck
authored andcommitted
Harmonize HSGP.prior dimension names and order (pymc-devs#7562)
* suggest dimension improvement * make mypy happy
1 parent 4139efa commit 27dbd0c

File tree

1 file changed

+11
-6
lines changed

1 file changed

+11
-6
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -428,8 +428,8 @@ def prior(
428428
self,
429429
name: str,
430430
X: TensorLike,
431+
dims: str | None = None,
431432
hsgp_coeffs_dims: str | None = None,
432-
gp_dims: str | None = None,
433433
*args,
434434
**kwargs,
435435
):
@@ -444,10 +444,11 @@ def prior(
444444
Name of the random variable
445445
X: array-like
446446
Function input values.
447+
dims: str, default None
448+
Dimension name for the GP random variable.
447449
hsgp_coeffs_dims: str, default None
448450
Dimension name for the HSGP basis vectors.
449-
gp_dims: str, default None
450-
Dimension name for the GP random variable.
451+
451452
"""
452453
phi, sqrt_psd = self.prior_linearized(X)
453454
self._sqrt_psd = sqrt_psd
@@ -469,7 +470,7 @@ def prior(
469470
)
470471
f = self.mean_func(X) + phi @ self._beta
471472

472-
self.f = pm.Deterministic(name, f, dims=gp_dims)
473+
self.f = pm.Deterministic(name, f, dims=dims)
473474
return self.f
474475

475476
def _build_conditional(self, Xnew):
@@ -695,7 +696,9 @@ def prior_linearized(self, X: TensorLike):
695696
psd = self.scale * self.cov_func.power_spectral_density_approx(J)
696697
return (phi_cos, phi_sin), psd
697698

698-
def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ignore[override]
699+
def prior( # type: ignore[override]
700+
self, name: str, X: TensorLike, dims: str | None = None, hsgp_coeffs_dims: str | None = None
701+
):
699702
R"""
700703
Return the (approximate) GP prior distribution evaluated over the input locations `X`.
701704
@@ -709,11 +712,13 @@ def prior(self, name: str, X: TensorLike, dims: str | None = None): # type: ign
709712
Function input values.
710713
dims: None
711714
Dimension name for the GP random variable.
715+
hsgp_coeffs_dims: str | None = None
716+
Dimension name for the HSGPPeriodic basis vectors.
712717
"""
713718
(phi_cos, phi_sin), psd = self.prior_linearized(X)
714719

715720
m = self._m
716-
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2 - 1))
721+
self._beta = pm.Normal(f"{name}_hsgp_coeffs_", size=(m * 2 - 1), dims=hsgp_coeffs_dims)
717722
# The first eigenfunction for the sine component is zero
718723
# and so does not contribute to the approximation.
719724
f = (

0 commit comments

Comments
 (0)