Skip to content

Commit 4cccb46

Browse files
authored
Improve sampling coverage (#4270)
* improve coverage * redistribute tests * use np.shape
1 parent edbafaa commit 4cccb46

File tree

4 files changed

+32
-33
lines changed

4 files changed

+32
-33
lines changed

.github/workflows/pytest.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,10 @@ jobs:
3939
pymc3/tests/test_distributions_timeseries.py
4040
pymc3/tests/test_parallel_sampling.py
4141
pymc3/tests/test_random.py
42-
pymc3/tests/test_sampling.py
4342
pymc3/tests/test_shared.py
4443
pymc3/tests/test_smc.py
4544
- |
4645
pymc3/tests/test_examples.py
47-
pymc3/tests/test_gp.py
4846
pymc3/tests/test_mixture.py
4947
pymc3/tests/test_posteriors.py
5048
pymc3/tests/test_quadpotential.py
@@ -54,6 +52,8 @@ jobs:
5452
pymc3/tests/test_variational_inference.py
5553
- |
5654
pymc3/tests/test_distributions.py
55+
pymc3/tests/test_gp.py
56+
pymc3/tests/test_sampling.py
5757
runs-on: ${{ matrix.os }}
5858
env:
5959
TEST_SUBSET: ${{ matrix.test-subset }}

pymc3/sampling.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,7 @@
1414

1515
"""Functions for MCMC sampling."""
1616

17-
from typing import Dict, List, Optional, TYPE_CHECKING, cast, Union, Any
18-
19-
if TYPE_CHECKING:
20-
from typing import Tuple
17+
from typing import Dict, List, Optional, cast, Union, Any
2118
from typing import Iterable as TIterable
2219
from collections.abc import Iterable
2320
from collections import defaultdict
@@ -218,11 +215,7 @@ def assign_step_methods(model, step=None, methods=STEP_METHODS, step_kwargs=None
218215

219216

220217
def _print_step_hierarchy(s, level=0):
221-
if isinstance(s, (list, tuple)):
222-
_log.info(">" * level + "list")
223-
for i in s:
224-
_print_step_hierarchy(i, level + 1)
225-
elif isinstance(s, CompoundStep):
218+
if isinstance(s, CompoundStep):
226219
_log.info(">" * level + "CompoundStep")
227220
for i in s.methods:
228221
_print_step_hierarchy(i, level + 1)
@@ -458,7 +451,7 @@ def sample(
458451

459452
if return_inferencedata is None:
460453
v = packaging.version.parse(pm.__version__)
461-
if v.release[0] > 3 or v.release[1] >= 10:
454+
if v.release[0] > 3 or v.release[1] >= 10: # type: ignore
462455
warnings.warn(
463456
"In an upcoming release, pm.sample will return an `arviz.InferenceData` object instead of a `MultiTrace` by default. "
464457
"You can pass return_inferencedata=True or return_inferencedata=False to be safe and silence this warning.",
@@ -585,7 +578,7 @@ def sample(
585578
UserWarning,
586579
)
587580
_print_step_hierarchy(step)
588-
trace = _sample_population(**sample_args, parallelize=cores > 1)
581+
trace = _sample_population(parallelize=cores > 1, **sample_args)
589582
else:
590583
_log.info(f"Sequential sampling ({chains} chains in 1 job)")
591584
_print_step_hierarchy(step)
@@ -770,11 +763,9 @@ def _sample_population(
770763
trace : MultiTrace
771764
Contains samples of all chains
772765
"""
773-
# create the generator that iterates all chains in parallel
774-
chains = [chain + c for c in range(chains)]
775766
sampling = _prepare_iter_population(
776767
draws,
777-
chains,
768+
[chain + c for c in range(chains)],
778769
step,
779770
start,
780771
parallelize,
@@ -1582,10 +1573,7 @@ def insert(self, k: str, v, idx: int):
15821573
ids: int
15831574
The index of the sample we are inserting into the trace.
15841575
"""
1585-
if hasattr(v, "shape"):
1586-
value_shape = tuple(v.shape) # type: Tuple[int, ...]
1587-
else:
1588-
value_shape = ()
1576+
value_shape = np.shape(v)
15891577

15901578
# initialize if necessary
15911579
if k not in self.trace_dict:

pymc3/tests/test_sampling.py

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,9 @@
1313
# limitations under the License.
1414

1515
from itertools import combinations
16-
import packaging
1716
from typing import Tuple
1817
import numpy as np
19-
20-
try:
21-
import unittest.mock as mock # py3
22-
except ImportError:
23-
from unittest import mock
18+
import unittest.mock as mock
2419

2520
import numpy.testing as npt
2621
import arviz as az
@@ -180,13 +175,9 @@ def test_trace_report_bart(self):
180175
assert var_imp[0] > var_imp[1:].sum()
181176
npt.assert_almost_equal(var_imp.sum(), 1)
182177

183-
def test_return_inferencedata(self):
178+
def test_return_inferencedata(self, monkeypatch):
184179
with self.model:
185180
kwargs = dict(draws=100, tune=50, cores=1, chains=2, step=pm.Metropolis())
186-
v = packaging.version.parse(pm.__version__)
187-
if v.major > 3 or v.minor >= 10:
188-
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
189-
result = pm.sample(**kwargs)
190181

191182
# trace with tuning
192183
with pytest.warns(UserWarning, match="will be included"):
@@ -203,12 +194,25 @@ def test_return_inferencedata(self):
203194
assert result.posterior.sizes["chain"] == 2
204195
assert len(result._groups_warmup) > 0
205196

206-
# inferencedata without tuning
207-
result = pm.sample(**kwargs, return_inferencedata=True, discard_tuned_samples=True)
197+
# inferencedata without tuning, with idata_kwargs
198+
prior = pm.sample_prior_predictive()
199+
result = pm.sample(
200+
**kwargs,
201+
return_inferencedata=True,
202+
discard_tuned_samples=True,
203+
idata_kwargs={"prior": prior},
204+
random_seed=-1
205+
)
206+
assert "prior" in result
208207
assert isinstance(result, az.InferenceData)
209208
assert result.posterior.sizes["draw"] == 100
210209
assert result.posterior.sizes["chain"] == 2
211210
assert len(result._groups_warmup) == 0
211+
212+
# check warning for version 3.10 onwards
213+
monkeypatch.setattr("pymc3.__version__", "3.10")
214+
with pytest.warns(FutureWarning, match="pass return_inferencedata"):
215+
result = pm.sample(**kwargs)
212216
pass
213217

214218
@pytest.mark.parametrize("cores", [1, 2])

pyproject.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
[tool.black]
22
line-length = 100
33

4+
[tool.coverage.report]
5+
exclude_lines = [
6+
"pragma: nocover",
7+
"raise NotImplementedError",
8+
"if TYPE_CHECKING:",
9+
]
10+
411
[tool.nbqa.mutate]
512
isort = 1
613
black = 1

0 commit comments

Comments
 (0)