Skip to content

Fix rejection-based truncation of scalar variables #6923

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Sep 22, 2023

Reported in pymc-devs/pytensor#442


📚 Documentation preview 📚: https://pymc--6923.org.readthedocs.build/en/6923/

@codecov
Copy link

codecov bot commented Sep 22, 2023

Codecov Report

Merging #6923 (5fa5721) into main (df7b267) will increase coverage by 0.38%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #6923      +/-   ##
==========================================
+ Coverage   91.78%   92.17%   +0.38%     
==========================================
  Files         100      100              
  Lines       16845    16847       +2     
==========================================
+ Hits        15462    15528      +66     
+ Misses       1383     1319      -64     
Files Coverage Δ
pymc/distributions/truncated.py 99.41% <100.00%> (+<0.01%) ⬆️

... and 4 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the fix_truncated_rejection_sampling_scalars branch from 629fdef to ff30b9d Compare September 22, 2023 16:03
Copy link
Member

@lucianopaz lucianopaz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks very nice @ricardoV94. I might need you to guide me a bit through the logic here because I haven't grasped the subtleties of trying to truncate SymbolicDistributions yet.

truncated_rv = pt.set_subtensor(
truncated_rv[reject_draws],
new_truncated_rv[reject_draws],
)
reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this also have some kind of pt.and_(not(reject_draws), ... so that the draws that were already accepted don't get resampled? I fail to see where that is happening.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The draws that were already accepted will always be within upper and lower. The set_subtensor only changes the indexes that were not already valid.

reject_draws = pt.or_((truncated_rv < lower), (truncated_rv > upper))

return (
(truncated_rv, reject_draws),
[(rng, next_rng)],
collect_default_updates([new_truncated_rv]),
until(~pt.any(reject_draws)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will this until cut the scan short of the max_n_steps if the condition is met sooner?

Copy link
Member Author

@ricardoV94 ricardoV94 Oct 3, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that's the whole point of the until (Also you can never define one without max_n_steps if that helps)

@ricardoV94 ricardoV94 changed the title Allow Truncation of SymbolicRandomVariables Allow Truncation of CustomDist Oct 3, 2023
@ricardoV94 ricardoV94 force-pushed the fix_truncated_rejection_sampling_scalars branch 2 times, most recently from d583fd6 to b39ceae Compare October 3, 2023 16:16
@ricardoV94 ricardoV94 marked this pull request as draft October 3, 2023 16:27
@ricardoV94 ricardoV94 force-pushed the fix_truncated_rejection_sampling_scalars branch 2 times, most recently from 0e0b8ef to 3a42050 Compare October 11, 2023 10:09
@ricardoV94 ricardoV94 changed the title Allow Truncation of CustomDist Fix rejection-based truncation of scalars variables Oct 11, 2023
@ricardoV94 ricardoV94 changed the title Fix rejection-based truncation of scalars variables Fix rejection-based truncation of scalar variables Oct 11, 2023
@ricardoV94
Copy link
Member Author

@lucianopaz, I realized I needed a bigger refactor, mostly because pymc-devs/pytensor#473 makes it hard to box other SymbolicRVs safely.

This PR is now just fixing the bug with the scalar case, and I'll open another one later with then new functionality

@ricardoV94 ricardoV94 force-pushed the fix_truncated_rejection_sampling_scalars branch from 3a42050 to b574db7 Compare October 11, 2023 10:13
@ricardoV94 ricardoV94 requested a review from lucianopaz October 11, 2023 10:13
@ricardoV94 ricardoV94 marked this pull request as ready for review October 11, 2023 10:14
@ricardoV94 ricardoV94 mentioned this pull request Oct 11, 2023
2 tasks
@ricardoV94 ricardoV94 force-pushed the fix_truncated_rejection_sampling_scalars branch from b574db7 to 5fa5721 Compare October 11, 2023 13:29
@ricardoV94 ricardoV94 merged commit 6b486b9 into pymc-devs:main Oct 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants