-
Notifications
You must be signed in to change notification settings - Fork 129
Implement @as_jax_op
to wrap a JAX function for use in PyTensor
#1120
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
e428fb1
9cb4cc5
43759ba
2543c9d
36a71d2
d4a0b6a
65984b0
5960947
104df83
e11777e
d2e788f
ab326e5
48fbf0a
b8c4523
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -24,4 +24,4 @@ dependencies: | |
- pip | ||
- pip: | ||
- sphinx_sitemap | ||
- -e .. | ||
- -e ..[jax] | ||
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -63,6 +63,13 @@ Convert to Variable | |
|
||
.. autofunction:: pytensor.as_symbolic(...) | ||
|
||
Wrap JAX functions | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reminds me, may want to add something on |
||
================== | ||
|
||
.. autofunction:: as_jax_op(...) | ||
|
||
Alias for :func:`pytensor.link.jax.ops.as_jax_op` | ||
|
||
Debug | ||
===== | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -167,6 +167,18 @@ def get_underlying_scalar_constant(v): | |
from pytensor.scan.views import foldl, foldr, map, reduce | ||
from pytensor.compile.builders import OpFromGraph | ||
|
||
try: | ||
import pytensor.link.jax.ops | ||
from pytensor.link.jax.ops import as_jax_op | ||
Comment on lines
+170
to
+172
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do not force eager import of JAX. You can make it inside the We did some effort to reduce import times of the library |
||
except ImportError as e: | ||
import_error_as_jax_op = e | ||
|
||
def as_jax_op(*args, **kwargs): | ||
raise ImportError( | ||
"JAX and/or equinox are not installed. Install them" | ||
" to use this function: pip install pytensor[jax]" | ||
) from import_error_as_jax_op | ||
|
||
# isort: on | ||
|
||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should be reverted?