Skip to content

Allow explicit RNG and Sparse input types in JAX functions #278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Apr 18, 2023

I will sort out Shared Variables in a separate PR

@codecov-commenter
Copy link

codecov-commenter commented Apr 18, 2023

Codecov Report

Merging #278 (a6bf472) into main (9be43d0) will increase coverage by 0.00%.
The diff coverage is 91.42%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main     #278   +/-   ##
=======================================
  Coverage   80.46%   80.46%           
=======================================
  Files         156      156           
  Lines       45514    45522    +8     
  Branches    11148    11154    +6     
=======================================
+ Hits        36624    36631    +7     
+ Misses       6688     6686    -2     
- Partials     2202     2205    +3     
Impacted Files Coverage Δ
pytensor/link/jax/dispatch/basic.py 82.81% <50.00%> (+4.68%) ⬆️
pytensor/link/jax/dispatch/sparse.py 88.57% <84.21%> (+5.23%) ⬆️
pytensor/link/jax/dispatch/random.py 94.76% <94.73%> (-1.32%) ⬇️
pytensor/link/basic.py 87.44% <100.00%> (+0.10%) ⬆️
pytensor/link/jax/linker.py 95.74% <100.00%> (+0.87%) ⬆️

... and 2 files with indirect coverage changes

@ricardoV94 ricardoV94 force-pushed the backend_types branch 5 times, most recently from cb4ab40 to 09f0d99 Compare July 7, 2023 10:54


@jax_typify.register(SparseTensorType)
def jax_typify_SparseTensorType(type):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea to have BCOO be the default sparse tensor type, or is it a stopgap? I think some algorithms prefer different types, so it'd be good long term to have different subclasses for SparseTensorType (BCOO, CSC, etc.)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From reading JAX docs it seems they are pushing for BCOO only at the moment

Copy link

@bwengals bwengals Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Jax's sparse support isn't great though, I'm not sure they're the best lead to follow. Or, I guess this PR is about pytensor's Jax support only and not necessarily other backends?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 17, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is just JAX backend. AFAICT BCOO is the only thing somewhat supported. Their other format (CSC or CSR) doesn't allow for any of the other jax transformations (vmap, grad, jit?). They pushed a paper on BCOO so I think it's really what their planning publicly at least.

