-
Notifications
You must be signed in to change notification settings - Fork 129
Fix JAX Scan for output ndim > 1 #288
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
I played around a bit more with this and I'm stumped on problem 2. I've been thinking more about problem 1, and I think it's true that my solution is too complex. It should suffice to just check which (if any) scan outputs have dimension > 0, and add a single batch dimension if so. I also noticed the solution as-is is failing the benchmark checks, which might be related to having the Otherwise, I would like some thoughts/guidance on how to attack problem 2. I don't think any solution would be acceptable unless matrix-valued sequences (read: data) are supported. mit-sot support is also pretty mandatory in my mind (i.e. for writing VAR(p) in a very natural way) |
The limitation with the dynamic slices is a real thing with JAX. But not too bad when creating models with static type shapes (which 99% of PyMC models are). You can see that we provide static shapes in some of our JAX scan tests to make them work: pytensor/tests/link/jax/test_scan.py Line 40 in c84bd0b
You should be able to do the same in your new tests? |
All tests now pass. I'm going to move this out of draft. |
You seem to have a minor dtype issue in the float32 run. Usually we set the test values dtype to |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome, thanks for fixing it.
The benchmark is failing again, is it just flakey or could it be related to the extra operations this PR adds? |
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #288 +/- ##
==========================================
+ Coverage 80.36% 80.37% +0.01%
==========================================
Files 153 153
Lines 44875 44891 +16
Branches 10991 10992 +1
==========================================
+ Hits 36064 36082 +18
+ Misses 6606 6605 -1
+ Partials 2205 2204 -1
|
Motivation for these changes
Closes #287 Currently, JAX compilation of scans with n-dimensional outputs fails.
Implementation details
Checks input and output dimensions, and adds a batch dimension to initial inputs where necessary.
Marked as draft because there are still some problems:
Currently, just check if output.ndim > 1 and insert a single dimension to the initial input if so. I think this is the only case possible given the shape checking that is done by
scan
itself. Since all outputs must be 1 dimension smaller than the inputs, I don't think it's possible to reach a case where more than 1 batch dimension needs to be inserted. Nevertheless, the machinery to do this exists in this code if necessary.Big problem: nd sequences and mit_sots do not work. This is captured by two tests in this PR:
test_nd_scan_mit_sot
andtest_nd_scan_sit_sot_with_seq
. The error raised in both cases isNotImplementedError: JAX does not support slicing arrays with a dynamic slice length.
. I'm not sure why the slices are dynamic though, we should be able to pre-compute all the required slices? It just doesn't seem much different from the normal 1d array case.Checklist
Major / Breaking Changes
I hope none, all tests in test/linker/jax/test_scan pass
New features
JAX support for more flavors of scan
Bugfixes
Scan no longer errors when outputs are ndim > 1 and mode = JAX
Documentation
None
Maintenance