Skip to content

Commit 4debbab

Browse files
committed
update
1 parent 20c3dd9 commit 4debbab

File tree

1 file changed

+123
-45
lines changed

1 file changed

+123
-45
lines changed

torch_sparse/tensor.py

Lines changed: 123 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -65,9 +65,15 @@ def from_edge_index(
6565
is_sorted: bool = False,
6666
trust_data: bool = False,
6767
):
68-
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
69-
value=edge_attr, sparse_sizes=sparse_sizes,
70-
is_sorted=is_sorted, trust_data=trust_data)
68+
return SparseTensor(
69+
row=edge_index[0],
70+
rowptr=None,
71+
col=edge_index[1],
72+
value=edge_attr,
73+
sparse_sizes=sparse_sizes,
74+
is_sorted=is_sorted,
75+
trust_data=trust_data,
76+
)
7177

7278
@classmethod
7379
def from_dense(self, mat: torch.Tensor, has_value: bool = True):
@@ -84,13 +90,22 @@ def from_dense(self, mat: torch.Tensor, has_value: bool = True):
8490
if has_value:
8591
value = mat[row, col]
8692

87-
return SparseTensor(row=row, rowptr=None, col=col, value=value,
88-
sparse_sizes=(mat.size(0), mat.size(1)),
89-
is_sorted=True, trust_data=True)
93+
return SparseTensor(
94+
row=row,
95+
rowptr=None,
96+
col=col,
97+
value=value,
98+
sparse_sizes=(mat.size(0), mat.size(1)),
99+
is_sorted=True,
100+
trust_data=True,
101+
)
90102

91103
@classmethod
92-
def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
93-
has_value: bool = True):
104+
def from_torch_sparse_coo_tensor(
105+
self,
106+
mat: torch.Tensor,
107+
has_value: bool = True,
108+
):
94109
mat = mat.coalesce()
95110
index = mat._indices()
96111
row, col = index[0], index[1]
@@ -99,27 +114,46 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
99114
if has_value:
100115
value = mat.values()
101116

102-
return SparseTensor(row=row, rowptr=None, col=col, value=value,
103-
sparse_sizes=(mat.size(0), mat.size(1)),
104-
is_sorted=True, trust_data=True)
117+
return SparseTensor(
118+
row=row,
119+
rowptr=None,
120+
col=col,
121+
value=value,
122+
sparse_sizes=(mat.size(0), mat.size(1)),
123+
is_sorted=True,
124+
trust_data=True,
125+
)
105126

106127
@classmethod
107-
def from_torch_sparse_csr_tensor(self, mat: torch.Tensor,
108-
has_value: bool = True):
128+
def from_torch_sparse_csr_tensor(
129+
self,
130+
mat: torch.Tensor,
131+
has_value: bool = True,
132+
):
109133
rowptr = mat.crow_indices()
110134
col = mat.col_indices()
111135

112136
value: Optional[torch.Tensor] = None
113137
if has_value:
114138
value = mat.values()
115139

116-
return SparseTensor(row=None, rowptr=rowptr, col=col, value=value,
117-
sparse_sizes=(mat.size(0), mat.size(1)),
118-
is_sorted=True, trust_data=True)
140+
return SparseTensor(
141+
row=None,
142+
rowptr=rowptr,
143+
col=col,
144+
value=value,
145+
sparse_sizes=(mat.size(0), mat.size(1)),
146+
is_sorted=True,
147+
trust_data=True,
148+
)
119149

120150
@classmethod
121-
def eye(self, M: int, N: Optional[int] = None, has_value: bool = True,
122-
dtype: Optional[int] = None, device: Optional[torch.device] = None,
151+
def eye(self,
152+
M: int,
153+
N: Optional[int] = None,
154+
has_value: bool = True,
155+
dtype: Optional[int] = None,
156+
device: Optional[torch.device] = None,
123157
fill_cache: bool = False):
124158

125159
N = M if N is None else N
@@ -214,13 +248,19 @@ def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
214248
def has_value(self) -> bool:
215249
return self.storage.has_value()
216250

217-
def set_value_(self, value: Optional[torch.Tensor],
218-
layout: Optional[str] = None):
251+
def set_value_(
252+
self,
253+
value: Optional[torch.Tensor],
254+
layout: Optional[str] = None,
255+
):
219256
self.storage.set_value_(value, layout)
220257
return self
221258

222-
def set_value(self, value: Optional[torch.Tensor],
223-
layout: Optional[str] = None):
259+
def set_value(
260+
self,
261+
value: Optional[torch.Tensor],
262+
layout: Optional[str] = None,
263+
):
224264
return self.from_storage(self.storage.set_value(value, layout))
225265

226266
def sparse_sizes(self) -> Tuple[int, int]:
@@ -275,13 +315,21 @@ def __eq__(self, other) -> bool:
275315
# Utility functions #######################################################
276316

277317
def fill_value_(self, fill_value: float, dtype: Optional[int] = None):
278-
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
279-
device=self.device())
318+
value = torch.full(
319+
(self.nnz(), ),
320+
fill_value,
321+
dtype=dtype,
322+
device=self.device(),
323+
)
280324
return self.set_value_(value, layout='coo')
281325

