Skip to content

Commit 6865ac9

Browse files
committed
Move dataset_to_point_list to arviz module
1 parent 8938851 commit 6865ac9

File tree

6 files changed

+81
-62
lines changed

6 files changed

+81
-62
lines changed

pymc/backends/arviz.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,17 @@
1616
import logging
1717
import warnings
1818

19-
from collections.abc import Iterable, Mapping
19+
from collections.abc import Iterable, Mapping, Sequence
2020
from typing import (
2121
TYPE_CHECKING,
2222
Any,
2323
Optional,
2424
Union,
25+
cast,
2526
)
2627

2728
import numpy as np
29+
import xarray
2830

2931
from arviz import InferenceData, concat, rcParams
3032
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
@@ -612,3 +614,26 @@ def predictions_to_inference_data(
612614
# data and return that.
613615
concat([new_idata, idata_orig], dim=None, copy=True, inplace=True)
614616
return new_idata
617+
618+
619+
def dataset_to_point_list(
620+
ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str]
621+
) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]:
622+
# All keys of the dataset must be a str
623+
var_names = cast(list[str], list(ds.keys()))
624+
for vn in var_names:
625+
if not isinstance(vn, str):
626+
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
627+
num_sample_dims = len(sample_dims)
628+
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
629+
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
630+
stacked_dict = {
631+
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
632+
for vn, da in transposed_dict.items()
633+
}
634+
points = [
635+
{vn: stacked_dict[vn][i, ...] for vn in var_names}
636+
for i in range(np.prod([len(coords) for coords in stacked_dims.values()]))
637+
]
638+
# use the list of points
639+
return cast(list[dict[str, np.ndarray]], points), stacked_dims

pymc/sampling/forward.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,14 @@
4949

5050
import pymc as pm
5151

52-
from pymc.backends.arviz import _DefaultTrace
52+
from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list
5353
from pymc.backends.base import MultiTrace
5454
from pymc.blocking import PointType
5555
from pymc.model import Model, modelcontext
5656
from pymc.pytensorf import compile_pymc
5757
from pymc.util import (
5858
RandomState,
5959
_get_seeds_per_chain,
60-
dataset_to_point_list,
6160
default_progress_theme,
6261
get_default_varnames,
6362
point_wrapper,

pymc/stats/log_density.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,14 @@
2020

2121
import pymc
2222

23-
from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata
23+
from pymc.backends.arviz import (
24+
_DefaultTrace,
25+
coords_and_dims_for_inferencedata,
26+
dataset_to_point_list,
27+
)
2428
from pymc.model import Model, modelcontext
2529
from pymc.pytensorf import PointFunc
26-
from pymc.util import dataset_to_point_list, default_progress_theme
30+
from pymc.util import default_progress_theme
2731

2832
__all__ = ("compute_log_likelihood", "compute_log_prior")
2933

pymc/util.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import warnings
1717

1818
from collections.abc import Sequence
19-
from typing import Any, NewType, cast
19+
from typing import NewType, cast
2020

2121
import arviz
2222
import cloudpickle
@@ -31,6 +31,20 @@
3131

3232
from pymc.exceptions import BlockModelAccessError
3333

34+
35+
def __getattr__(name):
36+
if name == "dataset_to_point_list":
37+
warnings.warn(
38+
f"{name} has been moved to backends.arviz. Importing from util will fail in a future release.",
39+
FutureWarning,
40+
)
41+
from pymc.backends.arviz import dataset_to_point_list
42+
43+
return dataset_to_point_list
44+
45+
raise AttributeError(f"module {__name__} has no attribute {name}")
46+
47+
3448
VarName = NewType("VarName", str)
3549

3650
default_progress_theme = Theme(
@@ -247,29 +261,6 @@ def enhanced(*args, **kwargs):
247261
return enhanced
248262

249263

250-
def dataset_to_point_list(
251-
ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str]
252-
) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]:
253-
# All keys of the dataset must be a str
254-
var_names = cast(list[str], list(ds.keys()))
255-
for vn in var_names:
256-
if not isinstance(vn, str):
257-
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
258-
num_sample_dims = len(sample_dims)
259-
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
260-
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
261-
stacked_dict = {
262-
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
263-
for vn, da in transposed_dict.items()
264-
}
265-
points = [
266-
{vn: stacked_dict[vn][i, ...] for vn in var_names}
267-
for i in range(np.prod([len(coords) for coords in stacked_dims.values()]))
268-
]
269-
# use the list of points
270-
return cast(list[dict[str, np.ndarray]], points), stacked_dims
271-
272-
273264
def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData:
274265
"""Returns a new ``InferenceData`` object with the "warning" stat removed from sample stats groups.
275266

tests/backends/test_arviz.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import numpy as np
1717
import pytensor.tensor as pt
1818
import pytest
19+
import xarray
1920

2021
from arviz import InferenceData
2122
from arviz.tests.helpers import check_multiple_attrs
@@ -26,6 +27,7 @@
2627

2728
from pymc.backends.arviz import (
2829
InferenceDataConverter,
30+
dataset_to_point_list,
2931
predictions_to_inference_data,
3032
to_inference_data,
3133
)
@@ -776,3 +778,34 @@ def test_save_warmup_issue_1208_after_3_9(self):
776778
assert not fails
777779
assert idata.posterior.sizes["chain"] == 2
778780
assert idata.posterior.sizes["draw"] == 30
781+
782+
783+
class TestDatasetToPointList:
784+
@pytest.mark.parametrize("input_type", ("dict", "Dataset"))
785+
def test_dataset_to_point_list(self, input_type):
786+
if input_type == "dict":
787+
ds = {}
788+
elif input_type == "Dataset":
789+
ds = xarray.Dataset()
790+
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
791+
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
792+
assert isinstance(pl, list)
793+
assert len(pl) == 6
794+
assert isinstance(pl[0], dict)
795+
assert isinstance(pl[0]["A"], np.ndarray)
796+
797+
def test_transposed_dataset_to_point_list(self):
798+
ds = xarray.Dataset()
799+
ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain"))
800+
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
801+
assert isinstance(pl, list)
802+
assert len(pl) == 6
803+
assert isinstance(pl[0], dict)
804+
assert isinstance(pl[0]["A"], np.ndarray)
805+
806+
def test_dataset_to_point_list_str_key(self):
807+
# Check that non-str keys are caught
808+
ds = xarray.Dataset()
809+
ds[3] = xarray.DataArray([1, 2, 3])
810+
with pytest.raises(ValueError, match="must be str"):
811+
dataset_to_point_list(ds, sample_dims=["chain", "draw"])

tests/test_util.py

Lines changed: 0 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from pymc.util import (
2727
UNSET,
2828
_get_seeds_per_chain,
29-
dataset_to_point_list,
3029
drop_warning_stat,
3130
get_value_vars_from_user_vars,
3231
hash_key,
@@ -156,38 +155,6 @@ def fn(a=UNSET):
156155
assert "a=UNSET" in captured.out
157156

158157

159-
@pytest.mark.parametrize("input_type", ("dict", "Dataset"))
160-
def test_dataset_to_point_list(input_type):
161-
if input_type == "dict":
162-
ds = {}
163-
elif input_type == "Dataset":
164-
ds = xarray.Dataset()
165-
ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw"))
166-
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
167-
assert isinstance(pl, list)
168-
assert len(pl) == 6
169-
assert isinstance(pl[0], dict)
170-
assert isinstance(pl[0]["A"], np.ndarray)
171-
172-
173-
def test_transposed_dataset_to_point_list():
174-
ds = xarray.Dataset()
175-
ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain"))
176-
pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"])
177-
assert isinstance(pl, list)
178-
assert len(pl) == 6
179-
assert isinstance(pl[0], dict)
180-
assert isinstance(pl[0]["A"], np.ndarray)
181-
182-
183-
def test_dataset_to_point_list_str_key():
184-
# Check that non-str keys are caught
185-
ds = xarray.Dataset()
186-
ds[3] = xarray.DataArray([1, 2, 3])
187-
with pytest.raises(ValueError, match="must be str"):
188-
dataset_to_point_list(ds, sample_dims=["chain", "draw"])
189-
190-
191158
def test_drop_warning_stat():
192159
idata = arviz.from_dict(
193160
sample_stats={

0 commit comments

Comments
 (0)