Skip to content

Commit ebba64e

Browse files
committed
Fix pathfinder wrapper
1 parent b08610c commit ebba64e

File tree

6 files changed

+50
-66
lines changed

6 files changed

+50
-66
lines changed

conda-envs/environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ dependencies:
1111
- xhistogram
1212
- statsmodels
1313
- pip:
14-
- pymc>=5.8.1 # CI was failing to resolve
14+
- pymc>=5.9.0 # CI was failing to resolve
1515
- blackjax
1616
- scikit-learn

conda-envs/windows-environment-test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,5 +10,5 @@ dependencies:
1010
- xhistogram
1111
- statsmodels
1212
- pip:
13-
- pymc>=5.8.1 # CI was failing to resolve
13+
- pymc>=5.9.0 # CI was failing to resolve
1414
- scikit-learn

pymc_experimental/inference/fit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,10 @@ def fit(method, **kwargs):
3131
"""
3232
if method == "pathfinder":
3333
try:
34-
from pymc_experimental.inference.pathfinder import fit_pathfinder
34+
import blackjax
3535
except ImportError as exc:
3636
raise RuntimeError("Need BlackJAX to use `pathfinder`") from exc
37+
38+
from pymc_experimental.inference.pathfinder import fit_pathfinder
39+
3740
return fit_pathfinder(**kwargs)

pymc_experimental/inference/pathfinder.py

Lines changed: 36 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,39 +12,37 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
1615
import collections
1716
import sys
1817
from typing import Optional
1918

2019
import arviz as az
2120
import blackjax
2221
import jax
23-
import jax.numpy as jnp
24-
import jax.random as random
2522
import numpy as np
2623
import pymc as pm
27-
from pymc import modelcontext
24+
from packaging import version
25+
from pymc.backends.arviz import coords_and_dims_for_inferencedata
26+
from pymc.blocking import DictToArrayBijection, RaveledVars
27+
from pymc.model import modelcontext
2828
from pymc.sampling.jax import get_jaxified_graph
2929
from pymc.util import RandomSeed, _get_seeds_per_chain, get_default_varnames
3030

3131

3232
def convert_flat_trace_to_idata(
3333
samples,
34-
dims=None,
35-
coords=None,
3634
include_transformed=False,
3735
postprocessing_backend="cpu",
3836
model=None,
3937
):
4038

4139
model = modelcontext(model)
42-
init_position_dict = model.initial_point()
40+
ip = model.initial_point()
41+
ip_point_map_info = pm.blocking.DictToArrayBijection.map(ip).point_map_info
4342
trace = collections.defaultdict(list)
44-
astart = pm.blocking.DictToArrayBijection.map(init_position_dict)
4543
for sample in samples:
46-
raveld_vars = pm.blocking.RaveledVars(sample, astart.point_map_info)
47-
point = pm.blocking.DictToArrayBijection.rmap(raveld_vars, init_position_dict)
44+
raveld_vars = RaveledVars(sample, ip_point_map_info)
45+
point = DictToArrayBijection.rmap(raveld_vars, ip)
4846
for p, v in point.items():
4947
trace[p].append(v.tolist())
5048

@@ -57,19 +55,19 @@ def convert_flat_trace_to_idata(
5755
result = jax.vmap(jax.vmap(jax_fn))(
5856
*jax.device_put(list(trace.values()), jax.devices(postprocessing_backend)[0])
5957
)
60-
6158
trace = {v.name: r for v, r in zip(vars_to_sample, result)}
59+
coords, dims = coords_and_dims_for_inferencedata(model)
6260
idata = az.from_dict(trace, dims=dims, coords=coords)
6361

6462
return idata
6563

6664

6765
def fit_pathfinder(
68-
iterations=5_000,
66+
samples=1000,
6967
random_seed: Optional[RandomSeed] = None,
7068
postprocessing_backend="cpu",
71-
ftol=1e-4,
7269
model=None,
70+
**pathfinder_kwargs,
7371
):
7472
"""
7573
Fit the pathfinder algorithm as implemented in blackjax
@@ -78,15 +76,15 @@ def fit_pathfinder(
7876
7977
Parameters
8078
----------
81-
iterations : int
82-
Number of iterations to run.
79+
samples : int
80+
Number of samples to draw from the fitted approximation.
8381
random_seed : int
8482
Random seed to set.
8583
postprocessing_backend : str
8684
Where to compute transformations of the trace.
8785
"cpu" or "gpu".
88-
ftol : float
89-
Floating point tolerance
86+
pathfinder_kwargs:
87+
kwargs for blackjax.vi.pathfinder.approximate
9088
9189
Returns
9290
-------
@@ -96,53 +94,42 @@ def fit_pathfinder(
9694
---------
9795
https://arxiv.org/abs/2108.03782
9896
"""
99-
100-
(random_seed,) = _get_seeds_per_chain(random_seed, 1)
97+
# Temporarily helper
98+
if version.parse(blackjax.__version__).major < 1:
99+
raise ImportError("fit_pathfinder requires blackjax 1.0 or above")
101100

102101
model = modelcontext(model)
103102

104-
rvs = [rv.name for rv in model.value_vars]
105-
init_position_dict = model.initial_point()
106-
init_position = [init_position_dict[rv] for rv in rvs]
103+
ip = model.initial_point()
104+
ip_map = DictToArrayBijection.map(ip)
107105

108106
new_logprob, new_input = pm.pytensorf.join_nonshared_inputs(
109-
init_position_dict, (model.logp(),), model.value_vars, ()
107+
ip, (model.logp(),), model.value_vars, ()
110108
)
111109

112110
logprob_fn_list = get_jaxified_graph([new_input], new_logprob)
113111

114112
def logprob_fn(x):
115113
return logprob_fn_list(x)[0]
116114

117-
dim = sum(v.size for v in init_position_dict.values())
118-
119-
rng_key = random.PRNGKey(random_seed)
120-
w0 = random.multivariate_normal(rng_key, 2.0 + jnp.zeros(dim), jnp.eye(dim))
121-
path = blackjax.vi.pathfinder.init(rng_key, logprob_fn, w0, return_path=True, ftol=ftol)
122-
123-
pathfinder = blackjax.kernels.pathfinder(rng_key, logprob_fn, ftol=ftol)
124-
state = pathfinder.init(w0)
125-
126-
def inference_loop(rng_key, kernel, initial_state, num_samples):
127-
@jax.jit
128-
def one_step(state, rng_key):
129-
state, info = kernel(rng_key, state)
130-
return state, (state, info)
115+
[pathfinder_seed, sample_seed] = _get_seeds_per_chain(random_seed, 2)
131116

132-
keys = jax.random.split(rng_key, num_samples)
133-
return jax.lax.scan(one_step, initial_state, keys)
134-
135-
_, rng_key = random.split(rng_key)
136117
print("Running pathfinder...", file=sys.stdout)
137-
_, (_, samples) = inference_loop(rng_key, pathfinder.step, state, iterations)
138-
139-
dims = {
140-
var_name: [dim for dim in dims if dim is not None]
141-
for var_name, dims in model.named_vars_to_dims.items()
142-
}
118+
pathfinder_state, _ = blackjax.vi.pathfinder.approximate(
119+
rng_key=jax.random.key(pathfinder_seed),
120+
logdensity_fn=logprob_fn,
121+
initial_position=ip_map.data,
122+
**pathfinder_kwargs,
123+
)
124+
samples, _ = blackjax.vi.pathfinder.sample(
125+
rng_key=jax.random.key(sample_seed),
126+
state=pathfinder_state,
127+
num_samples=samples,
128+
)
143129

144130
idata = convert_flat_trace_to_idata(
145-
samples, postprocessing_backend=postprocessing_backend, coords=model.coords, dims=dims
131+
samples,
132+
postprocessing_backend=postprocessing_backend,
133+
model=model,
146134
)
147-
148135
return idata

pymc_experimental/tests/test_pathfinder.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,7 @@
2121
import pymc_experimental as pmx
2222

2323

24-
# TODO: Remove this filterwarning after pytensor uses jnp.prod instead of jnp.product
2524
@pytest.mark.skipif(sys.platform == "win32", reason="JAX not supported on windows.")
26-
@pytest.mark.skipif(
27-
sys.version_info < (3, 10), reason="pymc.sampling.jax does not currently support python < 3.10"
28-
)
29-
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
3025
def test_pathfinder():
3126
# Data of the Eight Schools Model
3227
J = 8
@@ -41,12 +36,11 @@ def test_pathfinder():
4136
theta = pm.Normal("theta", mu=0, sigma=1, shape=J)
4237
obs = pm.Normal("obs", mu=mu + tau * theta, sigma=sigma, shape=J, observed=y)
4338

44-
idata = pmx.fit(method="pathfinder", iterations=100)
39+
idata = pmx.fit(method="pathfinder", random_seed=41)
4540

46-
assert idata is not None
47-
assert "theta" in idata.posterior._variables.keys()
48-
assert "tau" in idata.posterior._variables.keys()
49-
assert "mu" in idata.posterior._variables.keys()
50-
assert idata.posterior["mu"].shape == (1, 100)
51-
assert idata.posterior["tau"].shape == (1, 100)
52-
assert idata.posterior["theta"].shape == (1, 100, 8)
41+
assert idata.posterior["mu"].shape == (1, 1000)
42+
assert idata.posterior["tau"].shape == (1, 1000)
43+
assert idata.posterior["theta"].shape == (1, 1000, 8)
44+
# FIXME: pathfinder doesn't find a reasonable mean! Fix bug or choose model pathfinder can handle
45+
# np.testing.assert_allclose(idata.posterior["mu"].mean(), 5.0)
46+
np.testing.assert_allclose(idata.posterior["tau"].mean(), 4.15, atol=0.5)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
pymc>=5.8.1
1+
pymc>=5.8.2
22
scikit-learn

0 commit comments

Comments
 (0)