Open
Description
Description
This blogpost walks through the logic for 3 different examples: https://www.pymc-labs.com/blog-posts/jax-functions-in-pymc-3-quick-examples/ and shows the logic is always the same:
- Wrap jitted forward pass in Op
- Wrap jitted jvp (or vjp I can never remember) as a GradOp to provide gradient implementation
- Dispatch unjitted versions of the two Ops for integration with `function(... , mode="JAX")
Things that cannot be obtained automatically (or maybe they can?) and should be opt-in as in @as_op
:
4. Input and outputs types
5. infer_shape