Skip to content

Commit 88c6ceb

Browse files
bwdeng20timrusty1s
authored
fix(SparseTensor.__getitem__): support np.ndarray and fix `List[b… (#194)
* fix(`SparseTensor.__getitem__`): support `np.ndarray` and fix `List[bool]` support indexing with np.ndarray & fix bug merging from indexing with List[bool] * style(tensor, test_tensor): pep8 E501 too long support indexing with np.ndarray & fix bug merging from indexing with List[bool] * update * typo * typo Co-authored-by: tim <[email protected]> Co-authored-by: rusty1s <[email protected]>
1 parent efc9808 commit 88c6ceb

File tree

2 files changed

+50
-10
lines changed

2 files changed

+50
-10
lines changed

test/test_tensor.py

Lines changed: 41 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,50 @@
99

1010
@pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
1111
def test_getitem(dtype, device):
12-
mat = torch.randn(50, 40, dtype=dtype, device=device)
12+
m = 50
13+
n = 40
14+
k = 10
15+
mat = torch.randn(m, n, dtype=dtype, device=device)
1316
mat = SparseTensor.from_dense(mat)
1417

15-
idx1 = torch.randint(0, 50, (10, ), dtype=torch.long, device=device)
16-
idx2 = torch.randint(0, 40, (10, ), dtype=torch.long, device=device)
18+
idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device)
19+
idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device)
20+
bool1 = torch.zeros(m, dtype=torch.bool, device=device)
21+
bool2 = torch.zeros(n, dtype=torch.bool, device=device)
22+
bool1.scatter_(0, idx1, 1)
23+
bool2.scatter_(0, idx2, 1)
24+
# idx1 and idx2 may have duplicates
25+
k1_bool = bool1.nonzero().size(0)
26+
k2_bool = bool2.nonzero().size(0)
1727

18-
assert mat[:10, :10].sizes() == [10, 10]
19-
assert mat[..., :10].sizes() == [50, 10]
20-
assert mat[idx1, idx2].sizes() == [10, 10]
21-
assert mat[idx1.tolist()].sizes() == [10, 40]
28+
idx1np = idx1.cpu().numpy()
29+
idx2np = idx2.cpu().numpy()
30+
bool1np = bool1.cpu().numpy()
31+
bool2np = bool2.cpu().numpy()
32+
33+
idx1list = idx1np.tolist()
34+
idx2list = idx2np.tolist()
35+
bool1list = bool1np.tolist()
36+
bool2list = bool2np.tolist()
37+
38+
assert mat[:k, :k].sizes() == [k, k]
39+
assert mat[..., :k].sizes() == [m, k]
40+
41+
assert mat[idx1, idx2].sizes() == [k, k]
42+
assert mat[idx1np, idx2np].sizes() == [k, k]
43+
assert mat[idx1list, idx2list].sizes() == [k, k]
44+
45+
assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool]
46+
assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool]
47+
assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool]
48+
49+
assert mat[idx1].sizes() == [k, n]
50+
assert mat[idx1np].sizes() == [k, n]
51+
assert mat[idx1list].sizes() == [k, n]
52+
53+
assert mat[bool1].sizes() == [k1_bool, n]
54+
assert mat[bool1np].sizes() == [k1_bool, n]
55+
assert mat[bool1list].sizes() == [k1_bool, n]
2256

2357

2458
@pytest.mark.parametrize('device', devices)

torch_sparse/tensor.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from typing import Optional, List, Tuple, Dict, Union, Any
33

44
import torch
5+
import numpy as np
56
import scipy.sparse
67
from torch_scatter import segment_csr
78

@@ -468,7 +469,6 @@ def is_shared(self: SparseTensor) -> bool:
468469

469470
def to(self, *args: Optional[List[Any]],
470471
**kwargs: Optional[Dict[str, Any]]) -> SparseTensor:
471-
472472
device, dtype, non_blocking = torch._C._nn._parse_to(*args, **kwargs)[:3]
473473

474474
if dtype is not None:
@@ -491,15 +491,21 @@ def cuda(self, device: Optional[Union[int, str]] = None,
491491
def __getitem__(self: SparseTensor, index: Any) -> SparseTensor:
492492
index = list(index) if isinstance(index, tuple) else [index]
493493
# More than one `Ellipsis` is not allowed...
494-
if len([i for i in index if not torch.is_tensor(i) and i == ...]) > 1:
494+
if len([
495+
i for i in index
496+
if not isinstance(i, (torch.Tensor, np.ndarray)) and i == ...
497+
]) > 1:
495498
raise SyntaxError
496499

497500
dim = 0
498501
out = self
499502
while len(index) > 0:
500503
item = index.pop(0)
501504
if isinstance(item, (list, tuple)):
502-
item = torch.tensor(item, dtype=torch.long, device=self.device())
505+
item = torch.tensor(item, device=self.device())
506+
if isinstance(item, np.ndarray):
507+
item = torch.from_numpy(item).to(self.device())
508+
503509
if isinstance(item, int):
504510
out = out.select(dim, item)
505511
dim += 1

0 commit comments

Comments
 (0)