|
9 | 9 |
|
10 | 10 | @pytest.mark.parametrize('dtype,device', product(grad_dtypes, devices))
|
11 | 11 | def test_getitem(dtype, device):
|
12 |
| - mat = torch.randn(50, 40, dtype=dtype, device=device) |
| 12 | + m = 50 |
| 13 | + n = 40 |
| 14 | + k = 10 |
| 15 | + mat = torch.randn(m, n, dtype=dtype, device=device) |
13 | 16 | mat = SparseTensor.from_dense(mat)
|
14 | 17 |
|
15 |
| - idx1 = torch.randint(0, 50, (10, ), dtype=torch.long, device=device) |
16 |
| - idx2 = torch.randint(0, 40, (10, ), dtype=torch.long, device=device) |
| 18 | + idx1 = torch.randint(0, m, (k,), dtype=torch.long, device=device) |
| 19 | + idx2 = torch.randint(0, n, (k,), dtype=torch.long, device=device) |
| 20 | + bool1 = torch.zeros(m, dtype=torch.bool, device=device) |
| 21 | + bool2 = torch.zeros(n, dtype=torch.bool, device=device) |
| 22 | + bool1.scatter_(0, idx1, 1) |
| 23 | + bool2.scatter_(0, idx2, 1) |
| 24 | + # idx1 and idx2 may have duplicates |
| 25 | + k1_bool = bool1.nonzero().size(0) |
| 26 | + k2_bool = bool2.nonzero().size(0) |
17 | 27 |
|
18 |
| - assert mat[:10, :10].sizes() == [10, 10] |
19 |
| - assert mat[..., :10].sizes() == [50, 10] |
20 |
| - assert mat[idx1, idx2].sizes() == [10, 10] |
21 |
| - assert mat[idx1.tolist()].sizes() == [10, 40] |
| 28 | + idx1np = idx1.cpu().numpy() |
| 29 | + idx2np = idx2.cpu().numpy() |
| 30 | + bool1np = bool1.cpu().numpy() |
| 31 | + bool2np = bool2.cpu().numpy() |
| 32 | + |
| 33 | + idx1list = idx1np.tolist() |
| 34 | + idx2list = idx2np.tolist() |
| 35 | + bool1list = bool1np.tolist() |
| 36 | + bool2list = bool2np.tolist() |
| 37 | + |
| 38 | + assert mat[:k, :k].sizes() == [k, k] |
| 39 | + assert mat[..., :k].sizes() == [m, k] |
| 40 | + |
| 41 | + assert mat[idx1, idx2].sizes() == [k, k] |
| 42 | + assert mat[idx1np, idx2np].sizes() == [k, k] |
| 43 | + assert mat[idx1list, idx2list].sizes() == [k, k] |
| 44 | + |
| 45 | + assert mat[bool1, bool2].sizes() == [k1_bool, k2_bool] |
| 46 | + assert mat[bool1np, bool2np].sizes() == [k1_bool, k2_bool] |
| 47 | + assert mat[bool1list, bool2list].sizes() == [k1_bool, k2_bool] |
| 48 | + |
| 49 | + assert mat[idx1].sizes() == [k, n] |
| 50 | + assert mat[idx1np].sizes() == [k, n] |
| 51 | + assert mat[idx1list].sizes() == [k, n] |
| 52 | + |
| 53 | + assert mat[bool1].sizes() == [k1_bool, n] |
| 54 | + assert mat[bool1np].sizes() == [k1_bool, n] |
| 55 | + assert mat[bool1list].sizes() == [k1_bool, n] |
22 | 56 |
|
23 | 57 |
|
24 | 58 | @pytest.mark.parametrize('device', devices)
|
|
0 commit comments