Skip to content

Commit cb4ab40

Browse files
committed
Typify Sparse input variables in JAX linker
1 parent 4ce0b07 commit cb4ab40

File tree

4 files changed

+129
-50
lines changed

4 files changed

+129
-50
lines changed

pytensor/link/jax/dispatch/basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def assert_fn(x, *inputs):
8787
def jnp_safe_copy(x):
8888
try:
8989
res = jnp.copy(x)
90-
except NotImplementedError:
91-
warnings.warn(
92-
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
93-
)
90+
except (NotImplementedError, TypeError):
9491
if hasattr(x, "copy"):
92+
warnings.warn(
93+
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
94+
)
9595
res = jnp.array(x.copy())
9696
else:
9797
warnings.warn(f"Object has no `copy` method: {x}")

pytensor/link/jax/dispatch/sparse.py

+44-16
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,66 @@
11
import jax.experimental.sparse as jsp
22
from scipy.sparse import spmatrix
33

4-
from pytensor.graph.basic import Constant
4+
from pytensor.graph.type import HasDataType
55
from pytensor.link.jax.dispatch import jax_funcify, jax_typify
6-
from pytensor.sparse.basic import Dot, StructuredDot
6+
from pytensor.sparse.basic import Dot, StructuredDot, Transpose
77
from pytensor.sparse.type import SparseTensorType
8+
from pytensor.tensor import TensorType
89

910

1011
@jax_typify.register(spmatrix)
1112
def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
12-
# Note: This changes the type of the constants from CSR/CSC to BCOO
13-
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
14-
# and it would break the premise of one graph -> multiple backends.
15-
# The same situation happens with RandomGenerators...
1613
return jsp.BCOO.from_scipy_sparse(matrix)
1714

1815

16+
class BCOOType(TensorType, HasDataType):
17+
"""JAX-compatible BCOO type.
18+
19+
This type is not exposed to users directly.
20+
21+
It is introduced by the JIT linker in place of any SparseTensorType input
22+
variables used in the original function. Nodes in the function graph will
23+
still show the original types as inputs and outputs.
24+
"""
25+
26+
def filter(self, data, strict: bool = False, allow_downcast=None):
27+
if isinstance(data, jsp.BCOO):
28+
return data
29+
30+
if strict:
31+
raise TypeError()
32+
33+
return jax_typify(data)
34+
35+
36+
@jax_typify.register(SparseTensorType)
37+
def jax_typify_SparseTensorType(type):
38+
return BCOOType(
39+
dtype=type.dtype,
40+
shape=type.shape,
41+
name=type.name,
42+
broadcastable=type.broadcastable,
43+
)
44+
45+
1946
@jax_funcify.register(Dot)
2047
@jax_funcify.register(StructuredDot)
2148
def jax_funcify_sparse_dot(op, node, **kwargs):
22-
for input in node.inputs:
23-
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
24-
raise NotImplementedError(
25-
"JAX sparse dot only implemented for constant sparse inputs"
26-
)
27-
28-
if isinstance(node.outputs[0].type, SparseTensorType):
29-
raise NotImplementedError("JAX sparse dot only implemented for dense outputs")
30-
3149
@jsp.sparsify
3250
def sparse_dot(x, y):
3351
out = x @ y
34-
if isinstance(out, jsp.BCOO):
52+
if isinstance(out, jsp.BCOO) and not isinstance(
53+
node.outputs[0].type, SparseTensorType
54+
):
3555
out = out.todense()
3656
return out
3757

3858
return sparse_dot
59+
60+
61+
@jax_funcify.register(Transpose)
62+
def jax_funciy_sparse_transpose(op, **kwargs):
63+
def sparse_transpose(x):
64+
return x.T
65+
66+
return sparse_transpose

pytensor/link/jax/linker.py

+11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class JAXLinker(JITLinker):
1212

1313
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1414
from pytensor.link.jax.dispatch import jax_funcify
15+
from pytensor.sparse.type import SparseTensorType
1516
from pytensor.tensor.random.type import RandomType
1617

1718
if any(
@@ -23,6 +24,16 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2324
"Input values should be provided in this format to avoid a conversion overhead."
2425
)
2526

