Skip to content

Commit a6bf472

Browse files
committed
Typify Sparse input variables in JAX linker
1 parent a023e5b commit a6bf472

File tree

4 files changed

+132
-50
lines changed

4 files changed

+132
-50
lines changed

pytensor/link/jax/dispatch/basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def assert_fn(x, *inputs):
8787
def jnp_safe_copy(x):
8888
try:
8989
res = jnp.copy(x)
90-
except NotImplementedError:
91-
warnings.warn(
92-
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
93-
)
90+
except (NotImplementedError, TypeError):
9491
if hasattr(x, "copy"):
92+
warnings.warn(
93+
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
94+
)
9595
res = jnp.array(x.copy())
9696
else:
9797
warnings.warn(f"Object has no `copy` method: {x}")

pytensor/link/jax/dispatch/sparse.py

+44-16
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,66 @@
11
import jax.experimental.sparse as jsp
22
from scipy.sparse import spmatrix
33

4-
from pytensor.graph.basic import Constant
4+
from pytensor.graph.type import HasDataType
55
from pytensor.link.jax.dispatch import jax_funcify, jax_typify
6-
from pytensor.sparse.basic import Dot, StructuredDot
6+
from pytensor.sparse.basic import Dot, StructuredDot, Transpose
77
from pytensor.sparse.type import SparseTensorType
8+
from pytensor.tensor import TensorType
89

910

1011
@jax_typify.register(spmatrix)
1112
def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
12-
# Note: This changes the type of the constants from CSR/CSC to BCOO
13-
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
14-
# and it would break the premise of one graph -> multiple backends.
15-
# The same situation happens with RandomGenerators...
1613
return jsp.BCOO.from_scipy_sparse(matrix)
1714

1815

16+
class BCOOType(TensorType, HasDataType):
17+
"""JAX-compatible BCOO type.
18+
19+
This type is not exposed to users directly.
20+
21+
It is introduced by the JIT linker in place of any SparseTensorType input
22+
variables used in the original function. Nodes in the function graph will
23+
still show the original types as inputs and outputs.
24+
"""
25+
26+
def filter(self, data, strict: bool = False, allow_downcast=None):
27+
if isinstance(data, jsp.BCOO):
28+
return data
29+
30+
if strict:
31+
raise TypeError()
32+
33+
return jax_typify(data)
34+
35+
36+
@jax_typify.register(SparseTensorType)
37+
def jax_typify_SparseTensorType(type):
38+
return BCOOType(
39+
dtype=type.dtype,
40+
shape=type.shape,
41+
name=type.name,
42+
broadcastable=type.broadcastable,
43+
)
44+
45+
1946
@jax_funcify.register(Dot)
2047
@jax_funcify.register(StructuredDot)
2148
def jax_funcify_sparse_dot(op, node, **kwargs):
22-
for input in node.inputs:
23-
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
24-
raise NotImplementedError(
25-
"JAX sparse dot only implemented for constant sparse inputs"
26-
)
27-
28-
if isinstance(node.outputs[0].type, SparseTensorType):
29-
raise NotImplementedError("JAX sparse dot only implemented for dense outputs")
30-
3149
@jsp.sparsify
3250
def sparse_dot(x, y):
3351
out = x @ y
34-
if isinstance(out, jsp.BCOO):
52+
if isinstance(out, jsp.BCOO) and not isinstance(
53+
node.outputs[0].type, SparseTensorType
54+
):
3555
out = out.todense()
3656
return out
3757

3858
return sparse_dot
59+
60+
61+
@jax_funcify.register(Transpose)
62+
def jax_funcify_sparse_transpose(op, **kwargs):
63+
def sparse_transpose(x):
64+
return x.T
65+
66+
return sparse_transpose

pytensor/link/jax/linker.py

+11
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ class JAXLinker(JITLinker):
1212

1313
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
1414
from pytensor.link.jax.dispatch import jax_funcify
15+
from pytensor.sparse.type import SparseTensorType
1516
from pytensor.tensor.random.type import RandomType
1617

1718
if any(
@@ -23,6 +24,16 @@ def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
2324
"Input values should be provided in this format to avoid a conversion overhead."
2425
)
2526

