Skip to content

Commit 441a602

Browse files
authored
Update MixVisionTransformer (#975)
* Update mix_transformer.py * Fix style issue * Replace LayerNorm * Update LayerNorm
1 parent b93cf54 commit 441a602

File tree

1 file changed

+78
-64
lines changed

1 file changed

+78
-64
lines changed

segmentation_models_pytorch/encoders/mix_transformer.py

+78-64
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
# ---------------------------------------------------------------
22
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
33
#
4-
# This work is licensed under the NVIDIA Source Code License
4+
# Licensed under the NVIDIA Source Code License. For full license
5+
# terms, please refer to the LICENSE file provided with this code
6+
# or visit NVIDIA's official repository at
7+
# https://github.com/NVlabs/SegFormer/tree/master.
8+
#
9+
# This code has been modified.
510
# ---------------------------------------------------------------
611
import math
712
import torch
@@ -11,6 +16,18 @@
1116
from timm.layers import DropPath, to_2tuple, trunc_normal_
1217

1318

19+
class LayerNorm(nn.LayerNorm):
20+
def forward(self, x):
21+
if x.ndim == 4:
22+
B, C, H, W = x.shape
23+
x = x.view(B, C, -1).transpose(1, 2)
24+
x = super().forward(x)
25+
x = x.transpose(1, 2).view(B, C, H, W)
26+
else:
27+
x = super().forward(x)
28+
return x
29+
30+
1431
class Mlp(nn.Module):
1532
def __init__(
1633
self,
@@ -36,9 +53,6 @@ def _init_weights(self, m):
3653
trunc_normal_(m.weight, std=0.02)
3754
if isinstance(m, nn.Linear) and m.bias is not None:
3855
nn.init.constant_(m.bias, 0)
39-
elif isinstance(m, nn.LayerNorm):
40-
nn.init.constant_(m.bias, 0)
41-
nn.init.constant_(m.weight, 1.0)
4256
elif isinstance(m, nn.Conv2d):
4357
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
4458
fan_out //= m.groups
@@ -86,7 +100,7 @@ def __init__(
86100
self.sr_ratio = sr_ratio
87101
if sr_ratio > 1:
88102
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
89-
self.norm = nn.LayerNorm(dim)
103+
self.norm = LayerNorm(dim)
90104

91105
self.apply(self._init_weights)
92106

@@ -95,7 +109,7 @@ def _init_weights(self, m):
95109
trunc_normal_(m.weight, std=0.02)
96110
if isinstance(m, nn.Linear) and m.bias is not None:
97111
nn.init.constant_(m.bias, 0)
98-
elif isinstance(m, nn.LayerNorm):
112+
elif isinstance(m, LayerNorm):
99113
nn.init.constant_(m.bias, 0)
100114
nn.init.constant_(m.weight, 1.0)
101115
elif isinstance(m, nn.Conv2d):
@@ -153,7 +167,7 @@ def __init__(
153167
attn_drop=0.0,
154168
drop_path=0.0,
155169
act_layer=nn.GELU,
156-
norm_layer=nn.LayerNorm,
170+
norm_layer=LayerNorm,
157171
sr_ratio=1,
158172
):
159173
super().__init__()
@@ -185,7 +199,7 @@ def _init_weights(self, m):
185199
trunc_normal_(m.weight, std=0.02)
186200
if isinstance(m, nn.Linear) and m.bias is not None:
187201
nn.init.constant_(m.bias, 0)
188-
elif isinstance(m, nn.LayerNorm):
202+
elif isinstance(m, LayerNorm):
189203
nn.init.constant_(m.bias, 0)
190204
nn.init.constant_(m.weight, 1.0)
191205
elif isinstance(m, nn.Conv2d):
@@ -195,10 +209,12 @@ def _init_weights(self, m):
195209
if m.bias is not None:
196210
m.bias.data.zero_()
197211

198-
def forward(self, x, H, W):
212+
def forward(self, x):
213+
B, _, H, W = x.shape
214+
x = x.flatten(2).transpose(1, 2)
199215
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
200216
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
201-
217+
x = x.transpose(1, 2).view(B, -1, H, W)
202218
return x
203219

204220

@@ -221,7 +237,7 @@ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=7
221237
stride=stride,
222238
padding=(patch_size[0] // 2, patch_size[1] // 2),
223239
)
224-
self.norm = nn.LayerNorm(embed_dim)
240+
self.norm = LayerNorm(embed_dim)
225241

226242
self.apply(self._init_weights)
227243

@@ -230,7 +246,7 @@ def _init_weights(self, m):
230246
trunc_normal_(m.weight, std=0.02)
231247
if isinstance(m, nn.Linear) and m.bias is not None:
232248
nn.init.constant_(m.bias, 0)
233-
elif isinstance(m, nn.LayerNorm):
249+
elif isinstance(m, LayerNorm):
234250
nn.init.constant_(m.bias, 0)
235251
nn.init.constant_(m.weight, 1.0)
236252
elif isinstance(m, nn.Conv2d):
@@ -242,11 +258,8 @@ def _init_weights(self, m):
242258

243259
def forward(self, x):
244260
x = self.proj(x)
245-
_, _, H, W = x.shape
246-
x = x.flatten(2).transpose(1, 2)
247261
x = self.norm(x)
248-
249-
return x, H, W
262+
return x
250263

251264

252265
class MixVisionTransformer(nn.Module):
@@ -264,7 +277,7 @@ def __init__(
264277
drop_rate=0.0,
265278
attn_drop_rate=0.0,
266279
drop_path_rate=0.0,
267-
norm_layer=nn.LayerNorm,
280+
norm_layer=LayerNorm,
268281
depths=[3, 4, 6, 3],
269282
sr_ratios=[8, 4, 2, 1],
270283
):
@@ -307,8 +320,8 @@ def __init__(
307320
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
308321
] # stochastic depth decay rule
309322
cur = 0
310-
self.block1 = nn.ModuleList(
311-
[
323+
self.block1 = nn.Sequential(
324+
*[
312325
Block(
313326
dim=embed_dims[0],
314327
num_heads=num_heads[0],
@@ -327,8 +340,8 @@ def __init__(
327340
self.norm1 = norm_layer(embed_dims[0])
328341

329342
cur += depths[0]
330-
self.block2 = nn.ModuleList(
331-
[
343+
self.block2 = nn.Sequential(
344+
*[
332345
Block(
333346
dim=embed_dims[1],
334347
num_heads=num_heads[1],
@@ -347,8 +360,8 @@ def __init__(
347360
self.norm2 = norm_layer(embed_dims[1])
348361

349362
cur += depths[1]
350-
self.block3 = nn.ModuleList(
351-
[
363+
self.block3 = nn.Sequential(
364+
*[
352365
Block(
353366
dim=embed_dims[2],
354367
num_heads=num_heads[2],
@@ -367,8 +380,8 @@ def __init__(
367380
self.norm3 = norm_layer(embed_dims[2])
368381

369382
cur += depths[2]
370-
self.block4 = nn.ModuleList(
371-
[
383+
self.block4 = nn.Sequential(
384+
*[
372385
Block(
373386
dim=embed_dims[3],
374387
num_heads=num_heads[3],
@@ -396,7 +409,7 @@ def _init_weights(self, m):
396409
trunc_normal_(m.weight, std=0.02)
397410
if isinstance(m, nn.Linear) and m.bias is not None:
398411
nn.init.constant_(m.bias, 0)
399-
elif isinstance(m, nn.LayerNorm):
412+
elif isinstance(m, LayerNorm):
400413
nn.init.constant_(m.bias, 0)
401414
nn.init.constant_(m.weight, 1.0)
402415
elif isinstance(m, nn.Conv2d):
@@ -450,39 +463,30 @@ def reset_classifier(self, num_classes, global_pool=""):
450463
)
451464

452465
def forward_features(self, x):
453-
B = x.shape[0]
454466
outs = []
455467

456468
# stage 1
457-
x, H, W = self.patch_embed1(x)
458-
for i, blk in enumerate(self.block1):
459-
x = blk(x, H, W)
460-
x = self.norm1(x)
461-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
469+
x = self.patch_embed1(x)
470+
x = self.block1(x)
471+
x = self.norm1(x).contiguous()
462472
outs.append(x)
463473

464474
# stage 2
465-
x, H, W = self.patch_embed2(x)
466-
for i, blk in enumerate(self.block2):
467-
x = blk(x, H, W)
468-
x = self.norm2(x)
469-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
475+
x = self.patch_embed2(x)
476+
x = self.block2(x)
477+
x = self.norm2(x).contiguous()
470478
outs.append(x)
471479

472480
# stage 3
473-
x, H, W = self.patch_embed3(x)
474-
for i, blk in enumerate(self.block3):
475-
x = blk(x, H, W)
476-
x = self.norm3(x)
477-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
481+
x = self.patch_embed3(x)
482+
x = self.block3(x)
483+
x = self.norm3(x).contiguous()
478484
outs.append(x)
479485

480486
# stage 4
481-
x, H, W = self.patch_embed4(x)
482-
for i, blk in enumerate(self.block4):
483-
x = blk(x, H, W)
484-
x = self.norm4(x)
485-
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
487+
x = self.patch_embed4(x)
488+
x = self.block4(x)
489+
x = self.norm4(x).contiguous()
486490
outs.append(x)
487491

488492
return outs
@@ -500,7 +504,7 @@ def __init__(self, dim=768):
500504
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
501505

502506
def forward(self, x, H, W):
503-
B, N, C = x.shape
507+
B, _, C = x.shape
504508
x = x.transpose(1, 2).view(B, C, H, W)
505509
x = self.dwconv(x)
506510
x = x.flatten(2).transpose(1, 2)
@@ -522,21 +526,31 @@ def __init__(self, out_channels, depth=5, **kwargs):
522526
self._depth = depth
523527
self._in_channels = 3
524528

525-
def make_dilated(self, *args, **kwargs):
526-
raise ValueError("MixVisionTransformer encoder does not support dilated mode")
527-
528-
def set_in_channels(self, in_channels, *args, **kwargs):
529-
if in_channels != 3:
530-
raise ValueError(
531-
"MixVisionTransformer encoder does not support in_channels setting other than 3"
532-
)
529+
def get_stages(self):
530+
return [
531+
nn.Identity(),
532+
nn.Identity(),
533+
nn.Sequential(self.patch_embed1, self.block1, self.norm1),
534+
nn.Sequential(self.patch_embed2, self.block2, self.norm2),
535+
nn.Sequential(self.patch_embed3, self.block3, self.norm3),
536+
nn.Sequential(self.patch_embed4, self.block4, self.norm4),
537+
]
533538

534539
def forward(self, x):
540+
stages = self.get_stages()
541+
535542
# create dummy output for the first block
536-
B, C, H, W = x.shape
543+
B, _, H, W = x.shape
537544
dummy = torch.empty([B, 0, H // 2, W // 2], dtype=x.dtype, device=x.device)
538545

539-
return [x, dummy] + self.forward_features(x)[: self._depth - 1]
546+
features = []
547+
for i in range(self._depth + 1):
548+
if i == 1:
549+
features.append(dummy)
550+
else:
551+
x = stages[i](x).contiguous()
552+
features.append(x)
553+
return features
540554

541555
def load_state_dict(self, state_dict):
542556
state_dict.pop("head.weight", None)
@@ -568,7 +582,7 @@ def get_pretrained_cfg(name):
568582
num_heads=[1, 2, 5, 8],
569583
mlp_ratios=[4, 4, 4, 4],
570584
qkv_bias=True,
571-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
585+
norm_layer=partial(LayerNorm, eps=1e-6),
572586
depths=[2, 2, 2, 2],
573587
sr_ratios=[8, 4, 2, 1],
574588
drop_rate=0.0,
@@ -585,7 +599,7 @@ def get_pretrained_cfg(name):
585599
num_heads=[1, 2, 5, 8],
586600
mlp_ratios=[4, 4, 4, 4],
587601
qkv_bias=True,
588-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
602+
norm_layer=partial(LayerNorm, eps=1e-6),
589603
depths=[2, 2, 2, 2],
590604
sr_ratios=[8, 4, 2, 1],
591605
drop_rate=0.0,
@@ -602,7 +616,7 @@ def get_pretrained_cfg(name):
602616
num_heads=[1, 2, 5, 8],
603617
mlp_ratios=[4, 4, 4, 4],
604618
qkv_bias=True,
605-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
619+
norm_layer=partial(LayerNorm, eps=1e-6),
606620
depths=[3, 4, 6, 3],
607621
sr_ratios=[8, 4, 2, 1],
608622
drop_rate=0.0,
@@ -619,7 +633,7 @@ def get_pretrained_cfg(name):
619633
num_heads=[1, 2, 5, 8],
620634
mlp_ratios=[4, 4, 4, 4],
621635
qkv_bias=True,
622-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
636+
norm_layer=partial(LayerNorm, eps=1e-6),
623637
depths=[3, 4, 18, 3],
624638
sr_ratios=[8, 4, 2, 1],
625639
drop_rate=0.0,
@@ -636,7 +650,7 @@ def get_pretrained_cfg(name):
636650
num_heads=[1, 2, 5, 8],
637651
mlp_ratios=[4, 4, 4, 4],
638652
qkv_bias=True,
639-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
653+
norm_layer=partial(LayerNorm, eps=1e-6),
640654
depths=[3, 8, 27, 3],
641655
sr_ratios=[8, 4, 2, 1],
642656
drop_rate=0.0,
@@ -653,7 +667,7 @@ def get_pretrained_cfg(name):
653667
num_heads=[1, 2, 5, 8],
654668
mlp_ratios=[4, 4, 4, 4],
655669
qkv_bias=True,
656-
norm_layer=partial(nn.LayerNorm, eps=1e-6),
670+
norm_layer=partial(LayerNorm, eps=1e-6),
657671
depths=[3, 6, 40, 3],
658672
sr_ratios=[8, 4, 2, 1],
659673
drop_rate=0.0,

0 commit comments

Comments
 (0)