-
Notifications
You must be signed in to change notification settings - Fork 130
Update example in "Adding JAX and Numba support for Ops" #687
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Two other things:
|
Thanks, this looks good to me. Can you open an issue (or open a PR) that fixes the theme.conf? Or is it just missing from the dev requirements? |
@HangenYuu these changes are already great compared to before. I would perhaps suggest showing the pytensor/pytensor/tensor/extra_ops.py Lines 273 to 299 in ef22377
What do you think? |
|
We may remove MaxAndMax in the future. CumOp will probably stay as it is though |
Good catch. Checking if this might be intentional @OriolAbril, but otherwise open an issue and we can attack this separately. |
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
I have changed the example to # General import
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt
from pytensor.link.jax.dispatch import jax_funcify
# Import for testing
import pytest
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value
# Import for the op to extend to JAX
from pytensor.tensor.extra_ops import CumOp
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
return jnp.cumprod(x, axis=axis)
return cumop
def test_jax_CumOp():
"""Test JAX conversion of the `CumOp` `Op`."""
a = pt.matrix("a")
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
out = pt.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = pt.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
test_jax_CumOp() for the full implementation case and # General import
import jax.numpy as jnp
import numpy as np
import pytensor.tensor as pt
from pytensor.link.jax.dispatch import jax_funcify
# Import for testing
import pytest
from pytensor.configdefaults import config
from tests.link.jax.test_basic import compare_jax_and_py
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value
# Import for the op to extend to JAX
from pytensor.tensor.extra_ops import CumOp
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
raise NotImplementedError("JAX does not support cumprod function at the moment.")
return cumop
def test_jax_CumOp():
"""Test JAX conversion of the `CumOp` `Op`."""
a = pt.matrix("a")
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
out = pt.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = pt.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
test_jax_CumOp() for the partial implementation case. |
I don't really remember, I do remember I fixed the docs after forking, but there were a lot of useless files from both theano and aesara times, it is perfectly possible one or even both of these are not even used but as they are there someone notices them and updates them from time to time |
@HangenYuu Thanks! Can you open a separate issue about theme.conf and cleaning up environment files? |
Description
Changed the documentation page "Adding JAX and Numba support for Ops" to using a different example. The example now uses the
fill_diagonal
method in theextra_ops
module, which I suspected should produce a static shape graph. It turned out to be correct, so I include that as the new example. I also updated minor details throughout the documentation.Related Issue
Checklist
Type of change