Skip to content

- Fix regression caused by #4211 #4285

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Dec 5, 2020
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 2, 2020

PR #4211 caused a subtle regression. Default sampling no longer applies the jitter+adapt_diag because start is set to model.test_point by default, in order to test if the model likelihood is finite with pm.util.check_start_vals.

https://github.com/pymc-devs/pymc3/blob/988ab9dd54e83dbd336fb4463aa16c8c8ce7ccfd/pymc3/sampling.py#L417

This conflicts with this line further down the road, where the undefined start values would be given those from the auto-assigned NUTS sampler
https://github.com/pymc-devs/pymc3/blob/988ab9dd54e83dbd336fb4463aa16c8c8ce7ccfd/pymc3/sampling.py#L494

I altered the code so that the model.test_point is still tested, without overriding the sample start variable.

I found this issue when working in #4107

@MarcoGorelli
Copy link
Contributor

Hi @ricardoV94 - good catch! So, this doesn't affect check_start_vals, it just avoids assigning a value to start if it's None?

Is it possible to write a test case which fails on master but which passes with this change?

@ricardoV94
Copy link
Member Author

Hi @ricardoV94 - good catch! So, this doesn't affect check_start_vals, it just avoids assigning a value to start if it's None?

Exactly

Is it possible to write a test case which fails on master but which passes with this change?

I am implementing more targeted tests in my work on #4107, which will include this issue as a subset (that's how I found it). Is it okay if I don't include the test until then? I just thought it might be important to fix this upstream ASAP, as it might affect other developers.

@MarcoGorelli
Copy link
Contributor

Is it okay if I don't include the test until then? I just thought it might be important to fix this upstream ASAP, as it might affect other developers.

I don't know what's normally done in PyMC3 but that seems reasonable to me (unless there's a simple test you already have ready), let's just check the current test suite passes

cc @StephenHogg (just to keep you in the loop, and certainly not to try to cast any blame 😄 )

@codecov
Copy link

codecov bot commented Dec 2, 2020

Codecov Report

Merging #4285 (6699d66) into master (7ff2f49) will increase coverage by 0.04%.
The diff coverage is 100.00%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #4285      +/-   ##
==========================================
+ Coverage   87.59%   87.63%   +0.04%     
==========================================
  Files          88       88              
  Lines       14312    14316       +4     
==========================================
+ Hits        12536    12546      +10     
+ Misses       1776     1770       -6     
Impacted Files Coverage Δ
pymc3/sampling.py 86.70% <100.00%> (+0.49%) ⬆️
pymc3/step_methods/hmc/base_hmc.py 90.83% <0.00%> (-0.84%) ⬇️
pymc3/sampling_jax.py 0.00% <0.00%> (ø)
pymc3/distributions/continuous.py 93.31% <0.00%> (+0.03%) ⬆️
pymc3/parallel_sampling.py 86.79% <0.00%> (+0.94%) ⬆️

@twiecki twiecki added this to the 3.10 milestone Dec 3, 2020
@ricardoV94
Copy link
Member Author

ricardoV94 commented Dec 3, 2020

I don't know what's normally done in PyMC3 but that seems reasonable to me (unless there's a simple test you already have ready), let's just check the current test suite passes

I added a test that tries to make sure that the starting point returned by init_nuts() is actually being used when start=None, but that it is not overriding a specific start point defined by the user. This test would not capture the exact previous regression (see [Note] below), but it should capture similar issues if they were to reoccur in the future. For example, if start is None is changed to start = model.test_point, the new test will fail.

It is a bit convoluted because it requires making the standard sampling fail on purpose with a random_seed. I had to do this because the automatically assigned start point does not seem to be accessible from the outside, and so I cannot test against it directly. If anyone has a better idea, let me know.

[Note]: I also had to move the second check_start_vals to the next if block for the test to work out. This seems fine to me, because if we are not using the new values as a start point, there is no point in testing them. Is this correct? This behavior assumes that it makes sense to use the "jitter+adapt_diag" with a custom fixed starting point for sampling (which was the default behavior already, I did not change it!). Does it?

If it does not, the arguments start != None and init="auto" or init="jitter+... are at odds with each other.

Copy link
Member

@michaelosthege michaelosthege left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch!
Considering that our sampling code has a lot of branching to deal with different settings, it's probably not the last time that we touch this part of the code. But all these little increments help to improve the experience and get us to a well-defined API..

@StephenHogg
Copy link
Contributor

Feel free to ignore - could the test be re-written to just patch the relevant sub-component of pm.sample so that it fails deterministically, rather than relying on a random seed? Should just be a matter of a context manager along the lines of:

with mock.patch('whatever.function') as m:

Then just set appropriate details for the mocked object. This would obviate the potential future issue you've identified.

@ricardoV94
Copy link
Member Author

Feel free to ignore - could the test be re-written to just patch the relevant sub-component of pm.sample so that it fails deterministically, rather than relying on a random seed? Should just be a matter of a context manager along the lines of:

with mock.patch('whatever.function') as m:

Then just set appropriate details for the mocked object. This would obviate the potential future issue you've identified.

Thanks for the suggestion.

If I understand correctly, you would use the mock to directly control what init_nuts returns, instead of relying on a custom seed as I was doing. The main test logic would still rely on making sure the SamplingError is being raised when the returned start value is being used and fail otherwise. Is that the idea?

@StephenHogg
Copy link
Contributor

Thanks for the suggestion.

If I understand correctly, you would use the mock to directly control what init_nuts returns, instead of relying on a custom seed as I was doing. The main test logic would still rely on making sure the SamplingError is being raised when the returned start value is being used and fail otherwise. Is that the idea?

Yes, that's it exactly - mocking can be extremely powerful like this. The rest of the code doesn't change at all, you just fake an error from one function.

Comment on lines -493 to +495
check_start_vals(start_, model)
if start is None:
start = start_
check_start_vals(start, model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does check_start_vals not need to run here if start is not None?

Copy link
Member Author

@ricardoV94 ricardoV94 Dec 3, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, because that was already done at the beginning of the function (lines 416-424), and regardless of whether start was None or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you're right, thanks!

@MarcoGorelli
Copy link
Contributor

MarcoGorelli commented Dec 3, 2020

with mock.patch('whatever.function') as m:

Or, seeing as pymc3 uses pytest, you could use the built-in monkeypatch and do something like:

def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch):
    # This test tries to check whether the starting points returned by init_nuts are actually being
    # used when pm.sample() is called without specifying an explicit start point (see
    # https://github.com/pymc-devs/pymc3/pull/4285).
    def _mocked_init_nuts(*args, **kwargs):
        if init == 'adapt_diag':
            start_ = [{'x': np.array(0.79788456)}] 
        else:
            start_ = [{'x': np.array(-0.04949886)}]
        _, step = init_nuts(*args, **kwargs)
        return start_, step
    monkeypatch.setattr('pymc3.sampling.init_nuts', _mocked_init_nuts)
    with pm.Model() as m:
        x = pm.HalfNormal("x", transform=None)
        with expectation:
            pm.sample(tune=1, draws=0, chains=1, init=init, start=start)

@ricardoV94
Copy link
Member Author

@MarcoGorelli that's fantastic! I will change it soon to work like that.

@MarcoGorelli
Copy link
Contributor

Black formatting, once again...

If you enable pre-commit (see the pymc3 code style wiki page) this'll be less painful 😄

@Spaak Spaak merged commit 9311899 into pymc-devs:master Dec 5, 2020
@Spaak
Copy link
Member

Spaak commented Dec 5, 2020

Thanks @ricardoV94 @MarcoGorelli!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants