Skip to content

Fix Scan JAX dispatcher #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Apr 4, 2023
Merged

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Mar 2, 2023

Closes #20

Should be working now

NotImplemented:

  • While Scans (JAX can only JIT while scans where only the last state is needed)
  • MIT-MOT Scans (did not have time yet to dig how those are supposed to work)

@ricardoV94 ricardoV94 force-pushed the fix_jax_scan_dispatch branch 2 times, most recently from c6a0e5b to 76d27b2 Compare March 2, 2023 16:35
@ricardoV94 ricardoV94 requested a review from Armavica March 2, 2023 16:59
@ricardoV94 ricardoV94 marked this pull request as ready for review March 2, 2023 16:59
@ricardoV94 ricardoV94 marked this pull request as draft March 4, 2023 19:08
@ricardoV94
Copy link
Member Author

Want to simplify a bit the logic with some of the scan helpers now that I know they exist. Can still be reviewed and tested upon!

@ricardoV94 ricardoV94 force-pushed the fix_jax_scan_dispatch branch from 76d27b2 to 8b89b2a Compare March 6, 2023 11:53
@ricardoV94
Copy link
Member Author

Cleaned up

@ricardoV94 ricardoV94 marked this pull request as ready for review March 6, 2023 11:53
@ricardoV94 ricardoV94 force-pushed the fix_jax_scan_dispatch branch 4 times, most recently from 86d84a3 to b0aa0b9 Compare March 6, 2023 12:39
Copy link
Member

@Armavica Armavica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

First round of questions, I am stopping here to not overwhelm you and because I think that some of them might help me understand the rest :)

@ricardoV94 ricardoV94 force-pushed the fix_jax_scan_dispatch branch from b0aa0b9 to d5ba7fe Compare March 17, 2023 12:24
@ricardoV94
Copy link
Member Author

@Armavica I addressed some of your suggestions. Feel free to take another stab!

@ricardoV94 ricardoV94 requested a review from Armavica March 17, 2023 12:24
@ricardoV94 ricardoV94 force-pushed the fix_jax_scan_dispatch branch from d5ba7fe to 7c35a6c Compare March 22, 2023 07:22
Copy link
Member

@Armavica Armavica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I think I understand most of it now. I am just stuck on a few details.

@ricardoV94 ricardoV94 force-pushed the fix_jax_scan_dispatch branch from 7c35a6c to 1ba5887 Compare April 4, 2023 16:07
@ricardoV94 ricardoV94 requested a review from Armavica April 4, 2023 16:09
@codecov-commenter
Copy link

Codecov Report

Merging #232 (1ba5887) into main (cbef7d5) will increase coverage by 0.03%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #232      +/-   ##
==========================================
+ Coverage   80.43%   80.46%   +0.03%     
==========================================
  Files         170      170              
  Lines       45394    45412      +18     
  Branches    11088    11087       -1     
==========================================
+ Hits        36512    36541      +29     
+ Misses       6654     6639      -15     
- Partials     2228     2232       +4     
Impacted Files Coverage Δ
pytensor/link/jax/dispatch/scan.py 100.00% <100.00%> (+84.74%) ⬆️
pytensor/scalar/basic.py 79.78% <100.00%> (+0.01%) ⬆️
pytensor/tensor/sharedvar.py 83.33% <100.00%> (+0.72%) ⬆️

... and 9 files with indirect coverage changes

Copy link
Member

@Armavica Armavica left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, thank you for your patience and your explanations!

@ricardoV94 ricardoV94 merged commit 88cc33b into pymc-devs:main Apr 4, 2023
@ricardoV94 ricardoV94 added the enhancement New feature or request label Apr 7, 2023
@ricardoV94 ricardoV94 deleted the fix_jax_scan_dispatch branch June 21, 2023 08:55
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
backend compatibility bug Something isn't working enhancement New feature or request jax scan
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: Scan inner graphs are not optimized in NUMBA / JAX backends
3 participants