Open
Description
Description
Allow automatically redefining a graph where a variable is replaced by a sparse one.
Minimal example:
import pytensor.tensor as pt
import pytensor.sparse as ps
A = pt.matrix("A")
v = pt.vector("v")
out = A @ v
A_sparse = ps.csc_matrix(name="A_sparse")
out_sparse = sparsify(out, replace={A: A_sparse})
# Should be equivalent to `ps.dot(A_sparse, v)`
Inspired by https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.sparsify.html