282326
def fill_value(self, fill_value: float, dtype: Optional[int] = None):
283-
value = torch.full((self.nnz(), ), fill_value, dtype=dtype,
284-
device=self.device())
327+
value = torch.full(
328+
(self.nnz(), ),
329+
fill_value,
330+
dtype=dtype,
331+
device=self.device(),
332+
)
285333
return self.set_value(value, layout='coo')
286334

287335
def sizes(self) -> List[int]:
@@ -373,8 +421,8 @@ def to_symmetric(self, reduce: str = "sum"):
373421
value = torch.cat([value, value])[perm]
374422
value = segment_csr(value, ptr, reduce=reduce)
375423

376-
new_row = torch.cat([row, col], dim=0, out=perm)[idx]
377-
new_col = torch.cat([col, row], dim=0, out=perm)[idx]
424+
new_row = torch.cat([row, col], dim=0)[idx]
425+
new_col = torch.cat([col, row], dim=0)[idx]
378426

379427
out = SparseTensor(
380428
row=new_row,
@@ -406,8 +454,11 @@ def requires_grad(self) -> bool:
406454
else:
407455
return False
408456

409-
def requires_grad_(self, requires_grad: bool = True,
410-
dtype: Optional[int] = None):
457+
def requires_grad_(
458+
self,
459+
requires_grad: bool = True,
460+
dtype: Optional[int] = None,
461+
):
411462
if requires_grad and not self.has_value():
412463
self.fill_value_(1., dtype)
413464

@@ -478,21 +529,29 @@ def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
478529
row, col, value = self.coo()
479530

480531
if value is not None:
481-
mat = torch.zeros(self.sizes(), dtype=value.dtype,
482-
device=self.device())
532+
mat = torch.zeros(
533+
self.sizes(),
534+
dtype=value.dtype,
535+
device=self.device(),
536+
)
483537
else:
484538
mat = torch.zeros(self.sizes(), dtype=dtype, device=self.device())
485539

486540
if value is not None:
487541
mat[row, col] = value
488542
else:
489-
mat[row, col] = torch.ones(self.nnz(), dtype=mat.dtype,
490-
device=mat.device)
543+
mat[row, col] = torch.ones(
544+
self.nnz(),
545+
dtype=mat.dtype,
546+
device=mat.device,
547+
)
491548

492549
return mat
493550

494551
def to_torch_sparse_coo_tensor(
495-
self, dtype: Optional[int] = None) -> torch.Tensor:
552+
self,
553+
dtype: Optional[int] = None,
554+
) -> torch.Tensor:
496555
row, col, value = self.coo()
497556
index = torch.stack([row, col], dim=0)
498557

@@ -502,7 +561,9 @@ def to_torch_sparse_coo_tensor(
502561
return torch.sparse_coo_tensor(index, value, self.sizes())
503562

504563
def to_torch_sparse_csr_tensor(
505-
self, dtype: Optional[int] = None) -> torch.Tensor:
564+
self,
565+
dtype: Optional[int] = None,
566+
) -> torch.Tensor:
506567
rowptr, col, value = self.csr()
507568

508569
if value is None:
@@ -511,7 +572,9 @@ def to_torch_sparse_csr_tensor(
511572
return torch.sparse_csr_tensor(rowptr, col, value, self.sizes())
512573

513574
def to_torch_sparse_csc_tensor(
514-
self, dtype: Optional[int] = None) -> torch.Tensor:
575+
self,
576+
dtype: Optional[int] = None,
577+
) -> torch.Tensor:
515578
colptr, row, value = self.csc()
516579

517580
if value is None:
@@ -548,8 +611,11 @@ def cpu(self) -> SparseTensor:
548611
return self.device_as(torch.tensor(0., device='cpu'))
549612

550613

551-
def cuda(self, device: Optional[Union[int, str]] = None,
552-
non_blocking: bool = False):
614+
def cuda(
615+
self,
616+
device: Optional[Union[int, str]] = None,
617+
non_blocking: bool = False,
618+
):
553619
return self.device_as(torch.tensor(0., device=device or 'cuda'))
554620

555621

@@ -654,17 +720,29 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
654720
value = torch.from_numpy(mat.data)
655721
sparse_sizes = mat.shape[:2]
656722

657-
storage = SparseStorage(row=row, rowptr=rowptr, col=col, value=value,
658-
sparse_sizes=sparse_sizes, rowcount=None,
659-
colptr=colptr, colcount=None, csr2csc=None,
660-
csc2csr=None, is_sorted=True)
723+
storage = SparseStorage(
724+
row=row,
725+
rowptr=rowptr,
726+
col=col,
727+
value=value,
728+
sparse_sizes=sparse_sizes,
729+
rowcount=None,
730+
colptr=colptr,
731+
colcount=None,
732+
csr2csc=None,
733+
csc2csr=None,
734+
is_sorted=True,
735+
)
661736

662737
return SparseTensor.from_storage(storage)
663738

664739

665740
@torch.jit.ignore
666-
def to_scipy(self: SparseTensor, layout: Optional[str] = None,
667-
dtype: Optional[torch.dtype] = None) -> ScipySparseMatrix:
741+
def to_scipy(
742+
self: SparseTensor,
743+
layout: Optional[str] = None,
744+
dtype: Optional[torch.dtype] = None,
745+
) -> ScipySparseMatrix:
668746
assert self.dim() == 2
669747
layout = get_layout(layout)
670748

0 commit comments

Comments
 (0)