1
1
# ---------------------------------------------------------------
2
2
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
3
3
#
4
- # This work is licensed under the NVIDIA Source Code License
4
+ # Licensed under the NVIDIA Source Code License. For full license
5
+ # terms, please refer to the LICENSE file provided with this code
6
+ # or visit NVIDIA's official repository at
7
+ # https://github.com/NVlabs/SegFormer/tree/master.
8
+ #
9
+ # This code has been modified.
5
10
# ---------------------------------------------------------------
6
11
import math
7
12
import torch
11
16
from timm .layers import DropPath , to_2tuple , trunc_normal_
12
17
13
18
19
+ class LayerNorm (nn .LayerNorm ):
20
+ def forward (self , x ):
21
+ if x .ndim == 4 :
22
+ B , C , H , W = x .shape
23
+ x = x .view (B , C , - 1 ).transpose (1 , 2 )
24
+ x = super ().forward (x )
25
+ x = x .transpose (1 , 2 ).view (B , C , H , W )
26
+ else :
27
+ x = super ().forward (x )
28
+ return x
29
+
30
+
14
31
class Mlp (nn .Module ):
15
32
def __init__ (
16
33
self ,
@@ -36,9 +53,6 @@ def _init_weights(self, m):
36
53
trunc_normal_ (m .weight , std = 0.02 )
37
54
if isinstance (m , nn .Linear ) and m .bias is not None :
38
55
nn .init .constant_ (m .bias , 0 )
39
- elif isinstance (m , nn .LayerNorm ):
40
- nn .init .constant_ (m .bias , 0 )
41
- nn .init .constant_ (m .weight , 1.0 )
42
56
elif isinstance (m , nn .Conv2d ):
43
57
fan_out = m .kernel_size [0 ] * m .kernel_size [1 ] * m .out_channels
44
58
fan_out //= m .groups
@@ -86,7 +100,7 @@ def __init__(
86
100
self .sr_ratio = sr_ratio
87
101
if sr_ratio > 1 :
88
102
self .sr = nn .Conv2d (dim , dim , kernel_size = sr_ratio , stride = sr_ratio )
89
- self .norm = nn . LayerNorm (dim )
103
+ self .norm = LayerNorm (dim )
90
104
91
105
self .apply (self ._init_weights )
92
106
@@ -95,7 +109,7 @@ def _init_weights(self, m):
95
109
trunc_normal_ (m .weight , std = 0.02 )
96
110
if isinstance (m , nn .Linear ) and m .bias is not None :
97
111
nn .init .constant_ (m .bias , 0 )
98
- elif isinstance (m , nn . LayerNorm ):
112
+ elif isinstance (m , LayerNorm ):
99
113
nn .init .constant_ (m .bias , 0 )
100
114
nn .init .constant_ (m .weight , 1.0 )
101
115
elif isinstance (m , nn .Conv2d ):
@@ -153,7 +167,7 @@ def __init__(
153
167
attn_drop = 0.0 ,
154
168
drop_path = 0.0 ,
155
169
act_layer = nn .GELU ,
156
- norm_layer = nn . LayerNorm ,
170
+ norm_layer = LayerNorm ,
157
171
sr_ratio = 1 ,
158
172
):
159
173
super ().__init__ ()
@@ -185,7 +199,7 @@ def _init_weights(self, m):
185
199
trunc_normal_ (m .weight , std = 0.02 )
186
200
if isinstance (m , nn .Linear ) and m .bias is not None :
187
201
nn .init .constant_ (m .bias , 0 )
188
- elif isinstance (m , nn . LayerNorm ):
202
+ elif isinstance (m , LayerNorm ):
189
203
nn .init .constant_ (m .bias , 0 )
190
204
nn .init .constant_ (m .weight , 1.0 )
191
205
elif isinstance (m , nn .Conv2d ):
@@ -195,10 +209,12 @@ def _init_weights(self, m):
195
209
if m .bias is not None :
196
210
m .bias .data .zero_ ()
197
211
198
- def forward (self , x , H , W ):
212
+ def forward (self , x ):
213
+ B , _ , H , W = x .shape
214
+ x = x .flatten (2 ).transpose (1 , 2 )
199
215
x = x + self .drop_path (self .attn (self .norm1 (x ), H , W ))
200
216
x = x + self .drop_path (self .mlp (self .norm2 (x ), H , W ))
201
-
217
+ x = x . transpose ( 1 , 2 ). view ( B , - 1 , H , W )
202
218
return x
203
219
204
220
@@ -221,7 +237,7 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7
221
237
stride = stride ,
222
238
padding = (patch_size [0 ] // 2 , patch_size [1 ] // 2 ),
223
239
)
224
- self .norm = nn . LayerNorm (embed_dim )
240
+ self .norm = LayerNorm (embed_dim )
225
241
226
242
self .apply (self ._init_weights )
227
243
@@ -230,7 +246,7 @@ def _init_weights(self, m):
230
246
trunc_normal_ (m .weight , std = 0.02 )
231
247
if isinstance (m , nn .Linear ) and m .bias is not None :
232
248
nn .init .constant_ (m .bias , 0 )
233
- elif isinstance (m , nn . LayerNorm ):
249
+ elif isinstance (m , LayerNorm ):
234
250
nn .init .constant_ (m .bias , 0 )
235
251
nn .init .constant_ (m .weight , 1.0 )
236
252
elif isinstance (m , nn .Conv2d ):
@@ -242,11 +258,8 @@ def _init_weights(self, m):
242
258
243
259
def forward (self , x ):
244
260
x = self .proj (x )
245
- _ , _ , H , W = x .shape
246
- x = x .flatten (2 ).transpose (1 , 2 )
247
261
x = self .norm (x )
248
-
249
- return x , H , W
262
+ return x
250
263
251
264
252
265
class MixVisionTransformer (nn .Module ):
@@ -264,7 +277,7 @@ def __init__(
264
277
drop_rate = 0.0 ,
265
278
attn_drop_rate = 0.0 ,
266
279
drop_path_rate = 0.0 ,
267
- norm_layer = nn . LayerNorm ,
280
+ norm_layer = LayerNorm ,
268
281
depths = [3 , 4 , 6 , 3 ],
269
282
sr_ratios = [8 , 4 , 2 , 1 ],
270
283
):
@@ -307,8 +320,8 @@ def __init__(
307
320
x .item () for x in torch .linspace (0 , drop_path_rate , sum (depths ))
308
321
] # stochastic depth decay rule
309
322
cur = 0
310
- self .block1 = nn .ModuleList (
311
- [
323
+ self .block1 = nn .Sequential (
324
+ * [
312
325
Block (
313
326
dim = embed_dims [0 ],
314
327
num_heads = num_heads [0 ],
@@ -327,8 +340,8 @@ def __init__(
327
340
self .norm1 = norm_layer (embed_dims [0 ])
328
341
329
342
cur += depths [0 ]
330
- self .block2 = nn .ModuleList (
331
- [
343
+ self .block2 = nn .Sequential (
344
+ * [
332
345
Block (
333
346
dim = embed_dims [1 ],
334
347
num_heads = num_heads [1 ],
@@ -347,8 +360,8 @@ def __init__(
347
360
self .norm2 = norm_layer (embed_dims [1 ])
348
361
349
362
cur += depths [1 ]
350
- self .block3 = nn .ModuleList (
351
- [
363
+ self .block3 = nn .Sequential (
364
+ * [
352
365
Block (
353
366
dim = embed_dims [2 ],
354
367
num_heads = num_heads [2 ],
@@ -367,8 +380,8 @@ def __init__(
367
380
self .norm3 = norm_layer (embed_dims [2 ])
368
381
369
382
cur += depths [2 ]
370
- self .block4 = nn .ModuleList (
371
- [
383
+ self .block4 = nn .Sequential (
384
+ * [
372
385
Block (
373
386
dim = embed_dims [3 ],
374
387
num_heads = num_heads [3 ],
@@ -396,7 +409,7 @@ def _init_weights(self, m):
396
409
trunc_normal_ (m .weight , std = 0.02 )
397
410
if isinstance (m , nn .Linear ) and m .bias is not None :
398
411
nn .init .constant_ (m .bias , 0 )
399
- elif isinstance (m , nn . LayerNorm ):
412
+ elif isinstance (m , LayerNorm ):
400
413
nn .init .constant_ (m .bias , 0 )
401
414
nn .init .constant_ (m .weight , 1.0 )
402
415
elif isinstance (m , nn .Conv2d ):
@@ -450,39 +463,30 @@ def reset_classifier(self, num_classes, global_pool=""):
450
463
)
451
464
452
465
def forward_features (self , x ):
453
- B = x .shape [0 ]
454
466
outs = []
455
467
456
468
# stage 1
457
- x , H , W = self .patch_embed1 (x )
458
- for i , blk in enumerate (self .block1 ):
459
- x = blk (x , H , W )
460
- x = self .norm1 (x )
461
- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
469
+ x = self .patch_embed1 (x )
470
+ x = self .block1 (x )
471
+ x = self .norm1 (x ).contiguous ()
462
472
outs .append (x )
463
473
464
474
# stage 2
465
- x , H , W = self .patch_embed2 (x )
466
- for i , blk in enumerate (self .block2 ):
467
- x = blk (x , H , W )
468
- x = self .norm2 (x )
469
- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
475
+ x = self .patch_embed2 (x )
476
+ x = self .block2 (x )
477
+ x = self .norm2 (x ).contiguous ()
470
478
outs .append (x )
471
479
472
480
# stage 3
473
- x , H , W = self .patch_embed3 (x )
474
- for i , blk in enumerate (self .block3 ):
475
- x = blk (x , H , W )
476
- x = self .norm3 (x )
477
- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
481
+ x = self .patch_embed3 (x )
482
+ x = self .block3 (x )
483
+ x = self .norm3 (x ).contiguous ()
478
484
outs .append (x )
479
485
480
486
# stage 4
481
- x , H , W = self .patch_embed4 (x )
482
- for i , blk in enumerate (self .block4 ):
483
- x = blk (x , H , W )
484
- x = self .norm4 (x )
485
- x = x .reshape (B , H , W , - 1 ).permute (0 , 3 , 1 , 2 ).contiguous ()
487
+ x = self .patch_embed4 (x )
488
+ x = self .block4 (x )
489
+ x = self .norm4 (x ).contiguous ()
486
490
outs .append (x )
487
491
488
492
return outs
@@ -500,7 +504,7 @@ def __init__(self, dim=768):
500
504
self .dwconv = nn .Conv2d (dim , dim , 3 , 1 , 1 , bias = True , groups = dim )
501
505
502
506
def forward (self , x , H , W ):
503
- B , N , C = x .shape
507
+ B , _ , C = x .shape
504
508
x = x .transpose (1 , 2 ).view (B , C , H , W )
505
509
x = self .dwconv (x )
506
510
x = x .flatten (2 ).transpose (1 , 2 )
@@ -522,21 +526,31 @@ def __init__(self, out_channels, depth=5, **kwargs):
522
526
self ._depth = depth
523
527
self ._in_channels = 3
524
528
525
- def make_dilated (self , * args , ** kwargs ):
526
- raise ValueError ("MixVisionTransformer encoder does not support dilated mode" )
527
-
528
- def set_in_channels (self , in_channels , * args , ** kwargs ):
529
- if in_channels != 3 :
530
- raise ValueError (
531
- "MixVisionTransformer encoder does not support in_channels setting other than 3"
532
- )
529
+ def get_stages (self ):
530
+ return [
531
+ nn .Identity (),
532
+ nn .Identity (),
533
+ nn .Sequential (self .patch_embed1 , self .block1 , self .norm1 ),
534
+ nn .Sequential (self .patch_embed2 , self .block2 , self .norm2 ),
535
+ nn .Sequential (self .patch_embed3 , self .block3 , self .norm3 ),
536
+ nn .Sequential (self .patch_embed4 , self .block4 , self .norm4 ),
537
+ ]
533
538
534
539
def forward (self , x ):
540
+ stages = self .get_stages ()
541
+
535
542
# create dummy output for the first block
536
- B , C , H , W = x .shape
543
+ B , _ , H , W = x .shape
537
544
dummy = torch .empty ([B , 0 , H // 2 , W // 2 ], dtype = x .dtype , device = x .device )
538
545
539
- return [x , dummy ] + self .forward_features (x )[: self ._depth - 1 ]
546
+ features = []
547
+ for i in range (self ._depth + 1 ):
548
+ if i == 1 :
549
+ features .append (dummy )
550
+ else :
551
+ x = stages [i ](x ).contiguous ()
552
+ features .append (x )
553
+ return features
540
554
541
555
def load_state_dict (self , state_dict ):
542
556
state_dict .pop ("head.weight" , None )
@@ -568,7 +582,7 @@ def get_pretrained_cfg(name):
568
582
num_heads = [1 , 2 , 5 , 8 ],
569
583
mlp_ratios = [4 , 4 , 4 , 4 ],
570
584
qkv_bias = True ,
571
- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
585
+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
572
586
depths = [2 , 2 , 2 , 2 ],
573
587
sr_ratios = [8 , 4 , 2 , 1 ],
574
588
drop_rate = 0.0 ,
@@ -585,7 +599,7 @@ def get_pretrained_cfg(name):
585
599
num_heads = [1 , 2 , 5 , 8 ],
586
600
mlp_ratios = [4 , 4 , 4 , 4 ],
587
601
qkv_bias = True ,
588
- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
602
+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
589
603
depths = [2 , 2 , 2 , 2 ],
590
604
sr_ratios = [8 , 4 , 2 , 1 ],
591
605
drop_rate = 0.0 ,
@@ -602,7 +616,7 @@ def get_pretrained_cfg(name):
602
616
num_heads = [1 , 2 , 5 , 8 ],
603
617
mlp_ratios = [4 , 4 , 4 , 4 ],
604
618
qkv_bias = True ,
605
- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
619
+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
606
620
depths = [3 , 4 , 6 , 3 ],
607
621
sr_ratios = [8 , 4 , 2 , 1 ],
608
622
drop_rate = 0.0 ,
@@ -619,7 +633,7 @@ def get_pretrained_cfg(name):
619
633
num_heads = [1 , 2 , 5 , 8 ],
620
634
mlp_ratios = [4 , 4 , 4 , 4 ],
621
635
qkv_bias = True ,
622
- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
636
+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
623
637
depths = [3 , 4 , 18 , 3 ],
624
638
sr_ratios = [8 , 4 , 2 , 1 ],
625
639
drop_rate = 0.0 ,
@@ -636,7 +650,7 @@ def get_pretrained_cfg(name):
636
650
num_heads = [1 , 2 , 5 , 8 ],
637
651
mlp_ratios = [4 , 4 , 4 , 4 ],
638
652
qkv_bias = True ,
639
- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
653
+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
640
654
depths = [3 , 8 , 27 , 3 ],
641
655
sr_ratios = [8 , 4 , 2 , 1 ],
642
656
drop_rate = 0.0 ,
@@ -653,7 +667,7 @@ def get_pretrained_cfg(name):
653
667
num_heads = [1 , 2 , 5 , 8 ],
654
668
mlp_ratios = [4 , 4 , 4 , 4 ],
655
669
qkv_bias = True ,
656
- norm_layer = partial (nn . LayerNorm , eps = 1e-6 ),
670
+ norm_layer = partial (LayerNorm , eps = 1e-6 ),
657
671
depths = [3 , 6 , 40 , 3 ],
658
672
sr_ratios = [8 , 4 , 2 , 1 ],
659
673
drop_rate = 0.0 ,
0 commit comments