Open
Description
Notebook proposal
Title: How to wrap a JAX function for use in PyMC (the automatic way)
Why should this notebook be added to pymc-examples?
The new wrapper @as_jax_op (still in a draft PR, but the functionality is there) requires some examples to showcase its functionality.
I would propose to have two parts, first an example of solving an ODE, similar to what I wrote here, but with only diffrax as external dependency.
Second, rewrite the existing notebook on how to wrap a function a Jax function, but using @as_jax_op instead of defining the operators manually.
Suggested categories:
- Level: Intermediate
- Diataxis type: How-to guide
Related notebooks
Relates to
https://www.pymc.io/projects/examples/en/latest/howto/wrapping_jax_function.html but simplifies the building of the Op. I would keep the existing one, as it explains in more depth what is happening behind the scenes.