27+
if any(
28+
isinstance(inp.type, SparseTensorType)
29+
and not isinstance(inp, SharedVariable)
30+
for inp in fgraph.inputs
31+
):
32+
warnings.warn(
33+
"SparseTypes are implicitly converted to sparse BCOO arrays in JAX. "
34+
"Input values should be provided in this format to to avoid a conversion overhead."
35+
)
36+
2637
shared_rng_inputs = [
2738
inp
2839
for inp in fgraph.inputs

tests/link/jax/test_sparse.py

+73-30
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,51 @@
22
import pytest
33
import scipy.sparse
44

5+
6+
jax = pytest.importorskip("jax")
7+
from jax.experimental.sparse import BCOO
8+
59
import pytensor.sparse as ps
610
import pytensor.tensor as pt
711
from pytensor import function
8-
from pytensor.graph import FunctionGraph
12+
from pytensor.graph import Constant, FunctionGraph
13+
from pytensor.tensor.type import DenseTensorType
914
from tests.link.jax.test_basic import compare_jax_and_py
1015

1116

17+
def assert_bcoo_arrays_allclose(a1, a2):
18+
assert isinstance(a1, BCOO)
19+
assert isinstance(a1, BCOO)
20+
np.testing.assert_allclose(a1.todense(), a2.todense())
21+
22+
23+
@pytest.mark.parametrize("sparse_type", ("csc", "csr"))
24+
def test_sparse_io(sparse_type):
25+
"""Test explicit (non-shared) input and output sparse types in JAX."""
26+
sparse_mat = ps.matrix(format=sparse_type, name="csc", dtype="float32")
27+
sparse_mat_out = sparse_mat.T
28+
29+
with pytest.warns(
30+
UserWarning,
31+
match="SparseTypes are implicitly converted to sparse BCOO arrays",
32+
):
33+
fn = function([sparse_mat], sparse_mat_out, mode="JAX")
34+
35+
sp_sparse_mat = scipy.sparse.random(
36+
5, 40, density=0.25, format=sparse_type, dtype="float32"
37+
)
38+
jx_sparse_mat = BCOO.from_scipy_sparse(sp_sparse_mat)
39+
40+
sp_res = fn(sp_sparse_mat)
41+
jx_res = fn(jx_sparse_mat)
42+
assert_bcoo_arrays_allclose(sp_res, jx_sparse_mat.T)
43+
assert_bcoo_arrays_allclose(jx_res, jx_sparse_mat.T)
44+
45+
# Chained applications
46+
assert_bcoo_arrays_allclose(fn(fn(sp_sparse_mat)), jx_sparse_mat)
47+
assert_bcoo_arrays_allclose(fn(fn(jx_sparse_mat)), jx_sparse_mat)
48+
49+
1250
@pytest.mark.parametrize(
1351
"op, x_type, y_type",
1452
[
@@ -19,57 +57,62 @@
1957
# structured_dot only allows matrix @ matrix
2058
(ps.structured_dot, pt.matrix, ps.matrix),
2159
(ps.structured_dot, ps.matrix, pt.matrix),
60+
(ps.structured_dot, ps.matrix, ps.matrix),
2261
],
2362
)
24-
def test_sparse_dot_constant_sparse(x_type, y_type, op):
63+
@pytest.mark.parametrize("x_constant", (False, True))
64+
@pytest.mark.parametrize("y_constant", (False, True))
65+
def test_sparse_dot(x_type, y_type, op, x_constant, y_constant):
2566
inputs = []
2667
test_values = []
2768

2869
if x_type is ps.matrix:
29-
x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
30-
x_pt = ps.as_sparse_variable(x_sp, name="x")
70+
x_test = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
71+
x_pt = ps.as_sparse_variable(x_test, name="x")
3172
else:
32-
x_pt = x_type("x", dtype="float32")
33-
if x_pt.ndim == 1:
73+
if x_type is pt.vector:
3474
x_test = np.arange(40, dtype="float32")
3575
else:
3676
x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40)
77+
x_pt = pt.as_tensor_variable(x_test, name="x")
78+
assert isinstance(x_pt, Constant)
79+
80+
if not x_constant:
81+
x_pt = x_pt.type(name="x")
3782
inputs.append(x_pt)
3883
test_values.append(x_test)
3984

4085
if y_type is ps.matrix:
41-
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
42-
y_pt = ps.as_sparse_variable(y_sp, name="y")
86+
y_test = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
87+
y_pt = ps.as_sparse_variable(y_test, name="y")
4388
else:
44-
y_pt = y_type("y", dtype="float32")
45-
if y_pt.ndim == 1:
89+
if y_type is pt.vector:
4690
y_test = np.arange(40, dtype="float32")
4791
else:
4892
y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3)
93+
y_pt = pt.as_tensor_variable(y_test, name="y")
94+
assert isinstance(y_pt, Constant)
95+
96+
if not y_constant:
97+
y_pt = y_pt.type(name="y")
4998
inputs.append(y_pt)
5099
test_values.append(y_test)
51100

52101
dot_pt = op(x_pt, y_pt)
53102
fgraph = FunctionGraph(inputs, [dot_pt])
54-
compare_jax_and_py(fgraph, test_values)
55-
56-
57-
def test_sparse_dot_non_const_raises():
58-
x_pt = pt.vector("x")
59-
60-
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
61-
y_pt = ps.as_sparse_variable(y_sp, name="y").type()
62-
63-
out = ps.dot(x_pt, y_pt)
64-
65-
msg = "JAX sparse dot only implemented for constant sparse inputs"
66-
67-
with pytest.raises(NotImplementedError, match=msg):
68-
function([x_pt, y_pt], out, mode="JAX")
69-
70-
y_pt_shared = ps.shared(y_sp, name="y")
71103

72-
out = ps.dot(x_pt, y_pt_shared)
104+
def assert_fn(x, y):
105+
[x] = x
106+
[y] = y
107+
if hasattr(x, "todense"):
108+
x = x.todense()
109+
if hasattr(y, "todense"):
110+
y = y.todense()
111+
np.testing.assert_allclose(x, y)
73112

74-
with pytest.raises(NotImplementedError, match=msg):
75-
function([x_pt], out, mode="JAX")
113+
compare_jax_and_py(
114+
fgraph,
115+
test_values,
116+
must_be_device_array=isinstance(dot_pt.type, DenseTensorType),
117+
assert_fn=assert_fn,
118+
)

0 commit comments

Comments
 (0)