75
75
from pymc .step_methods .arraystep import BlockedStep , PopulationArrayStepShared
76
76
from pymc .step_methods .hmc import quadpotential
77
77
from pymc .util import (
78
- chains_and_samples ,
79
78
dataset_to_point_list ,
80
79
get_default_varnames ,
81
80
get_untransformed_name ,
@@ -1765,6 +1764,7 @@ def sample_posterior_predictive(
1765
1764
trace ,
1766
1765
model : Optional [Model ] = None ,
1767
1766
var_names : Optional [List [str ]] = None ,
1767
+ sample_dims : Optional [List [str ]] = None ,
1768
1768
random_seed : RandomState = None ,
1769
1769
progressbar : bool = True ,
1770
1770
return_inferencedata : bool = True ,
@@ -1785,6 +1785,10 @@ def sample_posterior_predictive(
1785
1785
generally be the model used to generate the ``trace``, but it doesn't need to be.
1786
1786
var_names : Iterable[str]
1787
1787
Names of variables for which to compute the posterior predictive samples.
1788
+ sample_dims : list of str, optional
1789
+ Dimensions over which to loop and generate posterior predictive samples.
1790
+ When `sample_dims` is ``None`` (default) both "chain" and "draw" are considered sample
1791
+ dimensions. Only taken into account when `trace` is InferenceData or Dataset.
1788
1792
random_seed : int, RandomState or Generator, optional
1789
1793
Seed for the random number generator.
1790
1794
progressbar : bool
@@ -1821,6 +1825,14 @@ def sample_posterior_predictive(
1821
1825
thinned_idata = idata.sel(draw=slice(None, None, 5))
1822
1826
with model:
1823
1827
idata.extend(pymc.sample_posterior_predictive(thinned_idata))
1828
+
1829
+ Generate 5 posterior predictive samples per posterior sample.
1830
+
1831
+ .. code:: python
1832
+
1833
+ expanded_data = idata.posterior.expand_dims(pred_id=5)
1834
+ with model:
1835
+ idata.extend(pymc.sample_posterior_predictive(expanded_data))
1824
1836
"""
1825
1837
1826
1838
_trace : Union [MultiTrace , PointList ]
@@ -1829,36 +1841,34 @@ def sample_posterior_predictive(
1829
1841
idata_kwargs = {}
1830
1842
else :
1831
1843
idata_kwargs = idata_kwargs .copy ()
1844
+ if sample_dims is None :
1845
+ sample_dims = ["chain" , "draw" ]
1832
1846
constant_data : Dict [str , np .ndarray ] = {}
1833
1847
trace_coords : Dict [str , np .ndarray ] = {}
1834
1848
if "coords" not in idata_kwargs :
1835
1849
idata_kwargs ["coords" ] = {}
1850
+ idata : Optional [InferenceData ] = None
1851
+ stacked_dims = None
1836
1852
if isinstance (trace , InferenceData ):
1837
- idata_kwargs ["coords" ].setdefault ("draw" , trace ["posterior" ]["draw" ])
1838
- idata_kwargs ["coords" ].setdefault ("chain" , trace ["posterior" ]["chain" ])
1839
1853
_constant_data = getattr (trace , "constant_data" , None )
1840
1854
if _constant_data is not None :
1841
1855
trace_coords .update ({str (k ): v .data for k , v in _constant_data .coords .items ()})
1842
1856
constant_data .update ({str (k ): v .data for k , v in _constant_data .items ()})
1843
- trace_coords .update ({str (k ): v .data for k , v in trace ["posterior" ].coords .items ()})
1844
- _trace = dataset_to_point_list (trace ["posterior" ])
1845
- nchain , len_trace = chains_and_samples (trace )
1846
- elif isinstance (trace , xarray .Dataset ):
1847
- idata_kwargs ["coords" ].setdefault ("draw" , trace ["draw" ])
1848
- idata_kwargs ["coords" ].setdefault ("chain" , trace ["chain" ])
1857
+ idata = trace
1858
+ trace = trace ["posterior" ]
1859
+ if isinstance (trace , xarray .Dataset ):
1849
1860
trace_coords .update ({str (k ): v .data for k , v in trace .coords .items ()})
1850
- _trace = dataset_to_point_list (trace )
1851
- nchain , len_trace = chains_and_samples ( trace )
1861
+ _trace , stacked_dims = dataset_to_point_list (trace , sample_dims )
1862
+ nchain = 1
1852
1863
elif isinstance (trace , MultiTrace ):
1853
1864
_trace = trace
1854
1865
nchain = _trace .nchains
1855
- len_trace = len (_trace )
1856
1866
elif isinstance (trace , list ) and all (isinstance (x , dict ) for x in trace ):
1857
1867
_trace = trace
1858
1868
nchain = 1
1859
- len_trace = len (_trace )
1860
1869
else :
1861
1870
raise TypeError (f"Unsupported type for `trace` argument: { type (trace )} ." )
1871
+ len_trace = len (_trace )
1862
1872
1863
1873
if isinstance (_trace , MultiTrace ):
1864
1874
samples = sum (len (v ) for v in _trace ._straces .values ())
@@ -1961,23 +1971,30 @@ def sample_posterior_predictive(
1961
1971
ppc_trace = ppc_trace_t .trace_dict
1962
1972
1963
1973
for k , ary in ppc_trace .items ():
1964
- ppc_trace [k ] = ary .reshape ((nchain , len_trace , * ary .shape [1 :]))
1974
+ if stacked_dims is not None :
1975
+ ppc_trace [k ] = ary .reshape (
1976
+ (* [len (coord ) for coord in stacked_dims .values ()], * ary .shape [1 :])
1977
+ )
1978
+ else :
1979
+ ppc_trace [k ] = ary .reshape ((nchain , len_trace , * ary .shape [1 :]))
1965
1980
1966
1981
if not return_inferencedata :
1967
1982
return ppc_trace
1968
1983
ikwargs : Dict [str , Any ] = dict (model = model , ** idata_kwargs )
1984
+ ikwargs .setdefault ("sample_dims" , sample_dims )
1985
+ if stacked_dims is not None :
1986
+ coords = ikwargs .get ("coords" , {})
1987
+ ikwargs ["coords" ] = {** stacked_dims , ** coords }
1969
1988
if predictions :
1970
1989
if extend_inferencedata :
1971
- ikwargs .setdefault ("idata_orig" , trace )
1990
+ ikwargs .setdefault ("idata_orig" , idata )
1972
1991
ikwargs .setdefault ("inplace" , True )
1973
1992
return pm .predictions_to_inference_data (ppc_trace , ** ikwargs )
1974
- converter = pm .backends .arviz .InferenceDataConverter (posterior_predictive = ppc_trace , ** ikwargs )
1975
- converter .nchains = nchain
1976
- converter .ndraws = len_trace
1977
- idata_pp = converter .to_inference_data ()
1978
- if extend_inferencedata :
1979
- trace .extend (idata_pp )
1980
- return trace
1993
+ idata_pp = pm .to_inference_data (posterior_predictive = ppc_trace , ** ikwargs )
1994
+
1995
+ if extend_inferencedata and idata is not None :
1996
+ idata .extend (idata_pp )
1997
+ return idata
1981
1998
return idata_pp
1982
1999
1983
2000
0 commit comments