-
Notifications
You must be signed in to change notification settings - Fork 129
Allow explicit RNG and Sparse input types in JAX functions #278
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
c23f461
to
64be119
Compare
64be119
to
6c44e24
Compare
6c44e24
to
bbca7b5
Compare
Codecov Report
Additional details and impacted files@@ Coverage Diff @@
## main #278 +/- ##
=======================================
Coverage 80.46% 80.46%
=======================================
Files 156 156
Lines 45514 45522 +8
Branches 11148 11154 +6
=======================================
+ Hits 36624 36631 +7
+ Misses 6688 6686 -2
- Partials 2202 2205 +3
|
cb4ab40
to
09f0d99
Compare
09f0d99
to
0b53267
Compare
|
||
|
||
@jax_typify.register(SparseTensorType) | ||
def jax_typify_SparseTensorType(type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the idea to have BCOO be the default sparse tensor type, or is it a stopgap? I think some algorithms prefer different types, so it'd be good long term to have different subclasses for SparseTensorType (BCOO, CSC, etc.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
From reading JAX docs it seems they are pushing for BCOO only at the moment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Jax's sparse support isn't great though, I'm not sure they're the best lead to follow. Or, I guess this PR is about pytensor's Jax support only and not necessarily other backends?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is just JAX backend. AFAICT BCOO is the only thing somewhat supported. Their other format (CSC or CSR) doesn't allow for any of the other jax transformations (vmap, grad, jit?). They pushed a paper on BCOO so I think it's really what their planning publicly at least.
Pytensor itself uses scipy formats as well as numba (haven't worked on it much tough)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah gotcha. I didn't know about the paper, will try and find that. And that will make it more difficult if Jax has a particular way of handling this vs scipy or numba.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw pytensor only supports a subset of scipy formats (crs and csc, scipy has 7 formats listed). Numba supports the same formats pytensor does, but that's not a coincidence. My point is that there's room to redefine what pytensor's sparse formats should be, if it were advantageous to do so.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jessegrabowski that's definitely true.
Still, it's unlikely that we will have a common set of types (RNG / Shared / Tuples / Whatever), that work for all backends. This PR is more focused on how we can allow specalized backend-only types and not about deciding which specific types we want to provide to users as default in PyTensor.
0b53267
to
a6bf472
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Some small comments, nothing that should hold up the PR.
|
||
with pytest.warns( | ||
UserWarning, | ||
match="RandomTypes are implicitly converted to random PRNGKey arrays", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tangental to this PR, but what is the purpose of this PR? As a user I see it every time I compile a model with RVs to JAX, but there's nothing I can do to avoid this warning yes? What danger does it alert me to?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tells the user that the input type of the compiled function is no longer the standard PyTensor one (such as np.random.default_rng
or scipy.sparse.csr
) but a JAX-specific one. Everytime the function is called the type will be coerced into that, so if performance is critical the user can pass already that one.
Because this is a very edge case (users creating PyTensor functions with explicit sparse or RNG variables) I think it's better to just have a warning than trying to provide keyword arguments or the specific types.
If adoption increases / users complaint then it may be worthwhile to revisit our approach.
# Inputs - Value outputs | ||
assert fn(np_rng)[1] == fn(np_rng)[1] | ||
assert fn(jx_rng)[1] == fn(jx_rng)[1] | ||
assert fn(np_rng)[1] != fn(jx_rng)[1] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the desired behavior, or just how it works out?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just how it works. We are not going out of our way to try and find a 1-1 map between numpy generator and JAX PRNGKeys (if that's even possible)
@pytest.mark.parametrize("sparse_type", ("csc", "csr")) | ||
def test_sparse_io(sparse_type): | ||
"""Test explicit (non-shared) input and output sparse types in JAX.""" | ||
sparse_mat = ps.matrix(format=sparse_type, name="csc", dtype="float32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not floatX
here (and below)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
JAX doesn't like float64, so this probably avoids annoying warnings from them
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In general the floatX thing from the times of Theano is a very lazy way of testing one specific and back then popular mode "float32/" (or "FAST_COMPILE") doesn't accidentally break. This is a bit silly, because there's a thousand other configs that one could argue are also important to test and we do not (e.g., no optimization at all).
What I am trying to say is that the float32/float64 tests are not a reasonable way to actually check our code works with custom configurations. For legacy we keep doing it in old tests but I don't really try to introduce them all the time in new tests unless I foresee a reason why float32/float64 flexibility should be tested specifically.
|
||
|
||
@pytest.mark.parametrize("sparse_type", ("csc", "csr")) | ||
def test_sparse_io(sparse_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like it is testing the sparse transpose Op
as well, the name should reflect that or the transpose test should be split into a separate test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The sparse transpose Op is just a tool, not the thing of interest in this test (we need to have some sparse operation, otherwise it would be testing a not-implemented sparse Copy Op).
|
||
|
||
@jax_typify.register(SparseTensorType) | ||
def jax_typify_SparseTensorType(type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fwiw pytensor only supports a subset of scipy formats (crs and csc, scipy has 7 formats listed). Numba supports the same formats pytensor does, but that's not a coincidence. My point is that there's room to redefine what pytensor's sparse formats should be, if it were advantageous to do so.
isinstance(inp.type, RandomType) and not isinstance(inp, SharedVariable) | ||
for inp in fgraph.inputs | ||
): | ||
warnings.warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment to the one in the tests. Should a user using pytensor to compile to jax really be expected to provide jax-valued inputs? If so, why not just write in JAX directly (instead of passing through pytensor), and if not why issue this warning?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I commented above, the point is PyTensor is an abstraction and then tries to lower to specialized backends. It's not always possible to lower from a generic format (in our case numpy-scipy) into specialized ones (like JAX) because they may have types that are fundamentally different. Most inputs tend to be numpy-like arrays that our backends all support natively so these concerns don't show up. But once we move to stuff like Sparse and RNG variables things break down.
In the spirit of allowing one graph - multiple backends we have to at some point coerce types into the specific ones required by the backends. The other option would be for users to create the specialized graphs with specialized types (which then couldn't be compiled to other types). In that case you get closer to your point of "why not just write in JAX" that this approach actually tries to avoid.
For libraries like PyMC it's quite important that we don't force users to specify the backend in advance. Specially because the standard workflow actually uses different backends for different tasks (e.g. random sampling and mcmc sampling)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could not issue a warning (it's one of the things I am asking feedback about).
However, it's getting clear to me that when we come around to addressing the type specialization for shared variables there will probably be a need for a warning, because those variables will have to de copied over (and no longer be "shared" with the original function) because they can't override the original variables.
We have such a warning for Shared RNGs, which you have probably seen, and is conceptually a bit different than the one being introduced here for explicit/nonshared RNGs and Sparse Variables. That one is a hackish thing for shared/implicit RNG variables only.
This PR is a first step to try and address this type issues coherently. It was just much easier to start with explicit/nonshared variables because they are function-specific.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This one has been rotting for a while. I double-checked the code and it seemed fine to merge as long as its still relevant?
I will sort out Shared Variables in a separate PR