Skip to content

Commit 570e6e8

Browse files
authored
speed up posterior predictive sampling (#6208)
* refactor` dataset_to_point_list` for higher performance * allow specifying `sample_dims` in posterior predictive
1 parent a025059 commit 570e6e8

11 files changed

+101
-82
lines changed

conda-envs/environment-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
# Base dependencies
88
- aeppl=0.0.38
99
- aesara=2.8.7
10-
- arviz>=0.12.0
10+
- arviz>=0.13.0
1111
- blas
1212
- cachetools>=4.2.1
1313
- cloudpickle

conda-envs/environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
# Base dependencies
88
- aeppl=0.0.38
99
- aesara=2.8.7
10-
- arviz>=0.12.0
10+
- arviz>=0.13.0
1111
- blas
1212
- cachetools>=4.2.1
1313
- cloudpickle

conda-envs/windows-environment-dev.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
# Base dependencies (see install guide for Windows)
88
- aeppl=0.0.38
99
- aesara=2.8.7
10-
- arviz>=0.12.0
10+
- arviz>=0.13.0
1111
- blas
1212
- cachetools>=4.2.1
1313
- cloudpickle

conda-envs/windows-environment-test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ dependencies:
77
# Base dependencies (see install guide for Windows)
88
- aeppl=0.0.38
99
- aesara=2.8.7
10-
- arviz>=0.12.0
10+
- arviz>=0.13.0
1111
- blas
1212
- cachetools>=4.2.1
1313
- cloudpickle

pymc/backends/arviz.py

+21-22
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
Any,
88
Dict,
99
Iterable,
10+
List,
1011
Mapping,
1112
Optional,
1213
Sequence,
@@ -15,7 +16,6 @@
1516
)
1617

1718
import numpy as np
18-
import xarray as xr
1919

2020
from aesara.graph.basic import Constant
2121
from aesara.tensor.sharedvar import SharedVariable
@@ -162,6 +162,7 @@ def __init__(
162162
predictions=None,
163163
coords: Optional[CoordSpec] = None,
164164
dims: Optional[DimSpec] = None,
165+
sample_dims: Optional[List] = None,
165166
model=None,
166167
save_warmup: Optional[bool] = None,
167168
include_transformed: bool = False,
@@ -225,6 +226,9 @@ def __init__(
225226
for var_name, dims in self.model.RV_dims.items()
226227
}
227228
self.dims = {**model_dims, **self.dims}
229+
if sample_dims is None:
230+
sample_dims = ["chain", "draw"]
231+
self.sample_dims = sample_dims
228232

229233
self.observations = find_observations(self.model)
230234

@@ -423,36 +427,27 @@ def log_likelihood_to_xarray(self):
423427
),
424428
)
425429

426-
def translate_posterior_predictive_dict_to_xarray(self, dct, kind) -> xr.Dataset:
427-
"""Take Dict of variables to numpy ndarrays (samples) and translate into dataset."""
428-
data = {}
429-
warning_vars = []
430-
for k, ary in dct.items():
431-
if (ary.shape[0] == self.nchains) and (ary.shape[1] == self.ndraws):
432-
data[k] = ary
433-
else:
434-
data[k] = np.expand_dims(ary, 0)
435-
warning_vars.append(k)
436-
if warning_vars:
437-
warnings.warn(
438-
f"The shape of variables {', '.join(warning_vars)} in {kind} group is not compatible "
439-
"with number of chains and draws. The automatic dimension naming might not have worked. "
440-
"This can also mean that some draws or even whole chains are not represented.",
441-
UserWarning,
442-
)
443-
return dict_to_dataset(data, library=pymc, coords=self.coords, dims=self.dims)
430+
return dict_to_dataset(
431+
data, library=pymc, coords=self.coords, dims=self.dims, default_dims=self.sample_dims
432+
)
444433

445434
@requires(["posterior_predictive"])
446435
def posterior_predictive_to_xarray(self):
447436
"""Convert posterior_predictive samples to xarray."""
448-
return self.translate_posterior_predictive_dict_to_xarray(
449-
self.posterior_predictive, "posterior_predictive"
437+
data = self.posterior_predictive
438+
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
439+
return dict_to_dataset(
440+
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
450441
)
451442

452443
@requires(["predictions"])
453444
def predictions_to_xarray(self):
454445
"""Convert predictions (out of sample predictions) to xarray."""
455-
return self.translate_posterior_predictive_dict_to_xarray(self.predictions, "predictions")
446+
data = self.predictions
447+
dims = {var_name: self.sample_dims + self.dims.get(var_name, []) for var_name in data}
448+
return dict_to_dataset(
449+
data, library=pymc, coords=self.coords, dims=dims, default_dims=self.sample_dims
450+
)
456451

457452
def priors_to_xarray(self):
458453
"""Convert prior samples (and if possible prior predictive too) to xarray."""
@@ -541,6 +536,7 @@ def to_inference_data(
541536
log_likelihood: Union[bool, Iterable[str]] = True,
542537
coords: Optional[CoordSpec] = None,
543538
dims: Optional[DimSpec] = None,
539+
sample_dims: Optional[List] = None,
544540
model: Optional["Model"] = None,
545541
save_warmup: Optional[bool] = None,
546542
include_transformed: bool = False,
@@ -594,6 +590,7 @@ def to_inference_data(
594590
log_likelihood=log_likelihood,
595591
coords=coords,
596592
dims=dims,
593+
sample_dims=sample_dims,
597594
model=model,
598595
save_warmup=save_warmup,
599596
include_transformed=include_transformed,
@@ -608,6 +605,7 @@ def predictions_to_inference_data(
608605
model: Optional["Model"] = None,
609606
coords: Optional[CoordSpec] = None,
610607
dims: Optional[DimSpec] = None,
608+
sample_dims: Optional[List] = None,
611609
idata_orig: Optional[InferenceData] = None,
612610
inplace: bool = False,
613611
) -> InferenceData:
@@ -653,6 +651,7 @@ def predictions_to_inference_data(
653651
model=model,
654652
coords=coords,
655653
dims=dims,
654+
sample_dims=sample_dims,
656655
log_likelihood=False,
657656
)
658657
if hasattr(idata_orig, "posterior"):

pymc/sampling.py

+39-22
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,6 @@
7575
from pymc.step_methods.arraystep import BlockedStep, PopulationArrayStepShared
7676
from pymc.step_methods.hmc import quadpotential
7777
from pymc.util import (
78-
chains_and_samples,
7978
dataset_to_point_list,
8079
get_default_varnames,
8180
get_untransformed_name,
@@ -1765,6 +1764,7 @@ def sample_posterior_predictive(
17651764
trace,
17661765
model: Optional[Model] = None,
17671766
var_names: Optional[List[str]] = None,
1767+
sample_dims: Optional[List[str]] = None,
17681768
random_seed: RandomState = None,
17691769
progressbar: bool = True,
17701770
return_inferencedata: bool = True,
@@ -1785,6 +1785,10 @@ def sample_posterior_predictive(
17851785
generally be the model used to generate the ``trace``, but it doesn't need to be.
17861786
var_names : Iterable[str]
17871787
Names of variables for which to compute the posterior predictive samples.
1788+
sample_dims : list of str, optional
1789+
Dimensions over which to loop and generate posterior predictive samples.
1790+
When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
1791+
dimensions. Only taken into account when `trace` is InferenceData or Dataset.
17881792
random_seed : int, RandomState or Generator, optional
17891793
Seed for the random number generator.
17901794
progressbar : bool
@@ -1821,6 +1825,14 @@ def sample_posterior_predictive(
18211825
thinned_idata = idata.sel(draw=slice(None, None, 5))
18221826
with model:
18231827
idata.extend(pymc.sample_posterior_predictive(thinned_idata))
1828+
1829+
Generate 5 posterior predictive samples per posterior sample.
1830+
1831+
.. code:: python
1832+
1833+
expanded_data = idata.posterior.expand_dims(pred_id=5)
1834+
with model:
1835+
idata.extend(pymc.sample_posterior_predictive(expanded_data))
18241836
"""
18251837

18261838
_trace: Union[MultiTrace, PointList]
@@ -1829,36 +1841,34 @@ def sample_posterior_predictive(
18291841
idata_kwargs = {}
18301842
else:
18311843
idata_kwargs = idata_kwargs.copy()
1844+
if sample_dims is None:
1845+
sample_dims = ["chain", "draw"]
18321846
constant_data: Dict[str, np.ndarray] = {}
18331847
trace_coords: Dict[str, np.ndarray] = {}
18341848
if "coords" not in idata_kwargs:
18351849
idata_kwargs["coords"] = {}
1850+
idata: Optional[InferenceData] = None
1851+
stacked_dims = None
18361852
if isinstance(trace, InferenceData):
1837-
idata_kwargs["coords"].setdefault("draw", trace["posterior"]["draw"])
1838-
idata_kwargs["coords"].setdefault("chain", trace["posterior"]["chain"])
18391853
_constant_data = getattr(trace, "constant_data", None)
18401854
if _constant_data is not None:
18411855
trace_coords.update({str(k): v.data for k, v in _constant_data.coords.items()})
18421856
constant_data.update({str(k): v.data for k, v in _constant_data.items()})
1843-
trace_coords.update({str(k): v.data for k, v in trace["posterior"].coords.items()})
1844-
_trace = dataset_to_point_list(trace["posterior"])
1845-
nchain, len_trace = chains_and_samples(trace)
1846-
elif isinstance(trace, xarray.Dataset):
1847-
idata_kwargs["coords"].setdefault("draw", trace["draw"])
1848-
idata_kwargs["coords"].setdefault("chain", trace["chain"])
1857+
idata = trace
1858+
trace = trace["posterior"]
1859+
if isinstance(trace, xarray.Dataset):
18491860
trace_coords.update({str(k): v.data for k, v in trace.coords.items()})
1850-
_trace = dataset_to_point_list(trace)
1851-
nchain, len_trace = chains_and_samples(trace)
1861+
_trace, stacked_dims = dataset_to_point_list(trace, sample_dims)
1862+
nchain = 1
18521863
elif isinstance(trace, MultiTrace):
18531864
_trace = trace
18541865
nchain = _trace.nchains
1855-
len_trace = len(_trace)
18561866
elif isinstance(trace, list) and all(isinstance(x, dict) for x in trace):
18571867
_trace = trace
18581868
nchain = 1
1859-
len_trace = len(_trace)
18601869
else:
18611870
raise TypeError(f"Unsupported type for `trace` argument: {type(trace)}.")
1871+
len_trace = len(_trace)
18621872

18631873
if isinstance(_trace, MultiTrace):
18641874
samples = sum(len(v) for v in _trace._straces.values())
@@ -1961,23 +1971,30 @@ def sample_posterior_predictive(
19611971
ppc_trace = ppc_trace_t.trace_dict
19621972

19631973
for k, ary in ppc_trace.items():
1964-
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
1974+
if stacked_dims is not None:
1975+
ppc_trace[k] = ary.reshape(
1976+
(*[len(coord) for coord in stacked_dims.values()], *ary.shape[1:])
1977+
)
1978+
else:
1979+
ppc_trace[k] = ary.reshape((nchain, len_trace, *ary.shape[1:]))
19651980

19661981
if not return_inferencedata:
19671982
return ppc_trace
19681983
ikwargs: Dict[str, Any] = dict(model=model, **idata_kwargs)
1984+
ikwargs.setdefault("sample_dims", sample_dims)
1985+
if stacked_dims is not None:
1986+
coords = ikwargs.get("coords", {})
1987+
ikwargs["coords"] = {**stacked_dims, **coords}
19691988
if predictions:
19701989
if extend_inferencedata:
1971-
ikwargs.setdefault("idata_orig", trace)
1990+
ikwargs.setdefault("idata_orig", idata)
19721991
ikwargs.setdefault("inplace", True)
19731992
return pm.predictions_to_inference_data(ppc_trace, **ikwargs)
1974-
converter = pm.backends.arviz.InferenceDataConverter(posterior_predictive=ppc_trace, **ikwargs)
1975-
converter.nchains = nchain
1976-
converter.ndraws = len_trace
1977-
idata_pp = converter.to_inference_data()
1978-
if extend_inferencedata:
1979-
trace.extend(idata_pp)
1980-
return trace
1993+
idata_pp = pm.to_inference_data(posterior_predictive=ppc_trace, **ikwargs)
1994+
1995+
if extend_inferencedata and idata is not None:
1996+
idata.extend(idata_pp)
1997+
return idata
19811998
return idata_pp
19821999

19832000

pymc/tests/test_sampling.py

+16
Original file line numberDiff line numberDiff line change
@@ -1621,6 +1621,22 @@ def test_aesara_function_kwargs(self):
16211621

16221622
assert np.all(pp["y"] == np.arange(5) * 2)
16231623

1624+
def test_sample_dims(self, point_list_arg_bug_fixture):
1625+
pmodel, trace = point_list_arg_bug_fixture
1626+
with pmodel:
1627+
post = pm.to_inference_data(trace).posterior.stack(sample=["chain", "draw"])
1628+
pp = pm.sample_posterior_predictive(post, var_names=["d"], sample_dims=["sample"])
1629+
assert "sample" in pp.posterior_predictive
1630+
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
1631+
post = post.expand_dims(pred_id=5)
1632+
pp = pm.sample_posterior_predictive(
1633+
post, var_names=["d"], sample_dims=["sample", "pred_id"]
1634+
)
1635+
assert "sample" in pp.posterior_predictive
1636+
assert "pred_id" in pp.posterior_predictive
1637+
assert len(pp.posterior_predictive["sample"]) == len(post["sample"])
1638+
assert len(pp.posterior_predictive["pred_id"]) == 5
1639+
16241640

16251641
class TestDraw(SeededTest):
16261642
def test_univariate(self):

pymc/tests/test_util.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def fn(a=UNSET):
154154
def test_dataset_to_point_list():
155155
ds = xarray.Dataset()
156156
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
157-
pl = dataset_to_point_list(ds)
157+
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
158158
assert isinstance(pl, list)
159159
assert len(pl) == 6
160160
assert isinstance(pl[0], dict)
@@ -163,4 +163,4 @@ def test_dataset_to_point_list():
163163
# Check that non-str keys are caught
164164
ds[3] = xarray.DataArray([1, 2, 3])
165165
with pytest.raises(ValueError, match="must be str"):
166-
dataset_to_point_list(ds)
166+
dataset_to_point_list(ds, sample_dims=["chain", "draw"])

pymc/util.py

+17-30
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@
1414

1515
import functools
1616

17-
from typing import Dict, Hashable, List, Tuple, Union, cast
17+
from typing import Any, Dict, List, Tuple, cast
1818

19-
import arviz
2019
import cloudpickle
2120
import numpy as np
2221
import xarray
@@ -231,38 +230,26 @@ def enhanced(*args, **kwargs):
231230
return enhanced
232231

233232

234-
def dataset_to_point_list(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
233+
def dataset_to_point_list(
234+
ds: xarray.Dataset, sample_dims: List
235+
) -> Tuple[List[Dict[str, np.ndarray]], Dict[str, Any]]:
235236
# All keys of the dataset must be a str
236-
for vn in ds.keys():
237+
var_names = list(ds.keys())
238+
for vn in var_names:
237239
if not isinstance(vn, str):
238240
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
239-
# make dicts
240-
points: List[Dict[Hashable, np.ndarray]] = []
241-
da: "xarray.DataArray"
242-
for c in ds.chain:
243-
for d in ds.draw:
244-
points.append({vn: da.sel(chain=c, draw=d).values for vn, da in ds.items()})
241+
num_sample_dims = len(sample_dims)
242+
stacked_dims = {dim_name: ds[dim_name] for dim_name in sample_dims}
243+
ds = ds.transpose(*sample_dims, ...)
244+
stacked_dict = {
245+
vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) for vn, da in ds.items()
246+
}
247+
points = [
248+
{vn: stacked_dict[vn][i, ...] for vn in var_names}
249+
for i in range(np.product([len(coords) for coords in stacked_dims.values()]))
250+
]
245251
# use the list of points
246-
return cast(List[Dict[str, np.ndarray]], points)
247-
248-
249-
def chains_and_samples(data: Union[xarray.Dataset, arviz.InferenceData]) -> Tuple[int, int]:
250-
"""Extract and return number of chains and samples in xarray or arviz traces."""
251-
dataset: xarray.Dataset
252-
if isinstance(data, xarray.Dataset):
253-
dataset = data
254-
elif isinstance(data, arviz.InferenceData):
255-
dataset = data["posterior"]
256-
else:
257-
raise ValueError(
258-
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
259-
data.__class__,
260-
)
261-
262-
coords = dataset.coords
263-
nchains = coords["chain"].sizes["chain"]
264-
nsamples = coords["draw"].sizes["draw"]
265-
return nchains, nsamples
252+
return cast(List[Dict[str, np.ndarray]], points), stacked_dims
266253

267254

268255
def hashable(a=None) -> int:

requirements-dev.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
aeppl==0.0.38
55
aesara==2.8.7
6-
arviz>=0.12.0
6+
arviz>=0.13.0
77
cachetools>=4.2.1
88
cloudpickle
99
fastprogress>=0.2.0

0 commit comments

Comments
 (0)