Skip to content

Commit 13c231f

Browse files
committed
Added bifpn encoder (rishabh qubvel-org#1)
1 parent c96d340 commit 13c231f

File tree

4 files changed

+301
-0
lines changed

4 files changed

+301
-0
lines changed

segmentation_models_pytorch/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
from .unet import Unet
22
from .linknet import Linknet
33
from .fpn import FPN
4+
from .bifpn import BiFPN
5+
46
from .pspnet import PSPNet
57
from .pan import PAN
68

Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import BiFPN
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
from torch.autograd import Variable
5+
6+
class Conv3x3GNReLU(nn.Module):
7+
def __init__(self, in_channels, out_channels, upsample=False):
8+
super().__init__()
9+
self.upsample = upsample
10+
self.block = nn.Sequential(
11+
nn.Conv2d(
12+
in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False
13+
),
14+
nn.GroupNorm(32, out_channels),
15+
nn.ReLU(inplace=True),
16+
)
17+
18+
def forward(self, x):
19+
x = self.block(x)
20+
if self.upsample:
21+
x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True)
22+
return x
23+
24+
class DepthwiseConvBlock(nn.Module):
25+
"""
26+
Depthwise seperable convolution.
27+
28+
29+
"""
30+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False):
31+
super(DepthwiseConvBlock,self).__init__()
32+
self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride,
33+
padding, dilation, groups=in_channels, bias=False)
34+
self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1,
35+
stride=1, padding=0, dilation=1, groups=1, bias=False)
36+
37+
38+
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)
39+
self.act = nn.ReLU()
40+
41+
def forward(self, inputs):
42+
x = self.depthwise(inputs)
43+
x = self.pointwise(x)
44+
x = self.bn(x)
45+
return self.act(x)
46+
47+
48+
class SegmentationBlock(nn.Module):
49+
def __init__(self, in_channels, out_channels, n_upsamples=0):
50+
super().__init__()
51+
52+
blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))]
53+
54+
if n_upsamples > 1:
55+
for _ in range(1, n_upsamples):
56+
blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True))
57+
58+
self.block = nn.Sequential(*blocks)
59+
60+
def forward(self, x):
61+
return self.block(x)
62+
63+
64+
class MergeBlock(nn.Module):
65+
def __init__(self, policy):
66+
super().__init__()
67+
if policy not in ["add", "cat"]:
68+
raise ValueError(
69+
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(
70+
policy
71+
)
72+
)
73+
self.policy = policy
74+
75+
def forward(self, x):
76+
if self.policy == 'add':
77+
return sum(x)
78+
elif self.policy == 'cat':
79+
return torch.cat(x, dim=1)
80+
else:
81+
raise ValueError(
82+
"`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy)
83+
)
84+
85+
class ConvBlock(nn.Module):
86+
"""
87+
Convolution block with Batch Normalization and ReLU activation.
88+
89+
"""
90+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False):
91+
super(ConvBlock,self).__init__()
92+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding)
93+
self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)
94+
self.act = nn.ReLU()
95+
96+
def forward(self, inputs):
97+
x = self.conv(inputs)
98+
x = self.bn(x)
99+
return self.act(x)
100+
101+
class BiFPNBlock(nn.Module):
102+
"""
103+
Bi-directional Feature Pyramid Network
104+
"""
105+
def __init__(self, feature_size=64, epsilon=0.0001):
106+
super(BiFPNBlock, self).__init__()
107+
self.epsilon = epsilon
108+
109+
self.p3_td = DepthwiseConvBlock(feature_size, feature_size)
110+
self.p4_td = DepthwiseConvBlock(feature_size, feature_size)
111+
self.p5_td = DepthwiseConvBlock(feature_size, feature_size)
112+
self.p6_td = DepthwiseConvBlock(feature_size, feature_size)
113+
114+
self.p4_out = DepthwiseConvBlock(feature_size, feature_size)
115+
self.p5_out = DepthwiseConvBlock(feature_size, feature_size)
116+
self.p6_out = DepthwiseConvBlock(feature_size, feature_size)
117+
self.p7_out = DepthwiseConvBlock(feature_size, feature_size)
118+
119+
# TODO: Init weights
120+
self.w1 = nn.Parameter(torch.Tensor(2, 4))
121+
self.w1_relu = nn.ReLU()
122+
self.w2 = nn.Parameter(torch.Tensor(3, 4))
123+
self.w2_relu = nn.ReLU()
124+
125+
def forward(self, inputs):
126+
p3_x, p4_x, p5_x, p6_x, p7_x = inputs
127+
128+
# Calculate Top-Down Pathway
129+
w1 = self.w1_relu(self.w1)
130+
w1 /= torch.sum(w1, dim=0) + self.epsilon
131+
w2 = self.w2_relu(self.w2)
132+
w2 /= torch.sum(w2, dim=0) + self.epsilon
133+
134+
p7_td = p7_x
135+
p6_td = self.p6_td(w1[0, 0] * p6_x + w1[1, 0] * F.interpolate(p7_x, scale_factor=2))
136+
p5_td = self.p5_td(w1[0, 1] * p5_x + w1[1, 1] * F.interpolate(p6_x, scale_factor=2))
137+
p4_td = self.p4_td(w1[0, 2] * p4_x + w1[1, 2] * F.interpolate(p5_x, scale_factor=2))
138+
p3_td = self.p3_td(w1[0, 3] * p3_x + w1[1, 3] * F.interpolate(p4_x, scale_factor=2))
139+
140+
# Calculate Bottom-Up Pathway
141+
p3_out = p3_td
142+
p4_out = self.p4_out(w2[0, 0] * p4_x + w2[1, 0] * p4_td + w2[2, 0] * nn.Upsample(scale_factor=0.5)(p3_out))
143+
p5_out = self.p5_out(w2[0, 1] * p5_x + w2[1, 1] * p5_td + w2[2, 1] * nn.Upsample(scale_factor=0.5)(p4_out))
144+
p6_out = self.p6_out(w2[0, 2] * p6_x + w2[1, 2] * p6_td + w2[2, 2] * nn.Upsample(scale_factor=0.5)(p5_out))
145+
p7_out = self.p7_out(w2[0, 3] * p7_x + w2[1, 3] * p7_td + w2[2, 3] * nn.Upsample(scale_factor=0.5)(p6_out))
146+
147+
return [p3_out, p4_out, p5_out, p6_out, p7_out]
148+
149+
150+
class BiFPNDecoder(nn.Module):
151+
def __init__(
152+
self,
153+
encoder_channels,
154+
encoder_depth=4,
155+
feature_size = 64,
156+
num_layers =2,
157+
segmentation_channels=128,
158+
dropout=0.2,
159+
merge_policy="add",
160+
epsilon = 0.0001
161+
):
162+
super().__init__()
163+
164+
self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4
165+
if encoder_depth < 3:
166+
raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth))
167+
168+
encoder_channels = encoder_channels[::-1]
169+
encoder_channels = encoder_channels[:encoder_depth + 1]
170+
size = encoder_channels
171+
self.p3 = nn.Conv2d(size[2], feature_size, kernel_size=1, stride=1, padding=0)
172+
self.p4 = nn.Conv2d(size[1], feature_size, kernel_size=1, stride=1, padding=0)
173+
self.p5 = nn.Conv2d(size[0], feature_size, kernel_size=1, stride=1, padding=0)
174+
175+
# p6 is obtained via a 3x3 stride-2 conv on C5
176+
self.p6 = nn.Conv2d(size[0], feature_size, kernel_size=3, stride=2, padding=1)
177+
178+
# p7 is computed by applying ReLU followed by a 3x3 stride-2 conv on p6
179+
self.p7 = ConvBlock(feature_size, feature_size, kernel_size=3, stride=2, padding=1)
180+
181+
bifpns = []
182+
for _ in range(num_layers):
183+
bifpns.append(BiFPNBlock(feature_size))
184+
self.bifpn = nn.Sequential(*bifpns)
185+
self.seg_blocks = nn.ModuleList([
186+
SegmentationBlock(feature_size, segmentation_channels, n_upsamples=n_upsamples)
187+
for n_upsamples in [0,1,2,3,4]
188+
])
189+
190+
self.merge = MergeBlock(merge_policy)
191+
self.dropout = nn.Dropout2d(p=dropout, inplace=True)
192+
193+
194+
def forward(self, *features):
195+
c3, c4, c5 = features[-3:]
196+
p3_x = self.p3(c3)
197+
p4_x = self.p4(c4)
198+
p5_x = self.p5(c5)
199+
p6_x = self.p6(c5)
200+
p7_x = self.p7(p6_x)
201+
features = [p3_x, p4_x, p5_x, p6_x, p7_x]
202+
[p3_out, p4_out, p5_out, p6_out, p7_out] = self.bifpn(features)
203+
feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p3_out, p4_out, p5_out, p6_out, p7_out])]
204+
x = self.merge(feature_pyramid)
205+
x = self.dropout(x)
206+
return x
+92
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
from typing import Optional, Union
2+
from .decoder import BiFPNDecoder
3+
from ..base import SegmentationModel, SegmentationHead, ClassificationHead
4+
from ..encoders import get_encoder
5+
6+
7+
class BiFPN(SegmentationModel):
8+
"""FPN_ is a fully convolution neural network for image semantic segmentation
9+
Args:
10+
encoder_name: name of classification model (without last dense layers) used as feature
11+
extractor to build segmentation model.
12+
encoder_depth: number of stages used in decoder, larger depth - more features are generated.
13+
e.g. for depth=3 encoder will generate list of features with following spatial shapes
14+
[(H,W), (H/2, W/2), (H/4, W/4), (H/8, W/8)], so in general the deepest feature will have
15+
spatial resolution (H/(2^depth), W/(2^depth)]
16+
encoder_weights: one of ``None`` (random initialization), ``imagenet`` (pre-training on ImageNet).
17+
decoder_pyramid_channels: a number of convolution filters in Feature Pyramid of FPN_.
18+
decoder_segmentation_channels: a number of convolution filters in segmentation head of FPN_.
19+
decoder_merge_policy: determines how to merge outputs inside FPN.
20+
One of [``add``, ``cat``]
21+
decoder_dropout: spatial dropout rate in range (0, 1).
22+
in_channels: number of input channels for model, default is 3.
23+
classes: a number of classes for output (output shape - ``(batch, classes, h, w)``).
24+
activation (str, callable): activation function used in ``.predict(x)`` method for inference.
25+
One of [``sigmoid``, ``softmax2d``, callable, None]
26+
upsampling: optional, final upsampling factor
27+
(default is 4 to preserve input -> output spatial shape identity)
28+
aux_params: if specified model will have additional classification auxiliary output
29+
build on top of encoder, supported params:
30+
- classes (int): number of classes
31+
- pooling (str): one of 'max', 'avg'. Default is 'avg'.
32+
- dropout (float): dropout factor in [0, 1)
33+
- activation (str): activation function to apply "sigmoid"/"softmax" (could be None to return logits)
34+
35+
Returns:
36+
``torch.nn.Module``: **FPN**
37+
38+
.. _FPN:
39+
http://presentations.cocodataset.org/COCO17-Stuff-FAIR.pdf
40+
41+
"""
42+
43+
def __init__(
44+
self,
45+
encoder_name: str = "resnet34",
46+
encoder_depth: int = 4,
47+
encoder_weights: Optional[str] = "imagenet",
48+
feature_size: int = 256,
49+
decoder_segmentation_channels: int = 128,
50+
decoder_merge_policy: str = "add",
51+
decoder_dropout: float = 0.2,
52+
in_channels: int = 3,
53+
classes: int = 1,
54+
activation: Optional[str] = None,
55+
upsampling: int = 4,
56+
aux_params: Optional[dict] = None,
57+
):
58+
super().__init__()
59+
60+
self.encoder = get_encoder(
61+
encoder_name,
62+
in_channels=in_channels,
63+
depth=encoder_depth,
64+
weights=encoder_weights,
65+
)
66+
67+
self.decoder = BiFPNDecoder(
68+
encoder_channels=self.encoder.out_channels,
69+
encoder_depth=encoder_depth,
70+
feature_size=feature_size,
71+
segmentation_channels=decoder_segmentation_channels,
72+
dropout=decoder_dropout,
73+
merge_policy=decoder_merge_policy,
74+
)
75+
76+
self.segmentation_head = SegmentationHead(
77+
in_channels=self.decoder.out_channels,
78+
out_channels=classes,
79+
activation=activation,
80+
kernel_size=1,
81+
upsampling=upsampling,
82+
)
83+
84+
if aux_params is not None:
85+
self.classification_head = ClassificationHead(
86+
in_channels=self.encoder.out_channels[-1], **aux_params
87+
)
88+
else:
89+
self.classification_head = None
90+
91+
self.name = "fpn-{}".format(encoder_name)
92+
self.initialize()

0 commit comments

Comments
 (0)