Skip to content

Commit cca466b

Browse files
authored
Merge pull request #1 from qubvel/feature-linknet
Linknet model
2 parents 9882d42 + da8bce6 commit cca466b

File tree

4 files changed

+108
-0
lines changed

4 files changed

+108
-0
lines changed
+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .unet import Unet
2+
from .linknet import Linknet
23

34
from . import encoders
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .model import Linknet
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import torch.nn as nn
2+
3+
from ..common.blocks import Conv2dReLU
4+
from ..base.model import Model
5+
6+
7+
class TransposeX2(nn.Module):
8+
9+
def __init__(self, in_channels, out_channels, use_batchnorm=True, **batchnorm_params):
10+
super().__init__()
11+
layers = [
12+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1),
13+
nn.ReLU(inplace=True),
14+
]
15+
if use_batchnorm:
16+
layers.insert(1, nn.BatchNorm2d(out_channels, **batchnorm_params))
17+
18+
self.block = nn.Sequential(*layers)
19+
20+
def forward(self, x):
21+
return self.block(x)
22+
23+
24+
class DecoderBlock(nn.Module):
25+
def __init__(self, in_channels, out_channels, use_batchnorm=True):
26+
super().__init__()
27+
28+
self.block = nn.Sequential(
29+
Conv2dReLU(in_channels, in_channels // 4, kernel_size=1, use_batchnorm=use_batchnorm),
30+
TransposeX2(in_channels // 4, in_channels // 4, use_batchnorm=use_batchnorm),
31+
Conv2dReLU(in_channels // 4, out_channels, kernel_size=1, use_batchnorm=use_batchnorm),
32+
)
33+
34+
def forward(self, x):
35+
x, skip = x
36+
x = self.block(x)
37+
if skip is not None:
38+
x += skip
39+
return x
40+
41+
42+
class LinknetDecoder(Model):
43+
44+
def __init__(
45+
self,
46+
encoder_channels,
47+
prefinal_channels=32,
48+
final_channels=1,
49+
use_batchnorm=True,
50+
):
51+
super().__init__()
52+
53+
in_channels = encoder_channels
54+
55+
self.layer1 = DecoderBlock(in_channels[0], in_channels[1], use_batchnorm=use_batchnorm)
56+
self.layer2 = DecoderBlock(in_channels[1], in_channels[2], use_batchnorm=use_batchnorm)
57+
self.layer3 = DecoderBlock(in_channels[2], in_channels[3], use_batchnorm=use_batchnorm)
58+
self.layer4 = DecoderBlock(in_channels[3], in_channels[4], use_batchnorm=use_batchnorm)
59+
self.layer5 = DecoderBlock(in_channels[4], prefinal_channels, use_batchnorm=use_batchnorm)
60+
self.final_conv = nn.Conv2d(prefinal_channels, final_channels, kernel_size=(1, 1))
61+
62+
self.initialize()
63+
64+
def forward(self, x):
65+
encoder_head = x[0]
66+
skips = x[1:]
67+
68+
x = self.layer1([encoder_head, skips[0]])
69+
x = self.layer2([x, skips[1]])
70+
x = self.layer3([x, skips[2]])
71+
x = self.layer4([x, skips[3]])
72+
x = self.layer5([x, None])
73+
x = self.final_conv(x)
74+
75+
return x
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from .decoder import LinknetDecoder
2+
from ..base import EncoderDecoder
3+
from ..encoders import get_encoder
4+
5+
6+
class Linknet(EncoderDecoder):
7+
8+
def __init__(
9+
self,
10+
encoder_name='resnet34',
11+
encoder_weights='imagenet',
12+
decoder_use_batchnorm=True,
13+
classes=1,
14+
activation='sigmoid',
15+
):
16+
17+
encoder = get_encoder(
18+
encoder_name,
19+
encoder_weights=encoder_weights
20+
)
21+
22+
decoder = LinknetDecoder(
23+
encoder_channels=encoder.out_shapes,
24+
prefinal_channels=32,
25+
final_channels=classes,
26+
use_batchnorm=decoder_use_batchnorm,
27+
)
28+
29+
super().__init__(encoder, decoder, activation)
30+
31+
self.name = 'link-{}'.format(encoder_name)

0 commit comments

Comments
 (0)