|
| 1 | +import torch |
| 2 | +import torch.nn as nn |
| 3 | +import torch.nn.functional as F |
| 4 | +from torch.autograd import Variable |
| 5 | + |
| 6 | +class Conv3x3GNReLU(nn.Module): |
| 7 | + def __init__(self, in_channels, out_channels, upsample=False): |
| 8 | + super().__init__() |
| 9 | + self.upsample = upsample |
| 10 | + self.block = nn.Sequential( |
| 11 | + nn.Conv2d( |
| 12 | + in_channels, out_channels, (3, 3), stride=1, padding=1, bias=False |
| 13 | + ), |
| 14 | + nn.GroupNorm(32, out_channels), |
| 15 | + nn.ReLU(inplace=True), |
| 16 | + ) |
| 17 | + |
| 18 | + def forward(self, x): |
| 19 | + x = self.block(x) |
| 20 | + if self.upsample: |
| 21 | + x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=True) |
| 22 | + return x |
| 23 | + |
| 24 | +class DepthwiseConvBlock(nn.Module): |
| 25 | + """ |
| 26 | + Depthwise seperable convolution. |
| 27 | + |
| 28 | + |
| 29 | + """ |
| 30 | + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False): |
| 31 | + super(DepthwiseConvBlock,self).__init__() |
| 32 | + self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, |
| 33 | + padding, dilation, groups=in_channels, bias=False) |
| 34 | + self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, |
| 35 | + stride=1, padding=0, dilation=1, groups=1, bias=False) |
| 36 | + |
| 37 | + |
| 38 | + self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5) |
| 39 | + self.act = nn.ReLU() |
| 40 | + |
| 41 | + def forward(self, inputs): |
| 42 | + x = self.depthwise(inputs) |
| 43 | + x = self.pointwise(x) |
| 44 | + x = self.bn(x) |
| 45 | + return self.act(x) |
| 46 | + |
| 47 | + |
| 48 | +class SegmentationBlock(nn.Module): |
| 49 | + def __init__(self, in_channels, out_channels, n_upsamples=0): |
| 50 | + super().__init__() |
| 51 | + |
| 52 | + blocks = [Conv3x3GNReLU(in_channels, out_channels, upsample=bool(n_upsamples))] |
| 53 | + |
| 54 | + if n_upsamples > 1: |
| 55 | + for _ in range(1, n_upsamples): |
| 56 | + blocks.append(Conv3x3GNReLU(out_channels, out_channels, upsample=True)) |
| 57 | + |
| 58 | + self.block = nn.Sequential(*blocks) |
| 59 | + |
| 60 | + def forward(self, x): |
| 61 | + return self.block(x) |
| 62 | + |
| 63 | + |
| 64 | +class MergeBlock(nn.Module): |
| 65 | + def __init__(self, policy): |
| 66 | + super().__init__() |
| 67 | + if policy not in ["add", "cat"]: |
| 68 | + raise ValueError( |
| 69 | + "`merge_policy` must be one of: ['add', 'cat'], got {}".format( |
| 70 | + policy |
| 71 | + ) |
| 72 | + ) |
| 73 | + self.policy = policy |
| 74 | + |
| 75 | + def forward(self, x): |
| 76 | + if self.policy == 'add': |
| 77 | + return sum(x) |
| 78 | + elif self.policy == 'cat': |
| 79 | + return torch.cat(x, dim=1) |
| 80 | + else: |
| 81 | + raise ValueError( |
| 82 | + "`merge_policy` must be one of: ['add', 'cat'], got {}".format(self.policy) |
| 83 | + ) |
| 84 | + |
| 85 | +class ConvBlock(nn.Module): |
| 86 | + """ |
| 87 | + Convolution block with Batch Normalization and ReLU activation. |
| 88 | + |
| 89 | + """ |
| 90 | + def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, freeze_bn=False): |
| 91 | + super(ConvBlock,self).__init__() |
| 92 | + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding) |
| 93 | + self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5) |
| 94 | + self.act = nn.ReLU() |
| 95 | + |
| 96 | + def forward(self, inputs): |
| 97 | + x = self.conv(inputs) |
| 98 | + x = self.bn(x) |
| 99 | + return self.act(x) |
| 100 | + |
| 101 | +class BiFPNBlock(nn.Module): |
| 102 | + """ |
| 103 | + Bi-directional Feature Pyramid Network |
| 104 | + """ |
| 105 | + def __init__(self, feature_size=64, epsilon=0.0001): |
| 106 | + super(BiFPNBlock, self).__init__() |
| 107 | + self.epsilon = epsilon |
| 108 | + |
| 109 | + self.p3_td = DepthwiseConvBlock(feature_size, feature_size) |
| 110 | + self.p4_td = DepthwiseConvBlock(feature_size, feature_size) |
| 111 | + self.p5_td = DepthwiseConvBlock(feature_size, feature_size) |
| 112 | + self.p6_td = DepthwiseConvBlock(feature_size, feature_size) |
| 113 | + |
| 114 | + self.p4_out = DepthwiseConvBlock(feature_size, feature_size) |
| 115 | + self.p5_out = DepthwiseConvBlock(feature_size, feature_size) |
| 116 | + self.p6_out = DepthwiseConvBlock(feature_size, feature_size) |
| 117 | + self.p7_out = DepthwiseConvBlock(feature_size, feature_size) |
| 118 | + |
| 119 | + # TODO: Init weights |
| 120 | + self.w1 = nn.Parameter(torch.Tensor(2, 4)) |
| 121 | + self.w1_relu = nn.ReLU() |
| 122 | + self.w2 = nn.Parameter(torch.Tensor(3, 4)) |
| 123 | + self.w2_relu = nn.ReLU() |
| 124 | + |
| 125 | + def forward(self, inputs): |
| 126 | + p3_x, p4_x, p5_x, p6_x, p7_x = inputs |
| 127 | + |
| 128 | + # Calculate Top-Down Pathway |
| 129 | + w1 = self.w1_relu(self.w1) |
| 130 | + w1 /= torch.sum(w1, dim=0) + self.epsilon |
| 131 | + w2 = self.w2_relu(self.w2) |
| 132 | + w2 /= torch.sum(w2, dim=0) + self.epsilon |
| 133 | + |
| 134 | + p7_td = p7_x |
| 135 | + p6_td = self.p6_td(w1[0, 0] * p6_x + w1[1, 0] * F.interpolate(p7_x, scale_factor=2)) |
| 136 | + p5_td = self.p5_td(w1[0, 1] * p5_x + w1[1, 1] * F.interpolate(p6_x, scale_factor=2)) |
| 137 | + p4_td = self.p4_td(w1[0, 2] * p4_x + w1[1, 2] * F.interpolate(p5_x, scale_factor=2)) |
| 138 | + p3_td = self.p3_td(w1[0, 3] * p3_x + w1[1, 3] * F.interpolate(p4_x, scale_factor=2)) |
| 139 | + |
| 140 | + # Calculate Bottom-Up Pathway |
| 141 | + p3_out = p3_td |
| 142 | + p4_out = self.p4_out(w2[0, 0] * p4_x + w2[1, 0] * p4_td + w2[2, 0] * nn.Upsample(scale_factor=0.5)(p3_out)) |
| 143 | + p5_out = self.p5_out(w2[0, 1] * p5_x + w2[1, 1] * p5_td + w2[2, 1] * nn.Upsample(scale_factor=0.5)(p4_out)) |
| 144 | + p6_out = self.p6_out(w2[0, 2] * p6_x + w2[1, 2] * p6_td + w2[2, 2] * nn.Upsample(scale_factor=0.5)(p5_out)) |
| 145 | + p7_out = self.p7_out(w2[0, 3] * p7_x + w2[1, 3] * p7_td + w2[2, 3] * nn.Upsample(scale_factor=0.5)(p6_out)) |
| 146 | + |
| 147 | + return [p3_out, p4_out, p5_out, p6_out, p7_out] |
| 148 | + |
| 149 | + |
| 150 | +class BiFPNDecoder(nn.Module): |
| 151 | + def __init__( |
| 152 | + self, |
| 153 | + encoder_channels, |
| 154 | + encoder_depth=4, |
| 155 | + feature_size = 64, |
| 156 | + num_layers =2, |
| 157 | + segmentation_channels=128, |
| 158 | + dropout=0.2, |
| 159 | + merge_policy="add", |
| 160 | + epsilon = 0.0001 |
| 161 | + ): |
| 162 | + super().__init__() |
| 163 | + |
| 164 | + self.out_channels = segmentation_channels if merge_policy == "add" else segmentation_channels * 4 |
| 165 | + if encoder_depth < 3: |
| 166 | + raise ValueError("Encoder depth for FPN decoder cannot be less than 3, got {}.".format(encoder_depth)) |
| 167 | + |
| 168 | + encoder_channels = encoder_channels[::-1] |
| 169 | + encoder_channels = encoder_channels[:encoder_depth + 1] |
| 170 | + size = encoder_channels |
| 171 | + self.p3 = nn.Conv2d(size[2], feature_size, kernel_size=1, stride=1, padding=0) |
| 172 | + self.p4 = nn.Conv2d(size[1], feature_size, kernel_size=1, stride=1, padding=0) |
| 173 | + self.p5 = nn.Conv2d(size[0], feature_size, kernel_size=1, stride=1, padding=0) |
| 174 | + |
| 175 | + # p6 is obtained via a 3x3 stride-2 conv on C5 |
| 176 | + self.p6 = nn.Conv2d(size[0], feature_size, kernel_size=3, stride=2, padding=1) |
| 177 | + |
| 178 | + # p7 is computed by applying ReLU followed by a 3x3 stride-2 conv on p6 |
| 179 | + self.p7 = ConvBlock(feature_size, feature_size, kernel_size=3, stride=2, padding=1) |
| 180 | + |
| 181 | + bifpns = [] |
| 182 | + for _ in range(num_layers): |
| 183 | + bifpns.append(BiFPNBlock(feature_size)) |
| 184 | + self.bifpn = nn.Sequential(*bifpns) |
| 185 | + self.seg_blocks = nn.ModuleList([ |
| 186 | + SegmentationBlock(feature_size, segmentation_channels, n_upsamples=n_upsamples) |
| 187 | + for n_upsamples in [0,1,2,3,4] |
| 188 | + ]) |
| 189 | + |
| 190 | + self.merge = MergeBlock(merge_policy) |
| 191 | + self.dropout = nn.Dropout2d(p=dropout, inplace=True) |
| 192 | + |
| 193 | + |
| 194 | + def forward(self, *features): |
| 195 | + c3, c4, c5 = features[-3:] |
| 196 | + p3_x = self.p3(c3) |
| 197 | + p4_x = self.p4(c4) |
| 198 | + p5_x = self.p5(c5) |
| 199 | + p6_x = self.p6(c5) |
| 200 | + p7_x = self.p7(p6_x) |
| 201 | + features = [p3_x, p4_x, p5_x, p6_x, p7_x] |
| 202 | + [p3_out, p4_out, p5_out, p6_out, p7_out] = self.bifpn(features) |
| 203 | + feature_pyramid = [seg_block(p) for seg_block, p in zip(self.seg_blocks, [p3_out, p4_out, p5_out, p6_out, p7_out])] |
| 204 | + x = self.merge(feature_pyramid) |
| 205 | + x = self.dropout(x) |
| 206 | + return x |
0 commit comments