Skip to content

Commit b969ee7

Browse files
committed
Fix dropout layer
1 parent 92745a9 commit b969ee7

File tree

1 file changed

+6
-2
lines changed
  • segmentation_models_pytorch/base

1 file changed

+6
-2
lines changed

segmentation_models_pytorch/base/heads.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,15 @@
55
class SegmentationHead(nn.Sequential):
66

77
def __init__(self, in_channels, out_channels, kernel_size=3, dropout=None, activation=None, upsampling=1):
8-
dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
8+
if dropout:
9+
dropout = nn.Dropout(p=dropout, inplace=False) if dropout else nn.Identity()
910
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
1011
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
1112
activation = Activation(activation)
12-
super().__init__(dropout, conv2d, upsampling, activation)
13+
if dropout:
14+
super().__init__(dropout, conv2d, upsampling, activation)
15+
else:
16+
super().__init__(conv2d, upsampling, activation)
1317

1418

1519
class ClassificationHead(nn.Sequential):

0 commit comments

Comments
 (0)