-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Make coords and data always mutable #7047
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7047 +/- ##
===========================================
- Coverage 92.30% 39.54% -52.77%
===========================================
Files 100 101 +1
Lines 16895 16835 -60
===========================================
- Hits 15595 6657 -8938
- Misses 1300 10178 +8878
|
01deacd
to
28ded31
Compare
28ded31
to
55693f4
Compare
4445317
to
546d59e
Compare
546d59e
to
e35c402
Compare
Love it. |
Woohoo! |
We need to adapt the pymc-examples (and maybe the NBs here too?). |
Possibly |
It shouldn't fail immediately, just issue a warning. Only when we remove the kwarg will it fail hard |
Hi @ricardoV94! Here is an example: with pm.Model():
pm.Data("b", [True, False], dtype=bool)
Do you have any idea why this is happening? Thanks in advance for your help! |
Closes #6972
This PR provides a new model transform that freezes RV dims that depend on coords as well as mutable data, for those worried about performance issues or incompatibilities with JAX dynamic shape limitations. I expect most users won't need this. The default C-backend doesn't really exploit static shapes. I believe the simpler API for users is a net win.
Note that JAX dynamic shape stuff is not relevant when using JAX samplers because we already replace any shared variables by constants anyways. It's only relevant when compiling PyTensor functions with mode="JAX"
The big picture here is that we define the most general model first, and later specialize if needed. Going from a model with constant shapes to another with different constant shapes is generally not possible because PyTensor eagerly computes static shape outputs for intermediate nodes, and rebuilding with different constant types is not always supported.
Starting with more general models could be quite helpful for producing predictive models automatically.
Note: If there's resistance, this PR can be narrowed down in scope to just remove the distinction between coords_mutable and coords, but still leave MutableData vs ConstantData