|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +from contextlib import ExitStack as does_not_raise |
15 | 16 | from itertools import combinations
|
16 | 17 | from typing import Tuple
|
17 | 18 | import numpy as np
|
|
25 | 26 | import theano
|
26 | 27 | from pymc3.tests.models import simple_init
|
27 | 28 | from pymc3.tests.helpers import SeededTest
|
28 |
| -from pymc3.exceptions import IncorrectArgumentsError |
| 29 | +from pymc3.exceptions import IncorrectArgumentsError, SamplingError |
29 | 30 | from scipy import stats
|
30 | 31 | import pytest
|
31 | 32 |
|
@@ -785,6 +786,35 @@ def test_exec_nuts_init(method):
|
785 | 786 | assert "a" in start[0] and "b_log__" in start[0]
|
786 | 787 |
|
787 | 788 |
|
| 789 | +@pytest.mark.parametrize( |
| 790 | + "init, start, expectation", |
| 791 | + [ |
| 792 | + ("auto", None, pytest.raises(SamplingError)), |
| 793 | + ("jitter+adapt_diag", None, pytest.raises(SamplingError)), |
| 794 | + ("auto", {"x": 0}, does_not_raise()), |
| 795 | + ("jitter+adapt_diag", {"x": 0}, does_not_raise()), |
| 796 | + ("adapt_diag", None, does_not_raise()), |
| 797 | + ], |
| 798 | +) |
| 799 | +def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch): |
| 800 | + # This test tries to check whether the starting points returned by init_nuts are actually |
| 801 | + # being used when pm.sample() is called without specifying an explicit start point (see |
| 802 | + # https://github.com/pymc-devs/pymc3/pull/4285). |
| 803 | + def _mocked_init_nuts(*args, **kwargs): |
| 804 | + if init == "adapt_diag": |
| 805 | + start_ = [{"x": np.array(0.79788456)}] |
| 806 | + else: |
| 807 | + start_ = [{"x": np.array(-0.04949886)}] |
| 808 | + _, step = pm.init_nuts(*args, **kwargs) |
| 809 | + return start_, step |
| 810 | + |
| 811 | + monkeypatch.setattr("pymc3.sampling.init_nuts", _mocked_init_nuts) |
| 812 | + with pm.Model() as m: |
| 813 | + x = pm.HalfNormal("x", transform=None) |
| 814 | + with expectation: |
| 815 | + pm.sample(tune=1, draws=0, chains=1, init=init, start=start) |
| 816 | + |
| 817 | + |
788 | 818 | @pytest.fixture(scope="class")
|
789 | 819 | def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
|
790 | 820 | with pm.Model() as pmodel:
|
|
0 commit comments