Pytensor itself uses scipy formats as well as numba (haven't worked on it much tough)

Copy link

@bwengals bwengals Jul 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah gotcha. I didn't know about the paper, will try and find that. And that will make it more difficult if Jax has a particular way of handling this vs scipy or numba.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw pytensor only supports a subset of scipy formats (crs and csc, scipy has 7 formats listed). Numba supports the same formats pytensor does, but that's not a coincidence. My point is that there's room to redefine what pytensor's sparse formats should be, if it were advantageous to do so.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jessegrabowski that's definitely true.

Still, it's unlikely that we will have a common set of types (RNG / Shared / Tuples / Whatever), that work for all backends. This PR is more focused on how we can allow specalized backend-only types and not about deciding which specific types we want to provide to users as default in PyTensor.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some small comments, nothing that should hold up the PR.


with pytest.warns(
UserWarning,
match="RandomTypes are implicitly converted to random PRNGKey arrays",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tangental to this PR, but what is the purpose of this PR? As a user I see it every time I compile a model with RVs to JAX, but there's nothing I can do to avoid this warning yes? What danger does it alert me to?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This tells the user that the input type of the compiled function is no longer the standard PyTensor one (such as np.random.default_rng or scipy.sparse.csr) but a JAX-specific one. Everytime the function is called the type will be coerced into that, so if performance is critical the user can pass already that one.

Because this is a very edge case (users creating PyTensor functions with explicit sparse or RNG variables) I think it's better to just have a warning than trying to provide keyword arguments or the specific types.

If adoption increases / users complaint then it may be worthwhile to revisit our approach.

# Inputs - Value outputs
assert fn(np_rng)[1] == fn(np_rng)[1]
assert fn(jx_rng)[1] == fn(jx_rng)[1]
assert fn(np_rng)[1] != fn(jx_rng)[1]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the desired behavior, or just how it works out?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just how it works. We are not going out of our way to try and find a 1-1 map between numpy generator and JAX PRNGKeys (if that's even possible)

@pytest.mark.parametrize("sparse_type", ("csc", "csr"))
def test_sparse_io(sparse_type):
"""Test explicit (non-shared) input and output sparse types in JAX."""
sparse_mat = ps.matrix(format=sparse_type, name="csc", dtype="float32")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not floatX here (and below)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX doesn't like float64, so this probably avoids annoying warnings from them

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general the floatX thing from the times of Theano is a very lazy way of testing one specific and back then popular mode "float32/" (or "FAST_COMPILE") doesn't accidentally break. This is a bit silly, because there's a thousand other configs that one could argue are also important to test and we do not (e.g., no optimization at all).

What I am trying to say is that the float32/float64 tests are not a reasonable way to actually check our code works with custom configurations. For legacy we keep doing it in old tests but I don't really try to introduce them all the time in new tests unless I foresee a reason why float32/float64 flexibility should be tested specifically.



@pytest.mark.parametrize("sparse_type", ("csc", "csr"))
def test_sparse_io(sparse_type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like it is testing the sparse transpose Op as well, the name should reflect that or the transpose test should be split into a separate test

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sparse transpose Op is just a tool, not the thing of interest in this test (we need to have some sparse operation, otherwise it would be testing a not-implemented sparse Copy Op).



@jax_typify.register(SparseTensorType)
def jax_typify_SparseTensorType(type):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fwiw pytensor only supports a subset of scipy formats (crs and csc, scipy has 7 formats listed). Numba supports the same formats pytensor does, but that's not a coincidence. My point is that there's room to redefine what pytensor's sparse formats should be, if it were advantageous to do so.

isinstance(inp.type, RandomType) and not isinstance(inp, SharedVariable)
for inp in fgraph.inputs
):
warnings.warn(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar comment to the one in the tests. Should a user using pytensor to compile to jax really be expected to provide jax-valued inputs? If so, why not just write in JAX directly (instead of passing through pytensor), and if not why issue this warning?

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I commented above, the point is PyTensor is an abstraction and then tries to lower to specialized backends. It's not always possible to lower from a generic format (in our case numpy-scipy) into specialized ones (like JAX) because they may have types that are fundamentally different. Most inputs tend to be numpy-like arrays that our backends all support natively so these concerns don't show up. But once we move to stuff like Sparse and RNG variables things break down.

In the spirit of allowing one graph - multiple backends we have to at some point coerce types into the specific ones required by the backends. The other option would be for users to create the specialized graphs with specialized types (which then couldn't be compiled to other types). In that case you get closer to your point of "why not just write in JAX" that this approach actually tries to avoid.

For libraries like PyMC it's quite important that we don't force users to specify the backend in advance. Specially because the standard workflow actually uses different backends for different tasks (e.g. random sampling and mcmc sampling)

Copy link
Member Author

@ricardoV94 ricardoV94 Jul 24, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could not issue a warning (it's one of the things I am asking feedback about).

However, it's getting clear to me that when we come around to addressing the type specialization for shared variables there will probably be a need for a warning, because those variables will have to de copied over (and no longer be "shared" with the original function) because they can't override the original variables.

We have such a warning for Shared RNGs, which you have probably seen, and is conceptually a bit different than the one being introduced here for explicit/nonshared RNGs and Sparse Variables. That one is a hackish thing for shared/implicit RNG variables only.

This PR is a first step to try and address this type issues coherently. It was just much easier to start with explicit/nonshared variables because they are function-specific.

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one has been rotting for a while. I double-checked the code and it seemed fine to merge as long as its still relevant?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants