Skip to content

Commit d8bd323

Browse files
authored
Merge pull request qubvel-org#1 from ludics/add_mbv3
Add mobilenet_v3 in torchvision.models
2 parents 33dc950 + b28d3a3 commit d8bd323

File tree

5 files changed

+124
-7
lines changed

5 files changed

+124
-7
lines changed

README.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ The main features of this library are:
1212

1313
- High level API (just two lines to create a neural network)
1414
- 9 models architectures for binary and multi class segmentation (including legendary Unet)
15-
- 104 available encoders
15+
- 106 available encoders
1616
- All encoders have pre-trained weights for faster and better convergence
1717

1818
### [📚 Project Documentation 📚](http://smp.readthedocs.io/)
@@ -284,6 +284,8 @@ The following is a list of supported encoders in the SMP. Select the appropriate
284284
|Encoder |Weights |Params, M |
285285
|--------------------------------|:------------------------------:|:------------------------------:|
286286
|mobilenet_v2 |imagenet |2M |
287+
|mobilenet_v3_large |imagenet |3M |
288+
|mobilenet_v3_small |imagenet |1M |
287289

288290
</div>
289291
</details>

docs/encoders.rst

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -252,11 +252,15 @@ EfficientNet
252252
MobileNet
253253
~~~~~~~~~
254254

255-
+-----------------+------------+-------------+
256-
| Encoder | Weights | Params, M |
257-
+=================+============+=============+
258-
| mobilenet\_v2 | imagenet | 2M |
259-
+-----------------+------------+-------------+
255+
+---------------------+------------+-------------+
256+
| Encoder | Weights | Params, M |
257+
+=====================+============+=============+
258+
| mobilenet\_v2 | imagenet | 2M |
259+
+---------------------+------------+-------------+
260+
| mobilenet\_v3_large | imagenet | 3M |
261+
+---------------------+------------+-------------+
262+
| mobilenet\_v2_small | imagenet | 1M |
263+
+---------------------+------------+-------------+
260264

261265
DPN
262266
~~~

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
torchvision>=0.3.0
1+
torchvision==0.9.0
22
pretrainedmodels==0.7.4
33
efficientnet-pytorch==0.6.3
44
timm==0.3.2

segmentation_models_pytorch/encoders/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .inceptionv4 import inceptionv4_encoders
1111
from .efficientnet import efficient_net_encoders
1212
from .mobilenet import mobilenet_encoders
13+
from .mobilenet_v3 import mobilenet_v3_encoders
1314
from .xception import xception_encoders
1415
from .timm_efficientnet import timm_efficientnet_encoders
1516
from .timm_resnest import timm_resnest_encoders
@@ -28,6 +29,7 @@
2829
encoders.update(inceptionv4_encoders)
2930
encoders.update(efficient_net_encoders)
3031
encoders.update(mobilenet_encoders)
32+
encoders.update(mobilenet_v3_encoders)
3133
encoders.update(xception_encoders)
3234
encoders.update(timm_efficientnet_encoders)
3335
encoders.update(timm_resnest_encoders)
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
""" Each encoder should have following attributes and methods and be inherited from `_base.EncoderMixin`
2+
3+
Attributes:
4+
5+
_out_channels (list of int): specify number of channels for each encoder feature tensor
6+
_depth (int): specify number of stages in decoder (in other words number of downsampling operations)
7+
_in_channels (int): default number of input channels in first Conv2d layer for encoder (usually 3)
8+
9+
Methods:
10+
11+
forward(self, x: torch.Tensor)
12+
produce list of features of different spatial resolutions, each feature is a 4D torch.tensor of
13+
shape NCHW (features should be sorted in descending order according to spatial resolution, starting
14+
with resolution same as input `x` tensor).
15+
16+
Input: `x` with shape (1, 3, 64, 64)
17+
Output: [f0, f1, f2, f3, f4, f5] - features with corresponding shapes
18+
[(1, 3, 64, 64), (1, 64, 32, 32), (1, 128, 16, 16), (1, 256, 8, 8),
19+
(1, 512, 4, 4), (1, 1024, 2, 2)] (C - dim may differ)
20+
21+
also should support number of features according to specified depth, e.g. if depth = 5,
22+
number of feature tensors = 6 (one with same resolution as input and 5 downsampled),
23+
depth = 3 -> number of feature tensors = 4 (one with same resolution as input and 3 downsampled).
24+
"""
25+
26+
import torchvision
27+
import torch.nn as nn
28+
from torchvision.models.mobilenetv3 import _mobilenet_v3_conf
29+
30+
from ._base import EncoderMixin
31+
32+
33+
class MobileNetV3Encoder(torchvision.models.MobileNetV3, EncoderMixin):
34+
35+
def __init__(self, out_channels, stage_idxs, model_name, depth=5, **kwargs):
36+
inverted_residual_setting, last_channel = _mobilenet_v3_conf(model_name, kwargs)
37+
super().__init__(inverted_residual_setting, last_channel, **kwargs)
38+
39+
self._depth = depth
40+
self._stage_idxs = stage_idxs
41+
self._out_channels = out_channels
42+
self._in_channels = 3
43+
44+
del self.classifier
45+
46+
def get_stages(self):
47+
return [
48+
nn.Identity(),
49+
self.features[:self._stage_idxs[0]],
50+
self.features[self._stage_idxs[0]:self._stage_idxs[1]],
51+
self.features[self._stage_idxs[1]:self._stage_idxs[2]],
52+
self.features[self._stage_idxs[2]:self._stage_idxs[3]],
53+
self.features[self._stage_idxs[3]:],
54+
]
55+
56+
def forward(self, x):
57+
stages = self.get_stages()
58+
59+
features = []
60+
for i in range(self._depth + 1):
61+
x = stages[i](x)
62+
features.append(x)
63+
64+
return features
65+
66+
def load_state_dict(self, state_dict, **kwargs):
67+
state_dict.pop("classifier.0.bias")
68+
state_dict.pop("classifier.0.weight")
69+
state_dict.pop("classifier.3.bias")
70+
state_dict.pop("classifier.3.weight")
71+
super().load_state_dict(state_dict, **kwargs)
72+
73+
74+
mobilenet_v3_encoders = {
75+
"mobilenet_v3_large": {
76+
"encoder": MobileNetV3Encoder,
77+
"pretrained_settings": {
78+
"imagenet": {
79+
"mean": [0.485, 0.456, 0.406],
80+
"std": [0.229, 0.224, 0.225],
81+
"url": "https://download.pytorch.org/models/mobilenet_v3_large-8738ca79.pth",
82+
"input_space": "RGB",
83+
"input_range": [0, 1],
84+
},
85+
},
86+
"params": {
87+
"out_channels": (3, 16, 24, 40, 112, 960),
88+
"stage_idxs": (2, 4, 7, 13),
89+
"model_name": "mobilenet_v3_large",
90+
},
91+
},
92+
"mobilenet_v3_small": {
93+
"encoder": MobileNetV3Encoder,
94+
"pretrained_settings": {
95+
"imagenet": {
96+
"mean": [0.485, 0.456, 0.406],
97+
"std": [0.229, 0.224, 0.225],
98+
"url": "https://download.pytorch.org/models/mobilenet_v3_small-047dcff4.pth",
99+
"input_space": "RGB",
100+
"input_range": [0, 1],
101+
},
102+
},
103+
"params": {
104+
"out_channels": (3, 16, 16, 24, 40, 576),
105+
"stage_idxs": (1, 2, 4, 7),
106+
"model_name": "mobilenet_v3_small",
107+
},
108+
},
109+
}

0 commit comments

Comments
 (0)