@@ -65,9 +65,15 @@ def from_edge_index(
65
65
is_sorted : bool = False ,
66
66
trust_data : bool = False ,
67
67
):
68
- return SparseTensor (row = edge_index [0 ], rowptr = None , col = edge_index [1 ],
69
- value = edge_attr , sparse_sizes = sparse_sizes ,
70
- is_sorted = is_sorted , trust_data = trust_data )
68
+ return SparseTensor (
69
+ row = edge_index [0 ],
70
+ rowptr = None ,
71
+ col = edge_index [1 ],
72
+ value = edge_attr ,
73
+ sparse_sizes = sparse_sizes ,
74
+ is_sorted = is_sorted ,
75
+ trust_data = trust_data ,
76
+ )
71
77
72
78
@classmethod
73
79
def from_dense (self , mat : torch .Tensor , has_value : bool = True ):
@@ -84,13 +90,22 @@ def from_dense(self, mat: torch.Tensor, has_value: bool = True):
84
90
if has_value :
85
91
value = mat [row , col ]
86
92
87
- return SparseTensor (row = row , rowptr = None , col = col , value = value ,
88
- sparse_sizes = (mat .size (0 ), mat .size (1 )),
89
- is_sorted = True , trust_data = True )
93
+ return SparseTensor (
94
+ row = row ,
95
+ rowptr = None ,
96
+ col = col ,
97
+ value = value ,
98
+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
99
+ is_sorted = True ,
100
+ trust_data = True ,
101
+ )
90
102
91
103
@classmethod
92
- def from_torch_sparse_coo_tensor (self , mat : torch .Tensor ,
93
- has_value : bool = True ):
104
+ def from_torch_sparse_coo_tensor (
105
+ self ,
106
+ mat : torch .Tensor ,
107
+ has_value : bool = True ,
108
+ ):
94
109
mat = mat .coalesce ()
95
110
index = mat ._indices ()
96
111
row , col = index [0 ], index [1 ]
@@ -99,27 +114,46 @@ def from_torch_sparse_coo_tensor(self, mat: torch.Tensor,
99
114
if has_value :
100
115
value = mat .values ()
101
116
102
- return SparseTensor (row = row , rowptr = None , col = col , value = value ,
103
- sparse_sizes = (mat .size (0 ), mat .size (1 )),
104
- is_sorted = True , trust_data = True )
117
+ return SparseTensor (
118
+ row = row ,
119
+ rowptr = None ,
120
+ col = col ,
121
+ value = value ,
122
+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
123
+ is_sorted = True ,
124
+ trust_data = True ,
125
+ )
105
126
106
127
@classmethod
107
- def from_torch_sparse_csr_tensor (self , mat : torch .Tensor ,
108
- has_value : bool = True ):
128
+ def from_torch_sparse_csr_tensor (
129
+ self ,
130
+ mat : torch .Tensor ,
131
+ has_value : bool = True ,
132
+ ):
109
133
rowptr = mat .crow_indices ()
110
134
col = mat .col_indices ()
111
135
112
136
value : Optional [torch .Tensor ] = None
113
137
if has_value :
114
138
value = mat .values ()
115
139
116
- return SparseTensor (row = None , rowptr = rowptr , col = col , value = value ,
117
- sparse_sizes = (mat .size (0 ), mat .size (1 )),
118
- is_sorted = True , trust_data = True )
140
+ return SparseTensor (
141
+ row = None ,
142
+ rowptr = rowptr ,
143
+ col = col ,
144
+ value = value ,
145
+ sparse_sizes = (mat .size (0 ), mat .size (1 )),
146
+ is_sorted = True ,
147
+ trust_data = True ,
148
+ )
119
149
120
150
@classmethod
121
- def eye (self , M : int , N : Optional [int ] = None , has_value : bool = True ,
122
- dtype : Optional [int ] = None , device : Optional [torch .device ] = None ,
151
+ def eye (self ,
152
+ M : int ,
153
+ N : Optional [int ] = None ,
154
+ has_value : bool = True ,
155
+ dtype : Optional [int ] = None ,
156
+ device : Optional [torch .device ] = None ,
123
157
fill_cache : bool = False ):
124
158
125
159
N = M if N is None else N
@@ -214,13 +248,19 @@ def csc(self) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
214
248
def has_value (self ) -> bool :
215
249
return self .storage .has_value ()
216
250
217
- def set_value_ (self , value : Optional [torch .Tensor ],
218
- layout : Optional [str ] = None ):
251
+ def set_value_ (
252
+ self ,
253
+ value : Optional [torch .Tensor ],
254
+ layout : Optional [str ] = None ,
255
+ ):
219
256
self .storage .set_value_ (value , layout )
220
257
return self
221
258
222
- def set_value (self , value : Optional [torch .Tensor ],
223
- layout : Optional [str ] = None ):
259
+ def set_value (
260
+ self ,
261
+ value : Optional [torch .Tensor ],
262
+ layout : Optional [str ] = None ,
263
+ ):
224
264
return self .from_storage (self .storage .set_value (value , layout ))
225
265
226
266
def sparse_sizes (self ) -> Tuple [int , int ]:
@@ -275,13 +315,21 @@ def __eq__(self, other) -> bool:
275
315
# Utility functions #######################################################
276
316
277
317
def fill_value_ (self , fill_value : float , dtype : Optional [int ] = None ):
278
- value = torch .full ((self .nnz (), ), fill_value , dtype = dtype ,
279
- device = self .device ())
318
+ value = torch .full (
319
+ (self .nnz (), ),
320
+ fill_value ,
321
+ dtype = dtype ,
322
+ device = self .device (),
323
+ )
280
324
return self .set_value_ (value , layout = 'coo' )
281
325
282
326
def fill_value (self , fill_value : float , dtype : Optional [int ] = None ):
283
- value = torch .full ((self .nnz (), ), fill_value , dtype = dtype ,
284
- device = self .device ())
327
+ value = torch .full (
328
+ (self .nnz (), ),
329
+ fill_value ,
330
+ dtype = dtype ,
331
+ device = self .device (),
332
+ )
285
333
return self .set_value (value , layout = 'coo' )
286
334
287
335
def sizes (self ) -> List [int ]:
@@ -373,8 +421,8 @@ def to_symmetric(self, reduce: str = "sum"):
373
421
value = torch .cat ([value , value ])[perm ]
374
422
value = segment_csr (value , ptr , reduce = reduce )
375
423
376
- new_row = torch .cat ([row , col ], dim = 0 , out = perm )[idx ]
377
- new_col = torch .cat ([col , row ], dim = 0 , out = perm )[idx ]
424
+ new_row = torch .cat ([row , col ], dim = 0 )[idx ]
425
+ new_col = torch .cat ([col , row ], dim = 0 )[idx ]
378
426
379
427
out = SparseTensor (
380
428
row = new_row ,
@@ -406,8 +454,11 @@ def requires_grad(self) -> bool:
406
454
else :
407
455
return False
408
456
409
- def requires_grad_ (self , requires_grad : bool = True ,
410
- dtype : Optional [int ] = None ):
457
+ def requires_grad_ (
458
+ self ,
459
+ requires_grad : bool = True ,
460
+ dtype : Optional [int ] = None ,
461
+ ):
411
462
if requires_grad and not self .has_value ():
412
463
self .fill_value_ (1. , dtype )
413
464
@@ -478,21 +529,29 @@ def to_dense(self, dtype: Optional[int] = None) -> torch.Tensor:
478
529
row , col , value = self .coo ()
479
530
480
531
if value is not None :
481
- mat = torch .zeros (self .sizes (), dtype = value .dtype ,
482
- device = self .device ())
532
+ mat = torch .zeros (
533
+ self .sizes (),
534
+ dtype = value .dtype ,
535
+ device = self .device (),
536
+ )
483
537
else :
484
538
mat = torch .zeros (self .sizes (), dtype = dtype , device = self .device ())
485
539
486
540
if value is not None :
487
541
mat [row , col ] = value
488
542
else :
489
- mat [row , col ] = torch .ones (self .nnz (), dtype = mat .dtype ,
490
- device = mat .device )
543
+ mat [row , col ] = torch .ones (
544
+ self .nnz (),
545
+ dtype = mat .dtype ,
546
+ device = mat .device ,
547
+ )
491
548
492
549
return mat
493
550
494
551
def to_torch_sparse_coo_tensor (
495
- self , dtype : Optional [int ] = None ) -> torch .Tensor :
552
+ self ,
553
+ dtype : Optional [int ] = None ,
554
+ ) -> torch .Tensor :
496
555
row , col , value = self .coo ()
497
556
index = torch .stack ([row , col ], dim = 0 )
498
557
@@ -502,7 +561,9 @@ def to_torch_sparse_coo_tensor(
502
561
return torch .sparse_coo_tensor (index , value , self .sizes ())
503
562
504
563
def to_torch_sparse_csr_tensor (
505
- self , dtype : Optional [int ] = None ) -> torch .Tensor :
564
+ self ,
565
+ dtype : Optional [int ] = None ,
566
+ ) -> torch .Tensor :
506
567
rowptr , col , value = self .csr ()
507
568
508
569
if value is None :
@@ -511,7 +572,9 @@ def to_torch_sparse_csr_tensor(
511
572
return torch .sparse_csr_tensor (rowptr , col , value , self .sizes ())
512
573
513
574
def to_torch_sparse_csc_tensor (
514
- self , dtype : Optional [int ] = None ) -> torch .Tensor :
575
+ self ,
576
+ dtype : Optional [int ] = None ,
577
+ ) -> torch .Tensor :
515
578
colptr , row , value = self .csc ()
516
579
517
580
if value is None :
@@ -548,8 +611,11 @@ def cpu(self) -> SparseTensor:
548
611
return self .device_as (torch .tensor (0. , device = 'cpu' ))
549
612
550
613
551
- def cuda (self , device : Optional [Union [int , str ]] = None ,
552
- non_blocking : bool = False ):
614
+ def cuda (
615
+ self ,
616
+ device : Optional [Union [int , str ]] = None ,
617
+ non_blocking : bool = False ,
618
+ ):
553
619
return self .device_as (torch .tensor (0. , device = device or 'cuda' ))
554
620
555
621
@@ -654,17 +720,29 @@ def from_scipy(mat: ScipySparseMatrix, has_value: bool = True) -> SparseTensor:
654
720
value = torch .from_numpy (mat .data )
655
721
sparse_sizes = mat .shape [:2 ]
656
722
657
- storage = SparseStorage (row = row , rowptr = rowptr , col = col , value = value ,
658
- sparse_sizes = sparse_sizes , rowcount = None ,
659
- colptr = colptr , colcount = None , csr2csc = None ,
660
- csc2csr = None , is_sorted = True )
723
+ storage = SparseStorage (
724
+ row = row ,
725
+ rowptr = rowptr ,
726
+ col = col ,
727
+ value = value ,
728
+ sparse_sizes = sparse_sizes ,
729
+ rowcount = None ,
730
+ colptr = colptr ,
731
+ colcount = None ,
732
+ csr2csc = None ,
733
+ csc2csr = None ,
734
+ is_sorted = True ,
735
+ )
661
736
662
737
return SparseTensor .from_storage (storage )
663
738
664
739
665
740
@torch .jit .ignore
666
- def to_scipy (self : SparseTensor , layout : Optional [str ] = None ,
667
- dtype : Optional [torch .dtype ] = None ) -> ScipySparseMatrix :
741
+ def to_scipy (
742
+ self : SparseTensor ,
743
+ layout : Optional [str ] = None ,
744
+ dtype : Optional [torch .dtype ] = None ,
745
+ ) -> ScipySparseMatrix :
668
746
assert self .dim () == 2
669
747
layout = get_layout (layout )
670
748
0 commit comments