Skip to content

Commit 9311899

Browse files
authored
- Fix regression caused by #4211 (#4285)
* - Fix regression caused by #4211 * - Add test to make sure jitter is being applied to chains starting points by default * - Import appropriate empty context for python < 3.7 * - Apply black formatting * - Change the second check_start_vals to explicitly run on the newly assigned start variable. * - Improve test documentation and add a new condition * Use monkeypatch for more robust test * - Black formatting, once again...
1 parent 198d13e commit 9311899

File tree

2 files changed

+34
-4
lines changed

2 files changed

+34
-4
lines changed

pymc3/sampling.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,15 +416,15 @@ def sample(
416416
"""
417417
model = modelcontext(model)
418418
if start is None:
419-
start = model.test_point
419+
check_start_vals(model.test_point, model)
420420
else:
421421
if isinstance(start, dict):
422422
update_start_vals(start, model.test_point, model)
423423
else:
424424
for chain_start_vals in start:
425425
update_start_vals(chain_start_vals, model.test_point, model)
426+
check_start_vals(start, model)
426427

427-
check_start_vals(start, model)
428428
if cores is None:
429429
cores = min(4, _cpu_count())
430430

@@ -492,9 +492,9 @@ def sample(
492492
progressbar=progressbar,
493493
**kwargs,
494494
)
495-
check_start_vals(start_, model)
496495
if start is None:
497496
start = start_
497+
check_start_vals(start, model)
498498
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
499499
# gradient computation failed
500500
_log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.")

pymc3/tests/test_sampling.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from contextlib import ExitStack as does_not_raise
1516
from itertools import combinations
1617
from typing import Tuple
1718
import numpy as np
@@ -25,7 +26,7 @@
2526
import theano
2627
from pymc3.tests.models import simple_init
2728
from pymc3.tests.helpers import SeededTest
28-
from pymc3.exceptions import IncorrectArgumentsError
29+
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
2930
from scipy import stats
3031
import pytest
3132

@@ -785,6 +786,35 @@ def test_exec_nuts_init(method):
785786
assert "a" in start[0] and "b_log__" in start[0]
786787

787788

789+
@pytest.mark.parametrize(
790+
"init, start, expectation",
791+
[
792+
("auto", None, pytest.raises(SamplingError)),
793+
("jitter+adapt_diag", None, pytest.raises(SamplingError)),
794+
("auto", {"x": 0}, does_not_raise()),
795+
("jitter+adapt_diag", {"x": 0}, does_not_raise()),
796+
("adapt_diag", None, does_not_raise()),
797+
],
798+
)
799+
def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch):
800+
# This test tries to check whether the starting points returned by init_nuts are actually
801+
# being used when pm.sample() is called without specifying an explicit start point (see
802+
# https://github.com/pymc-devs/pymc3/pull/4285).
803+
def _mocked_init_nuts(*args, **kwargs):
804+
if init == "adapt_diag":
805+
start_ = [{"x": np.array(0.79788456)}]
806+
else:
807+
start_ = [{"x": np.array(-0.04949886)}]
808+
_, step = pm.init_nuts(*args, **kwargs)
809+
return start_, step
810+
811+
monkeypatch.setattr("pymc3.sampling.init_nuts", _mocked_init_nuts)
812+
with pm.Model() as m:
813+
x = pm.HalfNormal("x", transform=None)
814+
with expectation:
815+
pm.sample(tune=1, draws=0, chains=1, init=init, start=start)
816+
817+
788818
@pytest.fixture(scope="class")
789819
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
790820
with pm.Model() as pmodel:

0 commit comments

Comments
 (0)