Skip to content

Fix JAX Scan for output ndim > 1 #288

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
May 10, 2023
Merged

Conversation

jessegrabowski
Copy link
Member

Motivation for these changes

Closes #287 Currently, JAX compilation of scans with n-dimensional outputs fails.

Implementation details

Checks input and output dimensions, and adds a batch dimension to initial inputs where necessary.

Marked as draft because there are still some problems:

  1. Currently, just check if output.ndim > 1 and insert a single dimension to the initial input if so. I think this is the only case possible given the shape checking that is done by scan itself. Since all outputs must be 1 dimension smaller than the inputs, I don't think it's possible to reach a case where more than 1 batch dimension needs to be inserted. Nevertheless, the machinery to do this exists in this code if necessary.

  2. Big problem: nd sequences and mit_sots do not work. This is captured by two tests in this PR: test_nd_scan_mit_sot and test_nd_scan_sit_sot_with_seq. The error raised in both cases is NotImplementedError: JAX does not support slicing arrays with a dynamic slice length.. I'm not sure why the slices are dynamic though, we should be able to pre-compute all the required slices? It just doesn't seem much different from the normal 1d array case.

Checklist

Major / Breaking Changes

I hope none, all tests in test/linker/jax/test_scan pass

New features

JAX support for more flavors of scan

Bugfixes

Scan no longer errors when outputs are ndim > 1 and mode = JAX

Documentation

None

Maintenance

  • ...

@jessegrabowski
Copy link
Member Author

I played around a bit more with this and I'm stumped on problem 2. I've been thinking more about problem 1, and I think it's true that my solution is too complex. It should suffice to just check which (if any) scan outputs have dimension > 0, and add a single batch dimension if so. I also noticed the solution as-is is failing the benchmark checks, which might be related to having the range function inside the jax JIT-compiled code. I will push some commits trying to improve this later today.

Otherwise, I would like some thoughts/guidance on how to attack problem 2. I don't think any solution would be acceptable unless matrix-valued sequences (read: data) are supported. mit-sot support is also pretty mandatory in my mind (i.e. for writing VAR(p) in a very natural way)

@ricardoV94 ricardoV94 assigned ricardoV94 and unassigned ricardoV94 May 8, 2023
@ricardoV94 ricardoV94 self-requested a review May 8, 2023 16:14
@ricardoV94
Copy link
Member

ricardoV94 commented May 9, 2023

The limitation with the dynamic slices is a real thing with JAX. But not too bad when creating models with static type shapes (which 99% of PyMC models are). You can see that we provide static shapes in some of our JAX scan tests to make them work:

x0 = at.vector("x0", dtype="float64", shape=(3,))

You should be able to do the same in your new tests?

@ricardoV94 ricardoV94 added bug Something isn't working jax scan labels May 9, 2023
@ricardoV94 ricardoV94 changed the title Add JAX support for scan when output ndim > 1 Fix JAX Scan for output ndim > 1 May 9, 2023
@jessegrabowski
Copy link
Member Author

The limitation with the dynamic slices is a real thing with JAX. But not too bad when creating models with static type shapes (which 99% of PyMC models are). You can see that we provide static shapes in some of our JAX scan tests to make them work:

x0 = at.vector("x0", dtype="float64", shape=(3,))

You should be able to do the same in your new tests?

All tests now pass. I'm going to move this out of draft.

@jessegrabowski jessegrabowski marked this pull request as ready for review May 9, 2023 09:41
@ricardoV94
Copy link
Member

ricardoV94 commented May 9, 2023

You seem to have a minor dtype issue in the float32 run. Usually we set the test values dtype to config.floatX

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome, thanks for fixing it.

@jessegrabowski
Copy link
Member Author

The benchmark is failing again, is it just flakey or could it be related to the extra operations this PR adds?

@codecov-commenter
Copy link

Codecov Report

Merging #288 (48a237b) into main (ab1850d) will increase coverage by 0.01%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #288      +/-   ##
==========================================
+ Coverage   80.36%   80.37%   +0.01%     
==========================================
  Files         153      153              
  Lines       44875    44891      +16     
  Branches    10991    10992       +1     
==========================================
+ Hits        36064    36082      +18     
+ Misses       6606     6605       -1     
+ Partials     2205     2204       -1     
Impacted Files Coverage Δ
pytensor/link/jax/dispatch/scan.py 100.00% <100.00%> (ø)

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 merged commit 9ae07ab into pymc-devs:main May 10, 2023
@jessegrabowski jessegrabowski deleted the jax_nd_scan branch May 10, 2023 11:57
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working jax scan
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: JAX Scan fails for outputs with dims > 1
3 participants