Skip to content

Add function for the addition of two matrices #177

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Oct 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions test/test_add.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from itertools import product

import pytest
import torch
from torch_sparse import SparseTensor, add

from .utils import dtypes, devices, tensor


@pytest.mark.parametrize('dtype,device', product(dtypes, devices))
def test_add(dtype, device):
rowA = torch.tensor([0, 0, 1, 2, 2], device=device)
colA = torch.tensor([0, 2, 1, 0, 1], device=device)
valueA = tensor([1, 2, 4, 1, 3], dtype, device)
A = SparseTensor(row=rowA, col=colA, value=valueA)

rowB = torch.tensor([0, 0, 1, 2, 2], device=device)
colB = torch.tensor([1, 2, 2, 1, 2], device=device)
valueB = tensor([2, 3, 1, 2, 4], dtype, device)
B = SparseTensor(row=rowB, col=colB, value=valueB)

C = A + B
rowC, colC, valueC = C.coo()

assert rowC.tolist() == [0, 0, 0, 1, 1, 2, 2, 2]
assert colC.tolist() == [0, 1, 2, 1, 2, 0, 1, 2]
assert valueC.tolist() == [1, 2, 5, 4, 1, 1, 5, 4]

@torch.jit.script
def jit_add(A: SparseTensor, B: SparseTensor) -> SparseTensor:
return add(A, B)

jit_add(A, B)
2 changes: 2 additions & 0 deletions torch_sparse/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from .eye import eye # noqa
from .spmm import spmm # noqa
from .spspmm import spspmm # noqa
from .spadd import spadd # noqa

__all__ = [
'SparseStorage',
Expand Down Expand Up @@ -111,5 +112,6 @@
'eye',
'spmm',
'spspmm',
'spadd',
'__version__',
]
71 changes: 53 additions & 18 deletions torch_sparse/add.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
from typing import Optional

import torch
from torch import Tensor
from torch_scatter import gather_csr
from torch_sparse.tensor import SparseTensor


def add(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).add_(value)
@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, Tensor) -> SparseTensor
pass


@torch.jit._overload # noqa: F811
def add(src, other): # noqa: F811
# type: (SparseTensor, SparseTensor) -> SparseTensor
pass


def add(src, other): # noqa: F811
if isinstance(other, Tensor):
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
f'Size mismatch: Expected size ({src.size(0)}, 1, ...) or '
f'(1, {src.size(1)}, ...), but got size {other.size()}.')
if value is not None:
value = other.to(value.dtype).add_(value)
else:
value = other.add_(1)
return src.set_value(value, layout='coo')

elif isinstance(other, SparseTensor):
rowA, colA, valueA = src.coo()
rowB, colB, valueB = other.coo()

row = torch.cat([rowA, rowB], dim=0)
col = torch.cat([colA, colB], dim=0)

value: Optional[Tensor] = None
if valueA is not None and valueB is not None:
value = torch.cat([valueA, valueB], dim=0)

M = max(src.size(0), other.size(0))
N = max(src.size(1), other.size(1))
sparse_sizes = (M, N)

out = SparseTensor(row=row, col=col, value=value,
sparse_sizes=sparse_sizes)
out = out.coalesce(reduce='sum')
return out

else:
value = other.add_(1)
return src.set_value(value, layout='coo')
raise NotImplementedError


def add_(src: SparseTensor, other: torch.Tensor) -> SparseTensor:
rowptr, col, value = src.csr()
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise...
if other.size(0) == src.size(0) and other.size(1) == 1: # Row-wise.
other = gather_csr(other.squeeze(1), rowptr)
pass
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise...
elif other.size(0) == 1 and other.size(1) == src.size(1): # Col-wise.
other = other.squeeze(0)[col]
else:
raise ValueError(
Expand Down
18 changes: 18 additions & 0 deletions torch_sparse/spadd.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch
from torch_sparse import coalesce


def spadd(indexA, valueA, indexB, valueB, m, n):
"""Matrix addition of two sparse matrices.

Args:
indexA (:class:`LongTensor`): The index tensor of first sparse matrix.
valueA (:class:`Tensor`): The value tensor of first sparse matrix.
indexB (:class:`LongTensor`): The index tensor of second sparse matrix.
valueB (:class:`Tensor`): The value tensor of second sparse matrix.
m (int): The first dimension of the sparse matrices.
n (int): The second dimension of the sparse matrices.
"""
index = torch.cat([indexA, indexB], dim=-1)
value = torch.cat([valueA, valueB], dim=0)
return coalesce(index=index, value=value, m=m, n=n, op='add')
2 changes: 1 addition & 1 deletion torch_sparse/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def sparse_reshape(self, num_rows: int, num_cols: int):

idx = self.sparse_size(1) * self.row() + self.col()

row = idx // num_cols
row = torch.div(idx, num_cols, rounding_mode='floor')
col = idx % num_cols
assert row.dtype == torch.long and col.dtype == torch.long

Expand Down