Skip to content

Commit de2d614

Browse files
committed
Fix posterior pred. sampling keep_size w/ arviz input.
Previously posterior predictive sampling functions did not properly handle the `keep_size` keyword argument when getting an xarray Dataset as parameter. Also extended these functions to accept InferenceData object as input.
1 parent 90f48ed commit de2d614

File tree

4 files changed

+147
-50
lines changed

4 files changed

+147
-50
lines changed

pymc3/distributions/posterior_predictive.py

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,11 @@
44
import logging
55
from collections import UserDict
66
from contextlib import AbstractContextManager
7+
78
if TYPE_CHECKING:
89
import contextvars # noqa: F401
910
from typing import Set
10-
from typing_extensions import Protocol
11+
from typing_extensions import Protocol, Literal
1112

1213
import numpy as np
1314
import theano
@@ -17,9 +18,13 @@
1718
from ..backends.base import MultiTrace #, TraceLike, TraceDict
1819
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
1920
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
21+
from arviz import InferenceData
22+
23+
from ..backends.base import MultiTrace # , TraceLike, TraceDict
2024
from ..exceptions import IncorrectArgumentsError
2125
from ..vartypes import theano_constant
22-
from ..util import dataset_to_point_dict
26+
from ..util import dataset_to_point_dict, chains_and_samples
27+
2328
# Failing tests:
2429
# test_mixture_random_shape::test_mixture_random_shape
2530
#
@@ -121,12 +126,14 @@ def __getitem__(self, item):
121126

122127

123128

