Skip to content

Commit 7ad71f8

Browse files
committed
Refactor utility to apply a PointFunc over a dataset
1 parent 6865ac9 commit 7ad71f8

File tree

2 files changed

+64
-41
lines changed

2 files changed

+64
-41
lines changed

pymc/backends/arviz.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,15 @@
3232
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
3333
from pytensor.graph.basic import Constant
3434
from pytensor.tensor.sharedvar import SharedVariable
35+
from rich.progress import Console, Progress
36+
from rich.theme import Theme
37+
from xarray import Dataset
3538

3639
import pymc
3740

3841
from pymc.model import Model, modelcontext
39-
from pymc.pytensorf import extract_obs_data
40-
from pymc.util import get_default_varnames
42+
from pymc.pytensorf import PointFunc, extract_obs_data
43+
from pymc.util import default_progress_theme, get_default_varnames
4144

4245
if TYPE_CHECKING:
4346
from pymc.backends.base import MultiTrace
@@ -637,3 +640,49 @@ def dataset_to_point_list(
637640
]
638641
# use the list of points
639642
return cast(list[dict[str, np.ndarray]], points), stacked_dims
643+
644+
645+
def apply_function_over_dataset(
646+
fn: PointFunc,
647+
dataset: Dataset,
648+
*,
649+
output_var_names: Sequence[str],
650+
coords,
651+
dims,
652+
sample_dims: Sequence[str] = ("chain", "draw"),
653+
progressbar: bool = True,
654+
progressbar_theme: Theme | None = default_progress_theme,
655+
) -> Dataset:
656+
posterior_pts, stacked_dims = dataset_to_point_list(dataset, sample_dims)
657+
658+
n_pts = len(posterior_pts)
659+
out_dict = _DefaultTrace(n_pts)
660+
indices = range(n_pts)
661+
662+
with Progress(console=Console(theme=progressbar_theme)) as progress:
663+
task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar)
664+
for idx in indices:
665+
out = fn(posterior_pts[idx])
666+
fn.f.trust_input = True # If we arrive here the dtypes are valid
667+
for var_name, val in zip(output_var_names, out):
668+
out_dict.insert(var_name, val, idx)
669+
670+
progress.advance(task)
671+
672+
out_trace = out_dict.trace_dict
673+
for key, val in out_trace.items():
674+
out_trace[key] = val.reshape(
675+
(
676+
*[len(coord) for coord in stacked_dims.values()],
677+
*val.shape[1:],
678+
)
679+
)
680+
681+
return dict_to_dataset(
682+
out_trace,
683+
library=pymc,
684+
dims=dims,
685+
coords=coords,
686+
default_dims=list(sample_dims),
687+
skip_event_dims=True,
688+
)

pymc/stats/log_density.py

Lines changed: 13 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,22 +12,16 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from collections.abc import Sequence
15-
from typing import cast
15+
from typing import Literal
1616

17-
from arviz import InferenceData, dict_to_dataset
18-
from rich.console import Console
19-
from rich.progress import Progress
20-
21-
import pymc
17+
from arviz import InferenceData
18+
from xarray import Dataset
2219

2320
from pymc.backends.arviz import (
24-
_DefaultTrace,
21+
apply_function_over_dataset,
2522
coords_and_dims_for_inferencedata,
26-
dataset_to_point_list,
2723
)
2824
from pymc.model import Model, modelcontext
29-
from pymc.pytensorf import PointFunc
30-
from pymc.util import default_progress_theme
3125

3226
__all__ = ("compute_log_likelihood", "compute_log_prior")
3327

@@ -117,10 +111,10 @@ def compute_log_density(
117111
var_names: Sequence[str] | None = None,
118112
extend_inferencedata: bool = True,
119113
model: Model | None = None,
120-
kind="likelihood",
114+
kind: Literal["likelihood", "prior"] = "likelihood",
121115
sample_dims: Sequence[str] = ("chain", "draw"),
122116
progressbar=True,
123-
):
117+
) -> InferenceData | Dataset:
124118
"""
125119
Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
126120
"""
@@ -163,40 +157,20 @@ def compute_log_density(
163157
outs=model.logp(vars=vars, sum=False),
164158
on_unused_input="ignore",
165159
)
166-
elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn)
167160
finally:
168161
model.rvs_to_values = original_rvs_to_values
169162
model.rvs_to_transforms = original_rvs_to_transforms
170163

171-
# Ignore Deterministics
172-
posterior_values = posterior[[rv.name for rv in model.free_RVs]]
173-
posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims)
174-
175-
n_pts = len(posterior_pts)
176-
logdens_dict = _DefaultTrace(n_pts)
177-
178-
with Progress(console=Console(theme=default_progress_theme)) as progress:
179-
task = progress.add_task("Computing log density...", total=n_pts, visible=progressbar)
180-
for idx in range(n_pts):
181-
logdenss_pts = elemwise_logdens_fn(posterior_pts[idx])
182-
for rv_name, rv_logdens in zip(var_names, logdenss_pts):
183-
logdens_dict.insert(rv_name, rv_logdens, idx)
184-
progress.update(task, advance=1)
185-
186-
logdens_trace = logdens_dict.trace_dict
187-
for key, array in logdens_trace.items():
188-
logdens_trace[key] = array.reshape(
189-
(*[len(coord) for coord in stacked_dims.values()], *array.shape[1:])
190-
)
191-
192164
coords, dims = coords_and_dims_for_inferencedata(model)
193-
logdens_dataset = dict_to_dataset(
194-
logdens_trace,
195-
library=pymc,
165+
166+
logdens_dataset = apply_function_over_dataset(
167+
elemwise_logdens_fn,
168+
posterior[[rv.name for rv in model.free_RVs]],
169+
output_var_names=var_names,
170+
sample_dims=sample_dims,
196171
dims=dims,
197172
coords=coords,
198-
default_dims=list(sample_dims),
199-
skip_event_dims=True,
173+
progressbar=progressbar,
200174
)
201175

202176
if extend_inferencedata:

0 commit comments

Comments
 (0)