Skip to content

Update example in "Adding JAX and Numba support for Ops" #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 10, 2024
Merged
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
148 changes: 96 additions & 52 deletions doc/extending/creating_a_numba_jax_op.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,55 +7,61 @@ this, each :class:`Op` in an PyTensor graph must have an equivalent JAX/Numba im
This tutorial will explain how JAX and Numba implementations are created for an :class:`Op`. It will
focus specifically on the JAX case, but the same mechanisms are used for Numba as well.

Step 1: Identify the PyTensor :class:`Op` youd like to implement in JAX
Step 1: Identify the PyTensor :class:`Op` you'd like to implement in JAX
------------------------------------------------------------------------

Find the source for the PyTensor :class:`Op` youd like to be supported in JAX, and
identify the function signature and return values. These can be determined by
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
Find the source for the PyTensor :class:`Op` you'd like to be supported in JAX, and
identify the function signature and return values. These can be determined by
looking at the :meth:`Op.make_node` implementation. In general, one needs to be familiar
with PyTensor :class:`Op`\s in order to provide a conversion implementation, so first read
:ref:`creating_an_op` if you are not familiar.

For example, the :class:`Eye`\ :class:`Op` current has an :meth:`Op.make_node` as follows:
For example, the :class:`FillDiagonal`\ :class:`Op` current has an :meth:`Op.make_node` as follows:

.. code:: python

def make_node(self, n, m, k):
n = as_tensor_variable(n)
m = as_tensor_variable(m)
k = as_tensor_variable(k)
assert n.ndim == 0
assert m.ndim == 0
assert k.ndim == 0
return Apply(
self,
[n, m, k],
[TensorType(dtype=self.dtype, shape=(None, None))()],
)
def make_node(self, a, val):
a = ptb.as_tensor_variable(a)
val = ptb.as_tensor_variable(val)
if a.ndim < 2:
raise TypeError(
"%s: first parameter must have at least"
" two dimensions" % self.__class__.__name__
)
elif val.ndim != 0:
raise TypeError(
f"{self.__class__.__name__}: second parameter must be a scalar"
)
val = ptb.cast(val, dtype=upcast(a.dtype, val.dtype))
if val.dtype != a.dtype:
raise TypeError(
"%s: type of second parameter must be the same as"
" the first's" % self.__class__.__name__
)
return Apply(self, [a, val], [a.type()])


The :class:`Apply` instance that's returned specifies the exact types of inputs that
our JAX implementation will receive and the exact types of outputs it's expected to
return--both in terms of their data types and number of dimensions.
return--both in terms of their data types and number of dimensions/shapes.
The actual inputs our implementation will receive are necessarily numeric values
or NumPy :class:`ndarray`\s; all that :meth:`Op.make_node` tells us is the
general signature of the underlying computation.

More specifically, the :class:`Apply` implies that the inputs come from values that are
More specifically, the :class:`Apply` implies that the inputs come from two values that are
automatically converted to PyTensor variables via :func:`as_tensor_variable`, and
the ``assert``\s that follow imply that they must be scalars. According to this
the ``assert``\s that follow imply that the first one must be a tensor with at least two
dimensions (i.e., matrix) and the second must be a scalar. According to this
logic, the inputs could have any data type (e.g. floats, ints), so our JAX
implementation must be able to handle all the possible data types.

It also tells us that there's only one return value, that it has a data type
determined by :attr:`Eye.dtype`, and that it has two non-broadcastable
dimensions. The latter implies that the result is necessarily a matrix. The
former implies that our JAX implementation will need to access the :attr:`dtype`
attribute of the PyTensor :class:`Eye`\ :class:`Op` it's converting.
determined by :meth:`a.type()` i.e., the data type of the original tensor.
This implies that the result is necessarily a matrix.

Next, we can look at the :meth:`Op.perform` implementation to see exactly
how the inputs and outputs are used to compute the outputs for an :class:`Op`
in Python. This method is effectively what needs to be implemented in JAX.
in Python. This method is effectively what needs to be implemented in JAX.


Step 2: Find the relevant JAX method (or something close)
Expand All @@ -82,47 +88,47 @@ Here's an example for :class:`IfElse`:
)
return res if n_outs > 1 else res[0]

