Description
Description
pytensor.function()
returns a class with a complicated __call__
method that puts inputs and allocates outputs in list like objects that are very much tuned to the C backend. This means that a generic function compiled to JAX or Numba will in general not work within a longer JAX / Numba workflow (e.g., calling vmap or grad on a compiled function).
We could provide a simpler jax_function
and numba_function
that do just that. In PyMC we implemented something like that for JAX: https://github.com/pymc-devs/pymc/blob/31c30dc1beea26e4bff52a93037540923feaaa84/pymc/sampling/jax.py#L108-L132
There is one obvious limitation which concerns the handling of shared variables and updates. Shared variables are global variables that are passed as inputs to the actual inner function but not provided explicitly by the user. Updates replace the original value of of shared variables by a (user-hidden) output of the function every time it is called.
A simple JAX/Numba PyTensor function with global variables and updates looks like this:
import pytensor
import pytensor.tensor as pt
import numpy as np
shared_y = pytensor.shared(np.ones((5,)))
x = pt.vector("x")
fn = pytensor.function([x], x + y, updates={y: y + 1}, mode="JAX")
And roughly translates to the following pseudo code:
global shared_y = np.ones((5,))
def fn(x):
@jax.jit
def inner_fn(x, y):
return x + y, y + 1
global shared_y
out, update_y = inner_fn(x, shared_y)
shared_y[:] = update_y
return out
I don't think neither JAX nor Numba support stateful jitted functions, so users would need to work with the inner_fn
directly.
The proposal here is to give users easy access to the compiled (jitted or not) inner_fn