Skip to content

Commit 15b6082

Browse files
Merge pull request #72 from michaelosthege/bump026
Bump for 0.2.6
2 parents 35f17ad + a97d84d commit 15b6082

File tree

4 files changed

+36
-2
lines changed

4 files changed

+36
-2
lines changed

mcbackend/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@
2020
pass
2121

2222

23-
__version__ = "0.2.5"
23+
__version__ = "0.2.6"

mcbackend/core.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from .meta import ChainMeta, RunMeta, Variable
1111
from .npproto.utils import ndarray_to_numpy
12+
from .utils import as_array_from_ragged
1213

1314
InferenceData = TypeVar("InferenceData")
1415
try:
@@ -252,7 +253,15 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
252253
warmup_sample_stats[svar.name].append(stats[tune])
253254
sample_stats[svar.name].append(stats[~tune])
254255

255-
kwargs.setdefault("save_warmup", True)
256+
if not equalize_chain_lengths:
257+
# Convert ragged arrays to object-dtyped ndarray because NumPy >=1.24.0 no longer does that automatically
258+
warmup_posterior = {k: as_array_from_ragged(v) for k, v in warmup_posterior.items()}
259+
warmup_sample_stats = {
260+
k: as_array_from_ragged(v) for k, v in warmup_sample_stats.items()
261+
}
262+
posterior = {k: as_array_from_ragged(v) for k, v in posterior.items()}
263+
sample_stats = {k: as_array_from_ragged(v) for k, v in sample_stats.items()}
264+
256265
idata = from_dict(
257266
warmup_posterior=warmup_posterior,
258267
warmup_sample_stats=warmup_sample_stats,
@@ -263,6 +272,7 @@ def to_inferencedata(self, *, equalize_chain_lengths: bool = True, **kwargs) ->
263272
attrs=self.meta.attributes,
264273
constant_data=self.constant_data,
265274
observed_data=self.observed_data,
275+
save_warmup=True,
266276
**kwargs,
267277
)
268278
return idata

mcbackend/test_utils.py

+13
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import pytest
1111

1212
import mcbackend
13+
from mcbackend import utils as mutils
1314
from mcbackend.meta import ChainMeta, DataVariable, RunMeta, Variable
1415
from mcbackend.npproto import utils
1516

@@ -407,3 +408,15 @@ def test__big_variables(self):
407408
speed = self.measure_big_variables()
408409
assert speed.draws_per_second > 500 or speed.mib_per_second > 5
409410
pass
411+
412+
413+
def test_as_array_from_ragged():
414+
even = mutils.as_array_from_ragged(
415+
[
416+
numpy.ones(2),
417+
numpy.ones(3),
418+
]
419+
)
420+
assert isinstance(even, numpy.ndarray)
421+
assert even.dtype == numpy.dtype(object)
422+
pass

mcbackend/utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""Contains helper functions that are independent of McBackend components."""
2+
from typing import Sequence
3+
4+
import numpy as np
5+
6+
7+
def as_array_from_ragged(arrs: Sequence[np.ndarray]) -> np.ndarray:
8+
shapes = {np.shape(arr) for arr in arrs}
9+
if len(shapes) > 1:
10+
return np.array(arrs, dtype=object)
11+
return np.array(arrs)

0 commit comments

Comments
 (0)