-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
- Fix regression caused by #4211 #4285
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Hi @ricardoV94 - good catch! So, this doesn't affect Is it possible to write a test case which fails on |
Exactly
I am implementing more targeted tests in my work on #4107, which will include this issue as a subset (that's how I found it). Is it okay if I don't include the test until then? I just thought it might be important to fix this upstream ASAP, as it might affect other developers. |
I don't know what's normally done in PyMC3 but that seems reasonable to me (unless there's a simple test you already have ready), let's just check the current test suite passes cc @StephenHogg (just to keep you in the loop, and certainly not to try to cast any blame 😄 ) |
Codecov Report
@@ Coverage Diff @@
## master #4285 +/- ##
==========================================
+ Coverage 87.59% 87.63% +0.04%
==========================================
Files 88 88
Lines 14312 14316 +4
==========================================
+ Hits 12536 12546 +10
+ Misses 1776 1770 -6
|
I added a test that tries to make sure that the starting point returned by It is a bit convoluted because it requires making the standard sampling fail on purpose with a random_seed. I had to do this because the automatically assigned start point does not seem to be accessible from the outside, and so I cannot test against it directly. If anyone has a better idea, let me know. [Note]: I also had to move the second If it does not, the arguments |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch!
Considering that our sampling code has a lot of branching to deal with different settings, it's probably not the last time that we touch this part of the code. But all these little increments help to improve the experience and get us to a well-defined API..
Feel free to ignore - could the test be re-written to just patch the relevant sub-component of
Then just set appropriate details for the mocked object. This would obviate the potential future issue you've identified. |
Thanks for the suggestion. If I understand correctly, you would use the mock to directly control what |
Yes, that's it exactly - mocking can be extremely powerful like this. The rest of the code doesn't change at all, you just fake an error from one function. |
check_start_vals(start_, model) | ||
if start is None: | ||
start = start_ | ||
check_start_vals(start, model) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does check_start_vals
not need to run here if start
is not None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, because that was already done at the beginning of the function (lines 416-424), and regardless of whether start was None or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you're right, thanks!
Or, seeing as pymc3 uses def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch):
# This test tries to check whether the starting points returned by init_nuts are actually being
# used when pm.sample() is called without specifying an explicit start point (see
# https://github.com/pymc-devs/pymc3/pull/4285).
def _mocked_init_nuts(*args, **kwargs):
if init == 'adapt_diag':
start_ = [{'x': np.array(0.79788456)}]
else:
start_ = [{'x': np.array(-0.04949886)}]
_, step = init_nuts(*args, **kwargs)
return start_, step
monkeypatch.setattr('pymc3.sampling.init_nuts', _mocked_init_nuts)
with pm.Model() as m:
x = pm.HalfNormal("x", transform=None)
with expectation:
pm.sample(tune=1, draws=0, chains=1, init=init, start=start) |
@MarcoGorelli that's fantastic! I will change it soon to work like that. |
If you enable pre-commit (see the pymc3 code style wiki page) this'll be less painful 😄 |
Thanks @ricardoV94 @MarcoGorelli! |
PR #4211 caused a subtle regression. Default sampling no longer applies the
jitter+adapt_diag
becausestart
is set tomodel.test_point
by default, in order to test if the model likelihood is finite withpm.util.check_start_vals
.https://github.com/pymc-devs/pymc3/blob/988ab9dd54e83dbd336fb4463aa16c8c8ce7ccfd/pymc3/sampling.py#L417
This conflicts with this line further down the road, where the undefined
start
values would be given those from the auto-assigned NUTS samplerhttps://github.com/pymc-devs/pymc3/blob/988ab9dd54e83dbd336fb4463aa16c8c8ce7ccfd/pymc3/sampling.py#L494
I altered the code so that the
model.test_point
is still tested, without overriding the samplestart
variable.I found this issue when working in #4107