Skip to content

- Fix regression caused by #4211 #4285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 5, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,15 +414,15 @@ def sample(
"""
model = modelcontext(model)
if start is None:
start = model.test_point
check_start_vals(model.test_point, model)
else:
if isinstance(start, dict):
update_start_vals(start, model.test_point, model)
else:
for chain_start_vals in start:
update_start_vals(chain_start_vals, model.test_point, model)
check_start_vals(start, model)

check_start_vals(start, model)
if cores is None:
cores = min(4, _cpu_count())

Expand Down Expand Up @@ -490,9 +490,9 @@ def sample(
progressbar=progressbar,
**kwargs,
)
check_start_vals(start_, model)
if start is None:
start = start_
check_start_vals(start, model)
Comment on lines -493 to +495
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does check_start_vals not need to run here if start is not None?

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because that was already done at the beginning of the function (lines 416-424), and regardless of whether start was None or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, thanks!

except (AttributeError, NotImplementedError, tg.NullTypeGradError):
# gradient computation failed
_log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.")
Expand Down
32 changes: 31 additions & 1 deletion pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import ExitStack as does_not_raise
from itertools import combinations
from typing import Tuple
import numpy as np
Expand All @@ -25,7 +26,7 @@
import theano
from pymc3.tests.models import simple_init
from pymc3.tests.helpers import SeededTest
from pymc3.exceptions import IncorrectArgumentsError
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
from scipy import stats
import pytest

Expand Down Expand Up @@ -761,6 +762,35 @@ def test_exec_nuts_init(method):
assert "a" in start[0] and "b_log__" in start[0]


@pytest.mark.parametrize(
"init, start, expectation",
[
("auto", None, pytest.raises(SamplingError)),
("jitter+adapt_diag", None, pytest.raises(SamplingError)),
("auto", {"x": 0}, does_not_raise()),
("jitter+adapt_diag", {"x": 0}, does_not_raise()),
("adapt_diag", None, does_not_raise()),
],
)
def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch):
# This test tries to check whether the starting points returned by init_nuts are actually
# being used when pm.sample() is called without specifying an explicit start point (see
# https://github.com/pymc-devs/pymc3/pull/4285).
def _mocked_init_nuts(*args, **kwargs):
if init == "adapt_diag":
start_ = [{"x": np.array(0.79788456)}]
else:
start_ = [{"x": np.array(-0.04949886)}]
_, step = pm.init_nuts(*args, **kwargs)
return start_, step

monkeypatch.setattr("pymc3.sampling.init_nuts", _mocked_init_nuts)
with pm.Model() as m:
x = pm.HalfNormal("x", transform=None)
with expectation:
pm.sample(tune=1, draws=0, chains=1, init=init, start=start)


@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
with pm.Model() as pmodel:
Expand Down