@@ -34,7 +34,8 @@ def __init__(self, row: Optional[torch.Tensor] = None,
34
34
rowptr : Optional [torch .Tensor ] = None ,
35
35
col : Optional [torch .Tensor ] = None ,
36
36
value : Optional [torch .Tensor ] = None ,
37
- sparse_sizes : Optional [Tuple [int , int ]] = None ,
37
+ sparse_sizes : Optional [Tuple [Optional [int ],
38
+ Optional [int ]]] = None ,
38
39
rowcount : Optional [torch .Tensor ] = None ,
39
40
colptr : Optional [torch .Tensor ] = None ,
40
41
colcount : Optional [torch .Tensor ] = None ,
@@ -48,26 +49,33 @@ def __init__(self, row: Optional[torch.Tensor] = None,
48
49
assert col .dim () == 1
49
50
col = col .contiguous ()
50
51
51
- if sparse_sizes is None :
52
+ M : int = 0
53
+ if sparse_sizes is None or sparse_sizes [0 ] is None :
52
54
if rowptr is not None :
53
55
M = rowptr .numel () - 1
54
56
elif row is not None and row .numel () > 0 :
55
- M = row .max ().item () + 1
56
- elif row is not None and row .numel () == 0 :
57
- M = 0
58
- else :
59
- raise ValueError
57
+ M = int (row .max ()) + 1
58
+ else :
59
+ _M = sparse_sizes [0 ]
60
+ assert _M is not None
61
+ M = _M
62
+ if rowptr is not None :
63
+ assert rowptr .numel () - 1 == M
64
+ elif row is not None and row .numel () > 0 :
65
+ assert int (row .max ()) < M
66
+
67
+ N : int = 0
68
+ if sparse_sizes is None or sparse_sizes [1 ] is None :
60
69
if col .numel () > 0 :
61
- N = col .max ().item () + 1
62
- else :
63
- N = 0
64
- sparse_sizes = (int (M ), int (N ))
70
+ N = int (col .max ()) + 1
65
71
else :
66
- assert len ( sparse_sizes ) == 2
67
- if row is not None and row . numel () > 0 :
68
- assert row . max (). item () < sparse_sizes [ 0 ]
72
+ _N = sparse_sizes [ 1 ]
73
+ assert _N is not None
74
+ N = _N
69
75
if col .numel () > 0 :
70
- assert col .max ().item () < sparse_sizes [1 ]
76
+ assert int (col .max ()) < N
77
+
78
+ sparse_sizes = (M , N )
71
79
72
80
if row is not None :
73
81
assert row .dtype == torch .long
0 commit comments