-
Notifications
You must be signed in to change notification settings - Fork 129
Add JAX support for SortOp #657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #657 +/- ##
==========================================
+ Coverage 80.80% 80.82% +0.02%
==========================================
Files 162 162
Lines 46743 46820 +77
Branches 11419 11438 +19
==========================================
+ Hits 37770 37844 +74
+ Misses 6731 6725 -6
- Partials 2242 2251 +9
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I left just two small modification suggestions for readability and more extensive testing
def sort(arr, *args): | ||
return jnp.sort(arr, *args) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def sort(arr, *args): | |
return jnp.sort(arr, *args) | |
def sort(arr, axis): | |
return jnp.sort(arr, axis=axis) |
tests/link/jax/test_tensor_basic.py
Outdated
def test_sort(): | ||
x = matrix("x") | ||
out = pytensor.tensor.sort(x) | ||
fgraph = FunctionGraph([x], [out]) | ||
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) | ||
compare_jax_and_py(fgraph, [arr]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def test_sort(): | |
x = matrix("x") | |
out = pytensor.tensor.sort(x) | |
fgraph = FunctionGraph([x], [out]) | |
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) | |
compare_jax_and_py(fgraph, [arr]) | |
@pytest.mark.parametrize("axis", [None, -1]) | |
def test_sort(axis): | |
x = matrix("x", shape=(2, 2), dtype="float64") | |
out = pytensor.tensor.sort(x, axis=axis) | |
fgraph = FunctionGraph([x], [out]) | |
arr = np.array([[1.0, 4.0], [5.0, 2.0]]) | |
compare_jax_and_py(fgraph, [arr]) |
Thanks @HarshvirSandhu |
Description
Implement JAX conversion for SortOp
Related Issue
Checklist
Type of change