Skip to content

Update example in "Adding JAX and Numba support for Ops" #687

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 10, 2024

Conversation

HangenYuu
Copy link
Contributor

Description

Changed the documentation page "Adding JAX and Numba support for Ops" to using a different example. The example now uses the fill_diagonal method in the extra_ops module, which I suspected should produce a static shape graph. It turned out to be correct, so I include that as the new example. I also updated minor details throughout the documentation.

Related Issue

Checklist

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

@HangenYuu
Copy link
Contributor Author

HangenYuu commented Mar 31, 2024

Here is the code snippet I used in a Jupyter Notebook after setting the correct environment from environment.yml with additional installation pip install -U "jax[cpu]" ipykernel.

# General import
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt
from pytensor.link.jax.dispatch import jax_funcify

# Import for testing
import pytensor.tensor.basic as ptb
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value

# Import for the op to extend to JAX
from pytensor.tensor.extra_ops import FillDiagonal

@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
    def filldiagonal(value, diagonal):
        i, j = jnp.diag_indices(min(value.shape[-2:]))
        return value.at[..., i, j].set(diagonal)

    return filldiagonal


def test_jax_FillDiagonal():
    """Test JAX conversion of the `FillDiagonal` `Op`."""
    a = pt.matrix("a")
    a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))

    c = ptb.as_tensor(5)

    out = pt.fill_diagonal(a, c)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

test_jax_FillDiagonal()

The working directory is the pytensor main repo folder on my machine. Here's a screenshot.
image

Currently the notebook is not part of the pull request since I am not sure whether you will need it or not.

@HangenYuu
Copy link
Contributor Author

Two other things:

  1. Docs build is successful in the PR check, but I have troubles building docs locally. As far as I can see, I don't see any theme.conf in the repo on GitHub?

image

  1. I notice that the current documentation contains a lot of outdated information (like how to install environment in Contributor Guide - we use environment.yml now, which also includes docs requirements instead of requirements.txt and requirements-rtd.txt. Is anyone actively working on updating the documentation? If not, I will open another PR to update docs to date. (I need some help with point 1 first though.)

@twiecki
Copy link
Member

twiecki commented Apr 1, 2024

Thanks, this looks good to me.

Can you open an issue (or open a PR) that fixes the theme.conf? Or is it just missing from the dev requirements?

@ricardoV94
Copy link
Member

@HangenYuu these changes are already great compared to before. I would perhaps suggest showing the CumOp Op because it also has __props__ which we should mention in the guide. These are the flags that can be used to parametrize a general Op and one should also introspect those when defining a dispatch function, in order to either support all variations or raise an explicit NonImplementedError for cases that are not supported. Perhaps CumOp(add) is supported but CumOp(prod) not in one of the backends (just as an example).

class CumOp(COp):
# See function cumsum/cumprod for docstring
__props__ = ("axis", "mode")
check_input = False
params_type = ParamsType(
c_axis=int_t, mode=EnumList(("MODE_ADD", "add"), ("MODE_MUL", "mul"))
)
def __init__(self, axis: int | None = None, mode="add"):
if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
self.axis = axis
self.mode = mode
c_axis = property(lambda self: np.MAXDIMS if self.axis is None else self.axis)
def make_node(self, x):
x = ptb.as_tensor_variable(x)
out_type = x.type()
if self.axis is None:
out_type = vector(dtype=x.dtype) # Flatten
elif self.axis >= x.ndim or self.axis < -x.ndim:
raise ValueError(f"axis(={self.axis}) out of bounds")
return Apply(self, [x], [out_type])

What do you think?

@HangenYuu
Copy link
Contributor Author

  1. I have successfully builts docs. theme.conf is missing simply because there are two separate environments ins the repo, with two environment.yml files, one is in ./environment.yml, the other is in ./doc/environment.yml. In the first one, there are already dependencies for docs building so I was mistaken, but it misses pip install git+https://github.com/pymc-devs/pymc-sphinx-theme for the theme. Maybe we want to reconcile these 2 files.
  2. As CumOp, I can look into it. My initial attempt was with MaxandArgmax, but I dropped since 2 ops in 1 seemed too complex at the time 😅. But since users may need it, I should look into it.

@ricardoV94
Copy link
Member

We may remove MaxAndMax in the future. CumOp will probably stay as it is though

@twiecki
Copy link
Member

twiecki commented Apr 3, 2024

  1. I have successfully builts docs. theme.conf is missing simply because there are two separate environments ins the repo, with two environment.yml files, one is in ./environment.yml, the other is in ./doc/environment.yml. In the first one, there are already dependencies for docs building so I was mistaken, but it misses pip install git+https://github.com/pymc-devs/pymc-sphinx-theme for the theme. Maybe we want to reconcile these 2 files.

Good catch. Checking if this might be intentional @OriolAbril, but otherwise open an issue and we can attack this separately.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@HangenYuu
Copy link
Contributor Author

HangenYuu commented Apr 6, 2024

I have changed the example to CumOp. The code I run is now

# General import
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt
from pytensor.link.jax.dispatch import jax_funcify

# Import for testing
import pytest
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value

# Import for the op to extend to JAX
from pytensor.tensor.extra_ops import CumOp

@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op, **kwargs):
    axis = op.axis
    mode = op.mode

    def cumop(x, axis=axis, mode=mode):
        if mode == "add":
            return jnp.cumsum(x, axis=axis)
        else:
            return jnp.cumprod(x, axis=axis)

    return cumop


def test_jax_CumOp():
    """Test JAX conversion of the `CumOp` `Op`."""
    a = pt.matrix("a")
    a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
    
    out = pt.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
    
    out = pt.cumprod(a, axis=1)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

test_jax_CumOp()

for the full implementation case and

# General import
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt
from pytensor.link.jax.dispatch import jax_funcify

# Import for testing
import pytest
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value

# Import for the op to extend to JAX
from pytensor.tensor.extra_ops import CumOp

@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op, **kwargs):
    axis = op.axis
    mode = op.mode

    def cumop(x, axis=axis, mode=mode):
        if mode == "add":
            return jnp.cumsum(x, axis=axis)
        else:
            raise NotImplementedError("JAX does not support cumprod function at the moment.")

    return cumop


def test_jax_CumOp():
    """Test JAX conversion of the `CumOp` `Op`."""
    a = pt.matrix("a")
    a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
    
    out = pt.cumsum(a, axis=0)
    fgraph = FunctionGraph([a], [out])
    compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
    
    with pytest.raises(NotImplementedError):
        out = pt.cumprod(a, axis=1)
        fgraph = FunctionGraph([a], [out])
        compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

test_jax_CumOp()

for the partial implementation case.

@OriolAbril
Copy link
Member

Good catch. Checking if this might be intentional @OriolAbril, but otherwise open an issue and we can attack this separately.

I don't really remember, I do remember I fixed the docs after forking, but there were a lot of useless files from both theano and aesara times, it is perfectly possible one or even both of these are not even used but as they are there someone notices them and updates them from time to time

@twiecki twiecki merged commit ca7e8b8 into pymc-devs:main Apr 10, 2024
12 checks passed
@twiecki
Copy link
Member

twiecki commented Apr 10, 2024

@HangenYuu Thanks! Can you open a separate issue about theme.conf and cleaning up environment files?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Update example in "Adding Jax support for Ops"
4 participants