|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 | from collections.abc import Sequence
|
15 |
| -from typing import cast |
| 15 | +from typing import Literal |
16 | 16 |
|
17 |
| -from arviz import InferenceData, dict_to_dataset |
18 |
| -from rich.console import Console |
19 |
| -from rich.progress import Progress |
20 |
| - |
21 |
| -import pymc |
| 17 | +from arviz import InferenceData |
| 18 | +from xarray import Dataset |
22 | 19 |
|
23 | 20 | from pymc.backends.arviz import (
|
24 |
| - _DefaultTrace, |
| 21 | + apply_function_over_dataset, |
25 | 22 | coords_and_dims_for_inferencedata,
|
26 |
| - dataset_to_point_list, |
27 | 23 | )
|
28 | 24 | from pymc.model import Model, modelcontext
|
29 |
| -from pymc.pytensorf import PointFunc |
30 |
| -from pymc.util import default_progress_theme |
31 | 25 |
|
32 | 26 | __all__ = ("compute_log_likelihood", "compute_log_prior")
|
33 | 27 |
|
@@ -117,10 +111,10 @@ def compute_log_density(
|
117 | 111 | var_names: Sequence[str] | None = None,
|
118 | 112 | extend_inferencedata: bool = True,
|
119 | 113 | model: Model | None = None,
|
120 |
| - kind="likelihood", |
| 114 | + kind: Literal["likelihood", "prior"] = "likelihood", |
121 | 115 | sample_dims: Sequence[str] = ("chain", "draw"),
|
122 | 116 | progressbar=True,
|
123 |
| -): |
| 117 | +) -> InferenceData | Dataset: |
124 | 118 | """
|
125 | 119 | Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group
|
126 | 120 | """
|
@@ -163,40 +157,20 @@ def compute_log_density(
|
163 | 157 | outs=model.logp(vars=vars, sum=False),
|
164 | 158 | on_unused_input="ignore",
|
165 | 159 | )
|
166 |
| - elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn) |
167 | 160 | finally:
|
168 | 161 | model.rvs_to_values = original_rvs_to_values
|
169 | 162 | model.rvs_to_transforms = original_rvs_to_transforms
|
170 | 163 |
|
171 |
| - # Ignore Deterministics |
172 |
| - posterior_values = posterior[[rv.name for rv in model.free_RVs]] |
173 |
| - posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) |
174 |
| - |
175 |
| - n_pts = len(posterior_pts) |
176 |
| - logdens_dict = _DefaultTrace(n_pts) |
177 |
| - |
178 |
| - with Progress(console=Console(theme=default_progress_theme)) as progress: |
179 |
| - task = progress.add_task("Computing log density...", total=n_pts, visible=progressbar) |
180 |
| - for idx in range(n_pts): |
181 |
| - logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) |
182 |
| - for rv_name, rv_logdens in zip(var_names, logdenss_pts): |
183 |
| - logdens_dict.insert(rv_name, rv_logdens, idx) |
184 |
| - progress.update(task, advance=1) |
185 |
| - |
186 |
| - logdens_trace = logdens_dict.trace_dict |
187 |
| - for key, array in logdens_trace.items(): |
188 |
| - logdens_trace[key] = array.reshape( |
189 |
| - (*[len(coord) for coord in stacked_dims.values()], *array.shape[1:]) |
190 |
| - ) |
191 |
| - |
192 | 164 | coords, dims = coords_and_dims_for_inferencedata(model)
|
193 |
| - logdens_dataset = dict_to_dataset( |
194 |
| - logdens_trace, |
195 |
| - library=pymc, |
| 165 | + |
| 166 | + logdens_dataset = apply_function_over_dataset( |
| 167 | + elemwise_logdens_fn, |
| 168 | + posterior[[rv.name for rv in model.free_RVs]], |
| 169 | + output_var_names=var_names, |
| 170 | + sample_dims=sample_dims, |
196 | 171 | dims=dims,
|
197 | 172 | coords=coords,
|
198 |
| - default_dims=list(sample_dims), |
199 |
| - skip_event_dims=True, |
| 173 | + progressbar=progressbar, |
200 | 174 | )
|
201 | 175 |
|
202 | 176 | if extend_inferencedata:
|
|
0 commit comments