Skip to content

Commit 2f57c5d

Browse files
committed
Merge remote-tracking branch 'brianhou0208/fix_deeplab' into merge/fix_deeplab
2 parents 737b24f + c97d43a commit 2f57c5d

File tree

2 files changed

+34
-29
lines changed

2 files changed

+34
-29
lines changed

segmentation_models_pytorch/decoders/deeplabv3/decoder.py

+10-21
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,6 @@ def __init__(
6161
nn.BatchNorm2d(out_channels),
6262
nn.ReLU(),
6363
)
64-
self.out_channels = out_channels
6564

6665
def forward(self, *features):
6766
return super().forward(features[-1])
@@ -88,9 +87,6 @@ def __init__(
8887
"Output stride should be 8 or 16, got {}.".format(output_stride)
8988
)
9089

91-
self.out_channels = out_channels
92-
self.output_stride = output_stride
93-
9490
self.aspp = nn.Sequential(
9591
ASPP(
9692
encoder_channels[-1],
@@ -106,17 +102,10 @@ def __init__(
106102
nn.ReLU(),
107103
)
108104

109-
scale_factor = 2 if output_stride == 8 else 4
105+
scale_factor = 4 if output_stride == 16 and encoder_depth > 3 else 2
110106
self.up = nn.UpsamplingBilinear2d(scale_factor=scale_factor)
111107

112-
if encoder_depth == 3 and output_stride == 8:
113-
self.highres_input_index = -2
114-
elif encoder_depth == 3 or encoder_depth == 4:
115-
self.highres_input_index = -3
116-
else:
117-
self.highres_input_index = -4
118-
119-
highres_in_channels = encoder_channels[self.highres_input_index]
108+
highres_in_channels = encoder_channels[2]
120109
highres_out_channels = 48 # proposed by authors of paper
121110
self.block1 = nn.Sequential(
122111
nn.Conv2d(
@@ -140,7 +129,7 @@ def __init__(
140129
def forward(self, *features):
141130
aspp_features = self.aspp(features[-1])
142131
aspp_features = self.up(aspp_features)
143-
high_res_features = self.block1(features[self.highres_input_index])
132+
high_res_features = self.block1(features[2])
144133
concat_features = torch.cat([aspp_features, high_res_features], dim=1)
145134
fused_features = self.block2(concat_features)
146135
return fused_features
@@ -240,13 +229,13 @@ def forward(self, x):
240229
class SeparableConv2d(nn.Sequential):
241230
def __init__(
242231
self,
243-
in_channels,
244-
out_channels,
245-
kernel_size,
246-
stride=1,
247-
padding=0,
248-
dilation=1,
249-
bias=True,
232+
in_channels: int,
233+
out_channels: int,
234+
kernel_size: int,
235+
stride: int = 1,
236+
padding: int = 0,
237+
dilation: int = 1,
238+
bias: bool = True,
250239
):
251240
dephtwise_conv = nn.Conv2d(
252241
in_channels,

segmentation_models_pytorch/decoders/deeplabv3/model.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,16 @@ class DeepLabV3(SegmentationModel):
3535
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
3636
**callable** and **None**.
3737
Default is **None**
38-
upsampling: Final upsampling factor (should have the same value as ``encoder_output_stride`` to preserve input-output spatial shape identity).
38+
upsampling: Final upsampling factor. Default is **None** to preserve input-output spatial shape identity
3939
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
4040
on top of encoder if **aux_params** is not **None** (default). Supported params:
4141
- classes (int): A number of classes
4242
- pooling (str): One of "max", "avg". Default is "avg"
4343
- dropout (float): Dropout factor in [0, 1)
4444
- activation (str): An activation function to apply "sigmoid"/"softmax"
4545
(could be **None** to return logits)
46-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
46+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
47+
Keys with ``None`` values are pruned before passing.
4748
4849
Returns:
4950
``torch.nn.Module``: **DeepLabV3**
@@ -72,6 +73,13 @@ def __init__(
7273
):
7374
super().__init__()
7475

76+
if encoder_output_stride not in [8, 16]:
77+
raise ValueError(
78+
"DeeplabV3 support output stride 8 or 16, got {}.".format(
79+
encoder_output_stride
80+
)
81+
)
82+
7583
self.encoder = get_encoder(
7684
encoder_name,
7785
in_channels=in_channels,
@@ -81,6 +89,14 @@ def __init__(
8189
**kwargs,
8290
)
8391

92+
if upsampling is None:
93+
if encoder_depth <= 3:
94+
scale_factor = 2**encoder_depth
95+
else:
96+
scale_factor = encoder_output_stride
97+
else:
98+
scale_factor = upsampling
99+
84100
self.decoder = DeepLabV3Decoder(
85101
in_channels=self.encoder.out_channels[-1],
86102
out_channels=decoder_channels,
@@ -90,11 +106,11 @@ def __init__(
90106
)
91107

92108
self.segmentation_head = SegmentationHead(
93-
in_channels=self.decoder.out_channels,
109+
in_channels=decoder_channels,
94110
out_channels=classes,
95111
activation=activation,
96112
kernel_size=1,
97-
upsampling=encoder_output_stride if upsampling is None else upsampling,
113+
upsampling=scale_factor,
98114
)
99115

100116
if aux_params is not None:
@@ -129,16 +145,16 @@ class DeepLabV3Plus(SegmentationModel):
129145
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
130146
**callable** and **None**.
131147
Default is **None**
132-
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity. In case
133-
**encoder_depth** and **encoder_output_stride** are 3 and 16 resp., set **upsampling** to 2 to preserve.
148+
upsampling: Final upsampling factor. Default is 4 to preserve input-output spatial shape identity.
134149
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
135150
on top of encoder if **aux_params** is not **None** (default). Supported params:
136151
- classes (int): A number of classes
137152
- pooling (str): One of "max", "avg". Default is "avg"
138153
- dropout (float): Dropout factor in [0, 1)
139154
- activation (str): An activation function to apply "sigmoid"/"softmax"
140155
(could be **None** to return logits)
141-
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
156+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models.
157+
Keys with ``None`` values are pruned before passing.
142158
143159
Returns:
144160
``torch.nn.Module``: **DeepLabV3Plus**
@@ -187,7 +203,7 @@ def __init__(
187203
)
188204

189205
self.segmentation_head = SegmentationHead(
190-
in_channels=self.decoder.out_channels,
206+
in_channels=decoder_channels,
191207
out_channels=classes,
192208
activation=activation,
193209
kernel_size=1,

0 commit comments

Comments
 (0)