Skip to content

How to wrap a JAX function for use in PyMC (the automatic way) #755

Open
@jdehning

Description

@jdehning

Notebook proposal

Title: How to wrap a JAX function for use in PyMC (the automatic way)

Why should this notebook be added to pymc-examples?

The new wrapper @as_jax_op (still in a draft PR, but the functionality is there) requires some examples to showcase its functionality.

I would propose to have two parts, first an example of solving an ODE, similar to what I wrote here, but with only diffrax as external dependency.

Second, rewrite the existing notebook on how to wrap a function a Jax function, but using @as_jax_op instead of defining the operators manually.

Suggested categories:

  • Level: Intermediate
  • Diataxis type: How-to guide

Related notebooks

Relates to
https://www.pymc.io/projects/examples/en/latest/howto/wrapping_jax_function.html but simplifies the building of the Op. I would keep the existing one, as it explains in more depth what is happening behind the scenes.

Metadata

Metadata

Assignees

No one assigned

    Labels

    proposalNew notebook proposal still up for discussion

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions