Skip to content

Fix gradient computation in to_symmetric #327

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
168 changes: 123 additions & 45 deletions torch_sparse/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,15 @@ def from_edge_index(
is_sorted: bool = False,
trust_data: bool = False,
):
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
value=edge_attr, sparse_sizes=sparse_sizes,
is_sorted=is_sorted, trust_data=trust_data)
return SparseTensor(
row=edge_index[0],
rowptr=None,
col=edge_index[1],
value=edge_attr,
sparse_sizes=sparse_sizes,
is_sorted=is_sorted,
trust_data=trust_data,
)

@classmethod
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
Expand All @@ -84,13 +90,22 @@ def from_dense(self, mat: torch.Tensor, has_value: bool = True):
if has_value:
value = mat[row, col]

return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
return SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True,
trust_data=True,
)

@classmethod
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
has_value: bool = True):
def from_torch_sparse_coo_tensor(
self,
mat: torch.Tensor,
has_value: bool = True,
):
mat = mat.coalesce()
index = mat._indices()
row, col = index[0], index[1]
Expand All @@ -99,27 +114,46 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
if has_value:
value = mat.values()

return SparseTensor(row=row, rowptr=None, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
return SparseTensor(
row=row,
rowptr=None,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True,
trust_data=True,
)

@classmethod
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
has_value: bool = True):
def from_torch_sparse_csr_tensor(
self,
mat: torch.Tensor,
has_value: bool = True,
):
rowptr = mat.crow_indices()
col = mat.col_indices()

value: Optional[torch.Tensor] = None
if has_value:
value = mat.values()

return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True, trust_data=True)
return SparseTensor(
row=None,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=(mat.size(0), mat.size(1)),
is_sorted=True,
trust_data=True,
)

@classmethod
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
dtype: Optional[int] = None, device: Optional[torch.device] = None,
def eye(self,
M: int,
N: Optional[int] = None,
has_value: bool = True,
dtype: Optional[int] = None,
device: Optional[torch.device] = None,
fill_cache: bool = False):

N = M if N is None else N
Expand Down Expand Up @@ -214,13 +248,19 @@ def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
def has_value(self) -> bool:
return self.storage.has_value()

def set_value_(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
def set_value_(
self,
value: Optional[torch.Tensor],
layout: Optional[str] = None,
):
self.storage.set_value_(value, layout)
return self

def set_value(self, value: Optional[torch.Tensor],
layout: Optional[str] = None):
def set_value(
self,
value: Optional[torch.Tensor],
layout: Optional[str] = None,
):
return self.from_storage(self.storage.set_value(value, layout))

def sparse_sizes(self) -> Tuple[int, int]:
Expand Down Expand Up @@ -275,13 +315,21 @@ def __eq__(self, other) -> bool:
# Utility functions #######################################################

def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
device=self.device())
value = torch.full(
(self.nnz(), ),
fill_value,
dtype=dtype,
device=self.device(),
)
return self.set_value_(value, layout='coo')

def fill_value(self, fill_value: float, dtype: Optional[int] = None):
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
device=self.device())
value = torch.full(
(self.nnz(), ),
fill_value,
dtype=dtype,
device=self.device(),
)
return self.set_value(value, layout='coo')

def sizes(self) -> List[int]:
Expand Down Expand Up @@ -373,8 +421,8 @@ def to_symmetric(self, reduce: str = "sum"):
value = torch.cat([value, value])[perm]
value = segment_csr(value, ptr, reduce=reduce)

new_row = torch.cat([row, col], dim=0, out=perm)[idx]
new_col = torch.cat([col, row], dim=0, out=perm)[idx]
new_row = torch.cat([row, col], dim=0)[idx]
new_col = torch.cat([col, row], dim=0)[idx]

out = SparseTensor(
row=new_row,
Expand Down Expand Up @@ -406,8 +454,11 @@ def requires_grad(self) -> bool:
else:
return False

def requires_grad_(self, requires_grad: bool = True,
dtype: Optional[int] = None):
def requires_grad_(
self,
requires_grad: bool = True,
dtype: Optional[int] = None,
):
if requires_grad and not self.has_value():
self.fill_value_(1., dtype)

Expand Down Expand Up @@ -478,21 +529,29 @@ def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
row, col, value = self.coo()

if value is not None:
mat = torch.zeros(self.sizes(), dtype=value.dtype,
device=self.device())
mat = torch.zeros(
self.sizes(),
dtype=value.dtype,
device=self.device(),
)
else:
mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())

if value is not None:
mat[row, col] = value
else:
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
device=mat.device)
mat[row, col] = torch.ones(
self.nnz(),
dtype=mat.dtype,
device=mat.device,
)

return mat

def to_torch_sparse_coo_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
self,
dtype: Optional[int] = None,
) -> torch.Tensor:
row, col, value = self.coo()
index = torch.stack([row, col], dim=0)

Expand All @@ -502,7 +561,9 @@ def to_torch_sparse_coo_tensor(
return torch.sparse_coo_tensor(index, value, self.sizes())

def to_torch_sparse_csr_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
self,
dtype: Optional[int] = None,
) -> torch.Tensor:
rowptr, col, value = self.csr()

if value is None:
Expand All @@ -511,7 +572,9 @@ def to_torch_sparse_csr_tensor(
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())

def to_torch_sparse_csc_tensor(
self, dtype: Optional[int] = None) -> torch.Tensor:
self,
dtype: Optional[int] = None,
) -> torch.Tensor:
colptr, row, value = self.csc()

if value is None:
Expand Down Expand Up @@ -548,8 +611,11 @@ def cpu(self) -> SparseTensor:
return self.device_as(torch.tensor(0., device='cpu'))


def cuda(self, device: Optional[Union[int, str]] = None,
non_blocking: bool = False):
def cuda(
self,
device: Optional[Union[int, str]] = None,
non_blocking: bool = False,
):
return self.device_as(torch.tensor(0., device=device or 'cuda'))


Expand Down Expand Up @@ -654,17 +720,29 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
value = torch.from_numpy(mat.data)
sparse_sizes = mat.shape[:2]

storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
sparse_sizes=sparse_sizes, rowcount=None,
colptr=colptr, colcount=None, csr2csc=None,
csc2csr=None, is_sorted=True)
storage = SparseStorage(
row=row,
rowptr=rowptr,
col=col,
value=value,
sparse_sizes=sparse_sizes,
rowcount=None,
colptr=colptr,
colcount=None,
csr2csc=None,
csc2csr=None,
is_sorted=True,
)

return SparseTensor.from_storage(storage)


@torch.jit.ignore
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
def to_scipy(
self: SparseTensor,
layout: Optional[str] = None,
dtype: Optional[torch.dtype] = None,
) -> ScipySparseMatrix:
assert self.dim() == 2
layout = get_layout(layout)

Expand Down