Skip to content

Commit d175203

Browse files
Add JAX support for SortOp (#657)
1 parent ad55b69 commit d175203

File tree

2 files changed

+18
-0
lines changed

2 files changed

+18
-0
lines changed

pytensor/link/jax/dispatch/tensor_basic.py

+9
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
)
2323
from pytensor.tensor.exceptions import NotScalarConstantError
2424
from pytensor.tensor.shape import Shape_i
25+
from pytensor.tensor.sort import SortOp
2526

2627

2728
ARANGE_CONCRETE_VALUE_ERROR = """JAX requires the arguments of `jax.numpy.arange`
@@ -205,3 +206,11 @@ def tri(*args):
205206
return jnp.tri(*args, dtype=op.dtype)
206207

207208
return tri
209+
210+
211+
@jax_funcify.register(SortOp)
212+
def jax_funcify_Sort(op, **kwargs):
213+
def sort(arr, axis):
214+
return jnp.sort(arr, axis=axis)
215+
216+
return sort

tests/link/jax/test_tensor_basic.py

+9
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,15 @@ def test_tri():
218218
compare_jax_and_py(fgraph, [])
219219

220220

221+
@pytest.mark.parametrize("axis", [None, -1])
222+
def test_sort(axis):
223+
x = matrix("x", shape=(2, 2), dtype="float64")
224+
out = pytensor.tensor.sort(x, axis=axis)
225+
fgraph = FunctionGraph([x], [out])
226+
arr = np.array([[1.0, 4.0], [5.0, 2.0]])
227+
compare_jax_and_py(fgraph, [arr])
228+
229+
221230
def test_tri_nonconcrete():
222231
"""JAX cannot JIT-compile `jax.numpy.tri` when arguments are not concrete values."""
223232

0 commit comments

Comments
 (0)