27+
if any(
28+
isinstance(inp.type, SparseTensorType)
29+
and not isinstance(inp, SharedVariable)
30+
for inp in fgraph.inputs
31+
):
32+
warnings.warn(
33+
"SparseTypes are implicitly converted to sparse BCOO arrays in JAX. "
34+
"Input values should be provided in this format to to avoid a conversion overhead."
35+
)
36+
2637
shared_rng_inputs = [
2738
inp
2839
for inp in fgraph.inputs

tests/link/jax/test_sparse.py

+70-30
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,49 @@
11
import numpy as np
22
import pytest
33
import scipy.sparse
4+
from jax.experimental.sparse import BCOO
45

56
import pytensor.sparse as ps
67
import pytensor.tensor as pt
78
from pytensor import function
8-
from pytensor.graph import FunctionGraph
9+
from pytensor.graph import Constant, FunctionGraph
10+
from pytensor.tensor.type import DenseTensorType
911
from tests.link.jax.test_basic import compare_jax_and_py
1012

1113

14+
def assert_bcoo_arrays_close(a1, a2):
15+
assert isinstance(a1, BCOO)
16+
assert isinstance(a1, BCOO)
17+
np.testing.assert_allclose(a1.todense(), a2.todense())
18+
19+
20+
@pytest.mark.parametrize("sparse_type", ("csc", "csr"))
21+
def test_sparse_io(sparse_type):
22+
"""Test explicit (non-shared) input and output sparse types in JAX."""
23+
sparse_mat = ps.matrix(format=sparse_type, name="csc", dtype="float32")
24+
sparse_mat_out = sparse_mat.T
25+
26+
with pytest.warns(
27+
UserWarning,
28+
match="SparseTypes are implicitly converted to sparse BCOO arrays",
29+
):
30+
fn = function([sparse_mat], sparse_mat_out, mode="JAX")
31+
32+
sp_sparse_mat = scipy.sparse.random(
33+
5, 40, density=0.25, format=sparse_type, dtype="float32"
34+
)
35+
jx_sparse_mat = BCOO.from_scipy_sparse(sp_sparse_mat)
36+
37+
sp_res = fn(sp_sparse_mat)
38+
jx_res = fn(jx_sparse_mat)
39+
assert_bcoo_arrays_close(sp_res, jx_sparse_mat.T)
40+
assert_bcoo_arrays_close(jx_res, jx_sparse_mat.T)
41+
42+
# Chained applications
43+
assert_bcoo_arrays_close(fn(fn(sp_sparse_mat)), jx_sparse_mat)
44+
assert_bcoo_arrays_close(fn(fn(jx_sparse_mat)), jx_sparse_mat)
45+
46+
1247
@pytest.mark.parametrize(
1348
"op, x_type, y_type",
1449
[
@@ -19,57 +54,62 @@
1954
# structured_dot only allows matrix @ matrix
2055
(ps.structured_dot, pt.matrix, ps.matrix),
2156
(ps.structured_dot, ps.matrix, pt.matrix),
57+
(ps.structured_dot, ps.matrix, ps.matrix),
2258
],
2359
)
24-
def test_sparse_dot_constant_sparse(x_type, y_type, op):
60+
@pytest.mark.parametrize("x_constant", (False, True))
61+
@pytest.mark.parametrize("y_constant", (False, True))
62+
def test_sparse_dot(x_type, y_type, op, x_constant, y_constant):
2563
inputs = []
2664
test_values = []
2765

2866
if x_type is ps.matrix:
29-
x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
30-
x_pt = ps.as_sparse_variable(x_sp, name="x")
67+
x_test = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
68+
x_pt = ps.as_sparse_variable(x_test, name="x")
3169
else:
32-
x_pt = x_type("x", dtype="float32")
33-
if x_pt.ndim == 1:
70+
if x_type is pt.vector:
3471
x_test = np.arange(40, dtype="float32")
3572
else:
3673
x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40)
74+
x_pt = pt.as_tensor_variable(x_test, name="x")
75+
assert isinstance(x_pt, Constant)
76+
77+
if not x_constant:
78+
x_pt = x_pt.type(name="x")
3779
inputs.append(x_pt)
3880
test_values.append(x_test)
3981

4082
if y_type is ps.matrix:
41-
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
42-
y_pt = ps.as_sparse_variable(y_sp, name="y")
83+
y_test = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
84+
y_pt = ps.as_sparse_variable(y_test, name="y")
4385
else:
44-
y_pt = y_type("y", dtype="float32")
45-
if y_pt.ndim == 1:
86+
if y_type is pt.vector:
4687
y_test = np.arange(40, dtype="float32")
4788
else:
4889
y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3)
90+
y_pt = pt.as_tensor_variable(y_test, name="y")
91+
assert isinstance(y_pt, Constant)
92+
93+
if not y_constant:
94+
y_pt = y_pt.type(name="y")
4995
inputs.append(y_pt)
5096
test_values.append(y_test)
5197

5298
dot_pt = op(x_pt, y_pt)
5399
fgraph = FunctionGraph(inputs, [dot_pt])
54-
compare_jax_and_py(fgraph, test_values)
55-
56-
57-
def test_sparse_dot_non_const_raises():
58-
x_pt = pt.vector("x")
59-
60-
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
61-
y_pt = ps.as_sparse_variable(y_sp, name="y").type()
62-
63-
out = ps.dot(x_pt, y_pt)
64-
65-
msg = "JAX sparse dot only implemented for constant sparse inputs"
66-
67-
with pytest.raises(NotImplementedError, match=msg):
68-
function([x_pt, y_pt], out, mode="JAX")
69-
70-
y_pt_shared = ps.shared(y_sp, name="y")
71100

72-
out = ps.dot(x_pt, y_pt_shared)
101+
def assert_fn(x, y):
102+
[x] = x
103+
[y] = y
104+
if hasattr(x, "todense"):
105+
x = x.todense()
106+
if hasattr(y, "todense"):
107+
y = y.todense()
108+
np.testing.assert_allclose(x, y)
73109

74-
with pytest.raises(NotImplementedError, match=msg):
75-
function([x_pt], out, mode="JAX")
110+
compare_jax_and_py(
111+
fgraph,
112+
test_values,
113+
must_be_device_array=isinstance(dot_pt.type, DenseTensorType),
114+
assert_fn=assert_fn,
115+
)

0 commit comments

Comments
 (0)