In this case, we have to use custom logic to implement the JAX version of
:class:`FillDiagonal` since JAX has no equivalent implementation. We have to use
:meth:`jax.numpy.diag_indices` to find the indices of the diagonal elements and then set
them to the value we want.

Step 3: Register the function with the `jax_funcify` dispatcher
---------------------------------------------------------------

With the PyTensor `Op` replicated in JAX, well need to register the
With the PyTensor `Op` replicated in JAX, we'll need to register the
function with the PyTensor JAX `Linker`. This is done through the use of
`singledispatch`. If you don't know how `singledispatch` works, see the
`Python documentation <https://docs.python.org/3/library/functools.html#functools.singledispatch>`_.

The relevant dispatch functions created by `singledispatch` are :func:`pytensor.link.numba.dispatch.numba_funcify` and
:func:`pytensor.link.jax.dispatch.jax_funcify`.

Heres an example for the `Eye`\ `Op`:
Here's an example for the `FillDiagonal`\ `Op`:

.. code:: python

import jax.numpy as jnp

from pytensor.tensor.basic import Eye
from pytensor.tensor.extra_ops import FillDiagonal
from pytensor.link.jax.dispatch import jax_funcify


@jax_funcify.register(Eye)
def jax_funcify_Eye(op):
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
i, j = jnp.diag_indices(min(value.shape[-2:]))
return value.at[..., i, j].set(diagonal)

# Obtain necessary "static" attributes from the Op being converted
dtype = op.dtype

# Create a JAX jit-able function that implements the Op
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)

return eye
return filldiagonal


Step 4: Write tests
-------------------

Test that your registered `Op` is working correctly by adding tests to the
appropriate test suites in PyTensor (e.g. in ``tests.link.test_jax`` and one of
the modules in ``tests.link.numba.dispatch``). The tests should ensure that your implementation can
appropriate test suites in PyTensor (e.g. in ``tests.link.jax`` and one of
the modules in ``tests.link.numba``). The tests should ensure that your implementation can
handle the appropriate types of inputs and produce outputs equivalent to `Op.perform`.
Check the existing tests for the general outline of these kinds of tests. In
most cases, a helper function can be used to easily verify the correspondence
Expand All @@ -131,23 +137,61 @@ between a JAX/Numba implementation and its `Op`.
For example, the :func:`compare_jax_and_py` function streamlines the steps
involved in making comparisons with `Op.perform`.

Here's a small example of a test for :class:`Eye`:
Here's a small example of a test for :class:`FillDiagonal`:

.. code:: python
import numpy as np
import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value

def test_jax_FillDiagonal():
"""Test JAX conversion of the `FillDiagonal` `Op`."""

# Create a symbolic input for the first input of `FillDiagonal`
a = pt.matrix("a")

# Create test value tag for a
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))

# Create a scalar value for the second input
c = ptb.as_tensor(5)

import pytensor.tensor as pt
# Create the output variable
out = pt.fill_diagonal(a, c)

# Create a PyTensor `FunctionGraph`
fgraph = FunctionGraph([a], [out])

# Pass the graph and inputs to the testing function
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

Note
----
In out previous example of extending JAX, :class:`Eye`\ :class:`Op` was used with the test function as follows:

.. code:: python
def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""

def test_jax_Eye():
"""Test JAX conversion of the `Eye` `Op`."""
# Create a symbolic input for `Eye`
x_at = pt.scalar()

# Create a symbolic input for `Eye`
x_at = pt.scalar()
# Create a variable that is the output of an `Eye` `Op`
eye_var = pt.eye(x_at)

# Create a variable that is the output of an `Eye` `Op`
eye_var = pt.eye(x_at)
# Create an PyTensor `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])

# Create an PyTensor `FunctionGraph`
out_fg = FunctionGraph(outputs=[eye_var])
# Pass the graph and any inputs to the testing function
compare_jax_and_py(out_fg, [3])

# Pass the graph and any inputs to the testing function
compare_jax_and_py(out_fg, [3])
This one nowadays leads to a test failure due to new restrictions in JAX + JIT,
as reported in issue `#654 <https://github.com/pymc-devs/pytensor/issues/654>`_.
All jitted functions now must have constant shape, which means a graph like the
one of :class:`Eye` can never be translated to JAX, since it's fundamentally a
function with dynamic shapes. In other words, only PyTensor graphs with static shapes
can be translated to JAX at the moment.