-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Make VI compatible with JAX backend #7103
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7103 +/- ##
==========================================
- Coverage 91.87% 91.11% -0.76%
==========================================
Files 100 100
Lines 16874 16858 -16
==========================================
- Hits 15503 15361 -142
- Misses 1371 1497 +126
|
pymc/pytensorf.py
Outdated
@@ -47,6 +46,7 @@ | |||
from pytensor.graph.fg import FunctionGraph | |||
from pytensor.graph.op import Op | |||
from pytensor.scalar.basic import Cast | |||
from pytensor.scalar.basic import identity as scalar_identity |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You don't need to create a new Elemwise, there's already one defined in tensor.math
(or basic), just called tensor_copy
@@ -387,7 +386,7 @@ def hessian_diag(f, vars=None): | |||
return empty_gradient | |||
|
|||
|
|||
identity = Elemwise(scalar_identity, name="identity") | |||
identity = tensor_copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick just import it directly in the VI module, no need to define it in pytensorf?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be used by someone else I assume
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think so, but even if we keep we should add a deprecatation warning
Windows tests seem to be very weird and can't reproduce it on a Linux machine, is shape inference platform dependent? |
windows behaves differently with regard to integers. Default type is int32, which sometimes causes problems due to some rewrite or check that doesn't expect that (shape in PyTensor is supposed to be int64) Just a guess from previous experiences. I can have a look on my windows machine next week |
I see one of the issues got resolved with sort op recently. Any updates for Windows? |
I don't think anyone investigated the problem yet |
How about marking these tests as xfail then? |
Let me or someone investigate on a Windows machine. Seems like an important failure on Windows. In the meantime you can rebase and pin PyMC to the next PyTensor version to see if the current xfail can be removed? |
2330568
to
30a2d73
Compare
@ricardoV94 updated the dependency on pytensor and commented on one of the xfails in the tests. Hope windows tests get resolved with newer pytensor |
In addition, mypy started to complain about pytensor
|
@ferrine feel free to rebase, we have already bumped the dependency on main |
def test_vi_sampling_jax(method): | ||
with pm.Model() as model: | ||
x = pm.Normal("x") | ||
pm.fit(10, method=method, fn_kwargs=dict(mode="JAX")) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be consistent with pm.sample
and the nuts_sampler=
arg, should we have a dedicated argument for the VI backend instead of kwargs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I vote yes, this API looks super weird.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What looks weird? This is the compilation mode, would be exactly the same if you wanted to use Numba or JAX for the PyMC nuts sampler or for prior/posterior predictive.
The only thing I would change is the name of fn_kwargs, which is called compile_kwargs I think in those other functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't this be what the user would have to do if they wanted to run VI on JAX?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand the question, this PR is just doing minor tweaks so the PyMC VI module can compile to JAX. It's not linking to specific JAX VI libraries.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We used this for sample_posterior_predictive for projects just last week, as we were sampling new variables that had heavy matmuls, went down from hours to minutes.
Great idea, should definitely add it there too.
pm.sample is still useful as you can sample discrete variables with JAX this way.
That makes sense, I'm not opposed to adding it there. Maybe we can add a warning that the sampler is still running Python and they likely will want to use nuts_sampler
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still doing python loops, it's exactly the same argument you need for pm.sample.
It's different than linking to a JAX VI library, which is what would be equivalent to the nuts_sampler
kwarg that Chris mentioned in the first comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is still doing python loops, it's exactly the same argument you need for pm.sample.
Oh, I somehow assumed that VI was implemented mostly in PyTensor?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As for this, I'd prefer to focus this PR on backend compatibility and later address possible API changes in a new issue + PR. Agreed that there is inconsistency, we need to resolve that, but this will only defer the push to main with at least some working solution which went through many issues already.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed @ferrine. My only suggestion is to switch fn_kwargs
to compile_kwargs
which we use in the other sample methods
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #7103 +/- ##
==========================================
- Coverage 92.34% 91.11% -1.23%
==========================================
Files 102 100 -2
Lines 17032 16858 -174
==========================================
- Hits 15728 15361 -367
- Misses 1304 1497 +193
|
Just rebased, let's see how it goes |
rebased the old PR to see if any issues got resolved |
|
||
scalar_identity = IdentityOp(scalar.upgrade_to_float, name="scalar_identity") | ||
identity = Elemwise(scalar_identity, name="identity") | ||
identity = tensor_copy |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just do from pytensor... import tensor_copy as identity
?
|
||
TypeError: The broadcast pattern of the output of scan | ||
(Matrix(float64, shape=(?, 1))) is inconsistent with the one provided in `output_info` | ||
(Vector(float64, shape=(?,))). The output on axis 0 is `True`, but it is `False` on axis |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is actually something wrong, if the number of dimensions of a recurring output is different from the initial state. The difference between None and 1 is more annoying but this one looks like an error.
Description
Related Issue
Checklist
Type of change
📚 Documentation preview 📚: https://pymc--7103.org.readthedocs.build/en/7103/