124-
def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
125-
samples: Optional[int]=None,
126-
model: Optional[Model]=None,
127-
var_names: Optional[List[str]]=None,
128-
keep_size: bool=False,
129-
random_seed=None) -> Dict[str, np.ndarray]:
129+
def fast_sample_posterior_predictive(
130+
trace: Union[MultiTrace, Dataset, InferenceData, List[Dict[str, np.ndarray]]],
131+
samples: Optional[int] = None,
132+
model: Optional[Model] = None,
133+
var_names: Optional[List[str]] = None,
134+
keep_size: bool = False,
135+
random_seed=None,
136+
) -> Dict[str, np.ndarray]:
130137
"""Generate posterior predictive samples from a model given a trace.
131138
132139
This is a vectorized alternative to the standard ``sample_posterior_predictive`` function.
@@ -137,7 +144,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict
137144
138145
Parameters
139146
----------
140-
trace: MultiTrace, xarray.Dataset, or List of points (dictionary)
147+
trace: MultiTrace, xarray.Dataset, InferenceData, or List of points (dictionary)
141148
Trace generated from MCMC sampling.
142149
samples: int, optional
143150
Number of posterior predictive samples to generate. Defaults to one posterior predictive
@@ -170,21 +177,33 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict
170177
### greater than the number of samples in the trace parameter, we sample repeatedly. This
171178
### makes the shape issues just a little easier to deal with.
172179

173-
if isinstance(trace, Dataset):
180+
if isinstance(trace, InferenceData):
181+
nchains, ndraws = chains_and_samples(trace)
182+
trace = dataset_to_point_dict(trace.posterior)
183+
elif isinstance(trace, Dataset):
184+
nchains, ndraws = chains_and_samples(trace)
174185
trace = dataset_to_point_dict(trace)
186+
elif isinstance(trace, MultiTrace):
187+
nchains = trace.nchains
188+
ndraws = len(trace)
189+
else:
190+
if keep_size:
191+
# arguably this should be just a warning.
192+
raise IncorrectArgumentsError(
193+
"For keep_size, cannot identify chains and length from %s.", trace
194+
)
175195

176196
model = modelcontext(model)
177197
assert model is not None
178198
with model:
179199

180200
if keep_size and samples is not None:
181-
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
182-
if keep_size and not isinstance(trace, MultiTrace):
183-
# arguably this should be just a warning.
184-
raise IncorrectArgumentsError("keep_size argument only applies when sampling from MultiTrace.")
201+
raise IncorrectArgumentsError(
202+
"Should not specify both keep_size and samples arguments"
203+
)
185204

186205
if isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
187-
_trace = _TraceDict(point_list=trace)
206+
_trace = _TraceDict(point_list=trace)
188207
elif isinstance(trace, MultiTrace):
189208
_trace = _TraceDict(multi_trace=trace)
190209
else:

pymc3/sampling.py

Lines changed: 56 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import warnings
3030

3131
import arviz
32+
from arviz import InferenceData
3233
import numpy as np
3334
import theano.gradient as tg
3435
from theano.tensor import Tensor
@@ -57,6 +58,7 @@
5758
is_transformed_name,
5859
get_default_varnames,
5960
dataset_to_point_dict,
61+
chains_and_samples,
6062
)
6163
from .vartypes import discrete_types
6264
from .exceptions import IncorrectArgumentsError
@@ -91,6 +93,8 @@
9193
)
9294

9395
ArrayLike = Union[np.ndarray, List[float]]
96+
PointType = Dict[str, np.ndarray]
97+
PointList = List[PointType]
9498

9599
_log = logging.getLogger("pymc3")
96100

@@ -248,10 +252,10 @@ def sample(
248252
callback=None,
249253
*,
250254
return_inferencedata=None,
251-
idata_kwargs: dict=None,
255+
idata_kwargs: dict = None,
252256
mp_ctx=None,
253-
pickle_backend: str = 'pickle',
254-
**kwargs
257+
pickle_backend: str = "pickle",
258+
**kwargs,
255259
):
256260
"""Draw samples from the posterior using the given step methods.
257261
@@ -1584,7 +1588,7 @@ def sample_posterior_predictive(
15841588
15851589
Parameters
15861590
----------
1587-
trace : backend, list, xarray.Dataset, or MultiTrace
1591+
trace : backend, list, xarray.Dataset, arviz.InferenceData, or MultiTrace
15881592
Trace generated from MCMC sampling, or a list of dicts (eg. points or from find_MAP()),
15891593
or xarray.Dataset (eg. InferenceData.posterior or InferenceData.prior)
15901594
samples : int
@@ -1598,8 +1602,7 @@ def sample_posterior_predictive(
15981602
Variables for which to compute the posterior predictive samples.
15991603
Deprecated: please use ``var_names`` instead.
16001604
var_names : Iterable[str]
1601-
Alternative way to specify vars to sample, to make this function orthogonal with
1602-
others.
1605+
Names of variables for which to compute the posterior predictive samples.
16031606
size : int
16041607
The number of random draws from the distribution specified by the parameters in each
16051608
sample of the trace. Not recommended unless more than ndraws times nchains posterior
@@ -1620,29 +1623,48 @@ def sample_posterior_predictive(
16201623
Dictionary with the variable names as keys, and values numpy arrays containing
16211624
posterior predictive samples.
16221625
"""
1623-
if isinstance(trace, xarray.Dataset):
1624-
trace = dataset_to_point_dict(trace)
16251626

1626-
len_trace = len(trace)
1627-
try:
1628-
nchain = trace.nchains
1629-
except AttributeError:
1630-
nchain = 1
1627+
_trace: Union[MultiTrace, PointList]
1628+
if isinstance(trace, InferenceData):
1629+
_trace = dataset_to_point_dict(trace.posterior)
1630+
elif isinstance(trace, xarray.Dataset):
1631+
_trace = dataset_to_point_dict(trace)
1632+
else:
1633+
_trace = trace
1634+
1635+
nchain: int
1636+
len_trace: int
1637+
if isinstance(trace, (InferenceData, xarray.Dataset)):
1638+
nchain, len_trace = chains_and_samples(trace)
1639+
else:
1640+
len_trace = len(_trace)
1641+
try:
1642+
nchain = _trace.nchains
1643+
except AttributeError:
1644+
nchain = 1
16311645

16321646
if keep_size and samples is not None:
1633-
raise IncorrectArgumentsError("Should not specify both keep_size and samples arguments")
1647+
raise IncorrectArgumentsError(
1648+
"Should not specify both keep_size and samples arguments"
1649+
)
16341650
if keep_size and size is not None:
1635-
raise IncorrectArgumentsError("Should not specify both keep_size and size arguments")
1651+
raise IncorrectArgumentsError(
1652+
"Should not specify both keep_size and size arguments"
1653+
)
16361654

16371655
if samples is None:
1638-
if isinstance(trace, MultiTrace):
1639-
samples = sum(len(v) for v in trace._straces.values())
1640-
elif isinstance(trace, list) and all((isinstance(x, dict) for x in trace)):
1656+
if isinstance(_trace, MultiTrace):
1657+
samples = sum(len(v) for v in _trace._straces.values())
1658+
elif isinstance(_trace, list) and all((isinstance(x, dict) for x in _trace)):
16411659
# this is a list of points
1642-
samples = len(trace)
1660+
samples = len(_trace)
16431661
else:
1644-
raise ValueError("Do not know how to compute number of samples for trace argument of type %s"%type(trace))
1662+
raise ValueError(
1663+
"Do not know how to compute number of samples for trace argument of type %s"
1664+
% type(_trace)
1665+
)
16451666

1667+
assert samples is not None
16461668
if samples < len_trace * nchain:
16471669
warnings.warn(
16481670
"samples parameter is smaller than nchains times ndraws, some draws "
@@ -1675,10 +1697,21 @@ def sample_posterior_predictive(
16751697
try:
16761698
for idx in indices:
16771699
if nchain > 1:
1678-
chain_idx, point_idx = np.divmod(idx, len_trace)
1679-
param = trace._straces[chain_idx % nchain].point(point_idx)
1700+
# the trace object will either be a MultiTrace (and have _straces)...
1701+
if hasattr(_trace, "_straces"):
1702+
chain_idx, point_idx = np.divmod(idx, len_trace)
1703+
param = (
1704+
cast(MultiTrace, _trace)
1705+
._straces[chain_idx % nchain]
1706+
.point(point_idx)
1707+
)
1708+
# ... or a PointList
1709+
else:
1710+
param = cast(PointList, _trace)[idx % len_trace]
1711+
# there's only a single chain, but the index might hit it multiple times if
1712+
# the number of indices is greater than the length of the trace.
16801713
else:
1681-
param = trace[idx % len_trace]
1714+
param = _trace[idx % len_trace]
16821715

16831716
values = draw_values(vars, point=param, size=size)
16841717
for k, v in zip(vars, values):

pymc3/tests/test_sampling.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,13 @@ def test_normal_scalar(self):
381381
ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True)
382382
assert ppc["a"].shape == (nchains, ndraws)
383383

384+
# test keep_size parameter and idata input
385+
idata = az.from_pymc3(trace)
386+
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
387+
assert ppc["a"].shape == (nchains, ndraws)
388+
ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True)
389+
assert ppc["a"].shape == (nchains, ndraws)
390+
384391
# test default case
385392
ppc = pm.sample_posterior_predictive(trace, var_names=["a"])
386393
assert "a" in ppc
@@ -428,6 +435,15 @@ def test_normal_vector(self, caplog):
428435
assert "a" in ppc
429436
assert ppc["a"].shape == (12, 2)
430437

438+
# test keep_size parameter with inference data as input...
439+
idata = az.from_pymc3(trace)
440+
ppc = pm.sample_posterior_predictive(idata, keep_size=True)
441+
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
442+
with pytest.warns(UserWarning):
443+
ppc = pm.sample_posterior_predictive(trace, samples=12, var_names=["a"])
444+
assert "a" in ppc
445+
assert ppc["a"].shape == (12, 2)
446+
431447
# test keep_size parameter
432448
ppc = pm.fast_sample_posterior_predictive(trace, keep_size=True)
433449
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
@@ -436,6 +452,14 @@ def test_normal_vector(self, caplog):
436452
assert "a" in ppc
437453
assert ppc["a"].shape == (12, 2)
438454

455+
# test keep_size parameter with inference data as input
456+
ppc = pm.fast_sample_posterior_predictive(idata, keep_size=True)
457+
assert ppc["a"].shape == (trace.nchains, len(trace), 2)
458+
with pytest.warns(UserWarning):
459+
ppc = pm.fast_sample_posterior_predictive(trace, samples=12, var_names=["a"])
460+
assert "a" in ppc
461+
assert ppc["a"].shape == (12, 2)
462+
439463

440464
# size unsupported by fast_ version argument. [2019/08/19:rpg]
441465
ppc = pm.sample_posterior_predictive(trace, samples=10, var_names=["a"], size=4)

pymc3/util.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,10 @@
1414

1515
import re
1616
import functools
17-
from typing import List, Dict
17+
from typing import List, Dict, Tuple, Union
1818

1919
import xarray
20+
import arviz
2021
from numpy import asscalar, ndarray
2122

2223

@@ -182,22 +183,42 @@ def enhanced(*args, **kwargs):
182183
else:
183184
newwrapper = functools.partial(wrapper, *args, **kwargs)
184185
return newwrapper
186+
185187
return enhanced
186188

189+
190+
# FIXME: this function is poorly named, because it returns a LIST of
191+
# points, not a dictionary of points.
187192
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
188193
# grab posterior samples for each variable
189-
_samples = {
190-
vn : ds[vn].values
191-
for vn in ds.keys()
192-
}
194+
_samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()}
193195
# make dicts
194-
points = []
196+
points: List[Dict[str, ndarray]] = []
197+
vn: str
198+
s: ndarray
195199
for c in ds.chain:
196200
for d in ds.draw:
197-
points.append({
198-
vn : s[c, d]
199-
for vn, s in _samples.items()
200-
})
201+
points.append({vn: s[c, d] for vn, s in _samples.items()})
201202
# use the list of points
202-
ds = points
203-
return ds
203+
return points
204+
205+
206+
def chains_and_samples(
207+
data: Union[xarray.Dataset, arviz.InferenceData]
208+
) -> Tuple[int, int]:
209+
"""Extract and return number of chains and samples in xarray or arviz traces."""
210+
dataset: xarray.Dataset
211+
if isinstance(data, xarray.Dataset):
212+
dataset = data
213+
elif isinstance(data, arviz.InferenceData):
214+
dataset = data.posterior
215+
else:
216+
raise ValueError(
217+
"Argument must be xarray Dataset or arviz InferenceData. Got %s",
218+
data.__class__,
219+
)
220+
221+
coords = dataset.coords
222+
nchains = coords["chain"].sizes["chain"]
223+
nsamples = coords["draw"].sizes["draw"]
224+
return nchains, nsamples

0 commit comments

Comments
 (0)