Skip to content

Implement helper @as_jax_op to wrap JAX functions in PyTensor #537

Open
@ricardoV94

Description

@ricardoV94

Description

This blogpost walks through the logic for 3 different examples: https://www.pymc-labs.com/blog-posts/jax-functions-in-pymc-3-quick-examples/ and shows the logic is always the same:

  1. Wrap jitted forward pass in Op
  2. Wrap jitted jvp (or vjp I can never remember) as a GradOp to provide gradient implementation
  3. Dispatch unjitted versions of the two Ops for integration with `function(... , mode="JAX")

Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in @as_op:
4. Input and outputs types
5. infer_shape

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions