Skip to content

Commit bbca7b5

Browse files
committed
Typify Sparse input variables in JAX linker
1 parent afdf6b5 commit bbca7b5

File tree

3 files changed

+84
-51
lines changed

3 files changed

+84
-51
lines changed

pytensor/link/jax/dispatch/basic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -87,11 +87,11 @@ def assert_fn(x, *inputs):
8787
def jnp_safe_copy(x):
8888
try:
8989
res = jnp.copy(x)
90-
except NotImplementedError:
91-
warnings.warn(
92-
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
93-
)
90+
except (NotImplementedError, TypeError):
9491
if hasattr(x, "copy"):
92+
warnings.warn(
93+
"`jnp.copy` is not implemented yet. Using the object's `copy` method."
94+
)
9595
res = jnp.array(x.copy())
9696
else:
9797
warnings.warn(f"Object has no `copy` method: {x}")

pytensor/link/jax/dispatch/sparse.py

+44-16
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,66 @@
11
import jax.experimental.sparse as jsp
22
from scipy.sparse import spmatrix
33

4-
from pytensor.graph.basic import Constant
4+
from pytensor.graph.type import HasDataType
55
from pytensor.link.jax.dispatch import jax_funcify, jax_typify
6-
from pytensor.sparse.basic import Dot, StructuredDot
6+
from pytensor.sparse.basic import Dot, StructuredDot, Transpose
77
from pytensor.sparse.type import SparseTensorType
8+
from pytensor.tensor import TensorType
89

910

1011
@jax_typify.register(spmatrix)
1112
def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
12-
# Note: This changes the type of the constants from CSR/CSC to BCOO
13-
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
14-
# and it would break the premise of one graph -> multiple backends.
15-
# The same situation happens with RandomGenerators...
1613
return jsp.BCOO.from_scipy_sparse(matrix)
1714

1815

16+
class BCOOType(TensorType, HasDataType):
17+
"""JAX-compatible BCOO type.
18+
19+
This type is not exposed to users directly.
20+
21+
It is introduced by the JIT linker in place of any SparseTensorType input
22+
variables used in the original function. Nodes in the function graph will
23+
still show the original types as inputs and outputs.
24+
"""
25+
26+
def filter(self, data, strict: bool = False, allow_downcast=None):
27+
if isinstance(data, jsp.BCOO):
28+
return data
29+
30+
if strict:
31+
raise TypeError()
32+
33+
return jax_typify(data)
34+
35+
36+
@jax_typify.register(SparseTensorType)
37+
def jax_typify_SparseTensorType(type):
38+
return BCOOType(
39+
dtype=type.dtype,
40+
shape=type.shape,
41+
name=type.name,
42+
broadcastable=type.broadcastable,
43+
)
44+
45+
1946
@jax_funcify.register(Dot)
2047
@jax_funcify.register(StructuredDot)
2148
def jax_funcify_sparse_dot(op, node, **kwargs):
22-
for input in node.inputs:
23-
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
24-
raise NotImplementedError(
25-
"JAX sparse dot only implemented for constant sparse inputs"
26-
)
27-
28-
if isinstance(node.outputs[0].type, SparseTensorType):
29-
raise NotImplementedError("JAX sparse dot only implemented for dense outputs")
30-
3149
@jsp.sparsify
3250
def sparse_dot(x, y):
3351
out = x @ y
34-
if isinstance(out, jsp.BCOO):
52+
if isinstance(out, jsp.BCOO) and not isinstance(
53+
node.outputs[0].type, SparseTensorType
54+
):
3555
out = out.todense()
3656
return out
3757

3858
return sparse_dot
59+
60+
61+
@jax_funcify.register(Transpose)
62+
def jax_funciy_sparse_transpose(op, **kwargs):
63+
def sparse_transpose(x):
64+
return x.T
65+
66+
return sparse_transpose

tests/link/jax/test_sparse.py

+36-31
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
import pytensor.sparse as ps
66
import pytensor.tensor as pt
7-
from pytensor import function
8-
from pytensor.graph import FunctionGraph
7+
from pytensor.graph import Constant, FunctionGraph
8+
from pytensor.tensor.type import DenseTensorType
99
from tests.link.jax.test_basic import compare_jax_and_py
1010

1111

@@ -19,57 +19,62 @@
1919
# structured_dot only allows matrix @ matrix
2020
(ps.structured_dot, pt.matrix, ps.matrix),
2121
(ps.structured_dot, ps.matrix, pt.matrix),
22+
(ps.structured_dot, ps.matrix, ps.matrix),
2223
],
2324
)
24-
def test_sparse_dot_constant_sparse(x_type, y_type, op):
25+
@pytest.mark.parametrize("x_constant", (False, True))
26+
@pytest.mark.parametrize("y_constant", (False, True))
27+
def test_sparse_dot(x_type, y_type, op, x_constant, y_constant):
2528
inputs = []
2629
test_values = []
2730

2831
if x_type is ps.matrix:
29-
x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
30-
x_pt = ps.as_sparse_variable(x_sp, name="x")
32+
x_test = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
33+
x_pt = ps.as_sparse_variable(x_test, name="x")
3134
else:
32-
x_pt = x_type("x", dtype="float32")
33-
if x_pt.ndim == 1:
35+
if x_type is pt.vector:
3436
x_test = np.arange(40, dtype="float32")
3537
else:
3638
x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40)
39+
x_pt = pt.as_tensor_variable(x_test, name="x")
40+
assert isinstance(x_pt, Constant)
41+
42+
if not x_constant:
43+
x_pt = x_pt.type(name="x")
3744
inputs.append(x_pt)
3845
test_values.append(x_test)
3946

4047
if y_type is ps.matrix:
41-
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
42-
y_pt = ps.as_sparse_variable(y_sp, name="y")
48+
y_test = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
49+
y_pt = ps.as_sparse_variable(y_test, name="y")
4350
else:
44-
y_pt = y_type("y", dtype="float32")
45-
if y_pt.ndim == 1:
51+
if y_type is pt.vector:
4652
y_test = np.arange(40, dtype="float32")
4753
else:
4854
y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3)
55+
y_pt = pt.as_tensor_variable(y_test, name="y")
56+
assert isinstance(y_pt, Constant)
57+
58+
if not y_constant:
59+
y_pt = y_pt.type(name="y")
4960
inputs.append(y_pt)
5061
test_values.append(y_test)
5162

5263
dot_pt = op(x_pt, y_pt)
5364
fgraph = FunctionGraph(inputs, [dot_pt])
54-
compare_jax_and_py(fgraph, test_values)
55-
56-
57-
def test_sparse_dot_non_const_raises():
58-
x_pt = pt.vector("x")
59-
60-
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
61-
y_pt = ps.as_sparse_variable(y_sp, name="y").type()
62-
63-
out = ps.dot(x_pt, y_pt)
64-
65-
msg = "JAX sparse dot only implemented for constant sparse inputs"
66-
67-
with pytest.raises(NotImplementedError, match=msg):
68-
function([x_pt, y_pt], out, mode="JAX")
69-
70-
y_pt_shared = ps.shared(y_sp, name="y")
7165

72-
out = ps.dot(x_pt, y_pt_shared)
66+
def assert_fn(x, y):
67+
[x] = x
68+
[y] = y
69+
if hasattr(x, "todense"):
70+
x = x.todense()
71+
if hasattr(y, "todense"):
72+
y = y.todense()
73+
np.testing.assert_allclose(x, y)
7374

74-
with pytest.raises(NotImplementedError, match=msg):
75-
function([x_pt], out, mode="JAX")
75+
compare_jax_and_py(
76+
fgraph,
77+
test_values,
78+
must_be_device_array=isinstance(dot_pt.type, DenseTensorType),
79+
assert_fn=assert_fn,
80+
)

0 commit comments

Comments
 (0)