Skip to content

Commit 28f1295

Browse files
authored
Merge pull request #176 from shi27feng/patch-1
Update storage.py
2 parents 23709f9 + 9b5d3c7 commit 28f1295

File tree

2 files changed

+27
-17
lines changed

2 files changed

+27
-17
lines changed

torch_sparse/storage.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
3434
rowptr: Optional[torch.Tensor] = None,
3535
col: Optional[torch.Tensor] = None,
3636
value: Optional[torch.Tensor] = None,
37-
sparse_sizes: Optional[Tuple[int, int]] = None,
37+
sparse_sizes: Optional[Tuple[Optional[int],
38+
Optional[int]]] = None,
3839
rowcount: Optional[torch.Tensor] = None,
3940
colptr: Optional[torch.Tensor] = None,
4041
colcount: Optional[torch.Tensor] = None,
@@ -48,26 +49,33 @@ def __init__(self, row: Optional[torch.Tensor] = None,
4849
assert col.dim() == 1
4950
col = col.contiguous()
5051

51-
if sparse_sizes is None:
52+
M: int = 0
53+
if sparse_sizes is None or sparse_sizes[0] is None:
5254
if rowptr is not None:
5355
M = rowptr.numel() - 1
5456
elif row is not None and row.numel() > 0:
55-
M = row.max().item() + 1
56-
elif row is not None and row.numel() == 0:
57-
M = 0
58-
else:
59-
raise ValueError
57+
M = int(row.max()) + 1
58+
else:
59+
_M = sparse_sizes[0]
60+
assert _M is not None
61+
M = _M
62+
if rowptr is not None:
63+
assert rowptr.numel() - 1 == M
64+
elif row is not None and row.numel() > 0:
65+
assert int(row.max()) < M
66+
67+
N: int = 0
68+
if sparse_sizes is None or sparse_sizes[1] is None:
6069
if col.numel() > 0:
61-
N = col.max().item() + 1
62-
else:
63-
N = 0
64-
sparse_sizes = (int(M), int(N))
70+
N = int(col.max()) + 1
6571
else:
66-
assert len(sparse_sizes) == 2
67-
if row is not None and row.numel() > 0:
68-
assert row.max().item() < sparse_sizes[0]
72+
_N = sparse_sizes[1]
73+
assert _N is not None
74+
N = _N
6975
if col.numel() > 0:
70-
assert col.max().item() < sparse_sizes[1]
76+
assert int(col.max()) < N
77+
78+
sparse_sizes = (M, N)
7179

7280
if row is not None:
7381
assert row.dtype == torch.long

torch_sparse/tensor.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
1616
rowptr: Optional[torch.Tensor] = None,
1717
col: Optional[torch.Tensor] = None,
1818
value: Optional[torch.Tensor] = None,
19-
sparse_sizes: Optional[Tuple[int, int]] = None,
19+
sparse_sizes: Optional[Tuple[Optional[int],
20+
Optional[int]]] = None,
2021
is_sorted: bool = False):
2122
self.storage = SparseStorage(row=row, rowptr=rowptr, col=col,
2223
value=value, sparse_sizes=sparse_sizes,
@@ -39,7 +40,8 @@ def from_storage(self, storage: SparseStorage):
3940
@classmethod
4041
def from_edge_index(self, edge_index: torch.Tensor,
4142
edge_attr: Optional[torch.Tensor] = None,
42-
sparse_sizes: Optional[Tuple[int, int]] = None,
43+
sparse_sizes: Optional[Tuple[Optional[int],
44+
Optional[int]]] = None,
4345
is_sorted: bool = False):
4446
return SparseTensor(row=edge_index[0], rowptr=None, col=edge_index[1],
4547
value=edge_attr, sparse_sizes=sparse_sizes,

0 commit comments

Comments
 (0)