Skip to content

Implement sparsify transform #1155

Open
@ricardoV94

Description

@ricardoV94

Description

Allow automatically redefining a graph where a variable is replaced by a sparse one.

Minimal example:

import pytensor.tensor as pt
import pytensor.sparse as ps

A = pt.matrix("A")
v = pt.vector("v")
out = A @ v

A_sparse = ps.csc_matrix(name="A_sparse")
out_sparse = sparsify(out, replace={A: A_sparse})
# Should be equivalent to `ps.dot(A_sparse, v)`

Inspired by https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.sparse.sparsify.html

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions