Skip to content

Commit b93cf54

Browse files
brianhou0208qubvel
andauthored
[new model] Add SegFormer (#944)
* add SegFormer * update tests/test_models * update docs/models * update readme * fix segmentation_head kernel size * add timm constructor arguments * Small fix on typing * Add conversion script * Add notebook example * Format --------- Co-authored-by: Pavel Iakubovskii <[email protected]>
1 parent 8f55d8f commit b93cf54

File tree

9 files changed

+473
-2
lines changed

9 files changed

+473
-2
lines changed

README.md

+2-1
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ Segmentation based on [PyTorch](https://pytorch.org/).**
1919
The main features of this library are:
2020

2121
- High-level API (just two lines to create a neural network)
22-
- 10 models architectures for binary and multi class segmentation (including legendary Unet)
22+
- 11 models architectures for binary and multi class segmentation (including legendary Unet)
2323
- 124 available encoders (and 500+ encoders from [timm](https://github.com/rwightman/pytorch-image-models))
2424
- All encoders have pre-trained weights for faster and better convergence
2525
- Popular metrics and losses for training routines
@@ -95,6 +95,7 @@ Congratulations! You are done! Now you can train your model with your favorite f
9595
- DeepLabV3 [[paper](https://arxiv.org/abs/1706.05587)] [[docs](https://smp.readthedocs.io/en/latest/models.html#deeplabv3)]
9696
- DeepLabV3+ [[paper](https://arxiv.org/abs/1802.02611)] [[docs](https://smp.readthedocs.io/en/latest/models.html#id9)]
9797
- UPerNet [[paper](https://arxiv.org/abs/1807.10221)] [[docs](https://smp.readthedocs.io/en/latest/models.html#upernet)]
98+
- Segformer [[paper](https://arxiv.org/abs/2105.15203)] [[docs](https://smp.readthedocs.io/en/latest/models.html#segformer)]
9899

99100
#### Encoders <a name="encoders"></a>
100101

docs/models.rst

+8
Original file line numberDiff line numberDiff line change
@@ -73,3 +73,11 @@ PAN
7373
UPerNet
7474
~~~~~~~
7575
.. autoclass:: segmentation_models_pytorch.UPerNet
76+
77+
78+
.. _segformer:
79+
80+
Segformer
81+
~~~~~~~~~
82+
.. autoclass:: segmentation_models_pytorch.Segformer
83+

examples/segformer_inference_pretrained.ipynb

+131
Large diffs are not rendered by default.

segmentation_models_pytorch/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .decoders.deeplabv3 import DeepLabV3, DeepLabV3Plus
1616
from .decoders.pan import PAN
1717
from .decoders.upernet import UPerNet
18+
from .decoders.segformer import Segformer
1819
from .base.hub_mixin import from_pretrained
1920

2021
from .__version__ import __version__
@@ -50,6 +51,7 @@ def create_model(
5051
DeepLabV3Plus,
5152
PAN,
5253
UPerNet,
54+
Segformer,
5355
]
5456
archs_dict = {a.__name__.lower(): a for a in archs}
5557
try:
@@ -85,6 +87,7 @@ def create_model(
8587
"DeepLabV3Plus",
8688
"PAN",
8789
"UPerNet",
90+
"Segformer",
8891
"from_pretrained",
8992
"create_model",
9093
"__version__",
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .model import Segformer
2+
3+
__all__ = ["Segformer"]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import torch
2+
import argparse
3+
import requests
4+
import numpy as np
5+
import huggingface_hub
6+
import albumentations as A
7+
import matplotlib.pyplot as plt
8+
9+
from PIL import Image
10+
import segmentation_models_pytorch as smp
11+
12+
13+
def convert_state_dict_to_smp(state_dict: dict):
14+
# fmt: off
15+
16+
if "state_dict" in state_dict:
17+
state_dict = state_dict["state_dict"]
18+
19+
new_state_dict = {}
20+
21+
# Map the backbone components to the encoder
22+
keys = list(state_dict.keys())
23+
for key in keys:
24+
if key.startswith("backbone"):
25+
new_key = key.replace("backbone", "encoder")
26+
new_state_dict[new_key] = state_dict.pop(key)
27+
28+
29+
# Map the linear_cX layers to MLP stages
30+
for i in range(4):
31+
base = f"decode_head.linear_c{i+1}.proj"
32+
new_state_dict[f"decoder.mlp_stage.{3-i}.linear.weight"] = state_dict.pop(f"{base}.weight")
33+
new_state_dict[f"decoder.mlp_stage.{3-i}.linear.bias"] = state_dict.pop(f"{base}.bias")
34+
35+
# Map fuse_stage components
36+
fuse_base = "decode_head.linear_fuse"
37+
fuse_weights = {
38+
"decoder.fuse_stage.0.weight": state_dict.pop(f"{fuse_base}.conv.weight"),
39+
"decoder.fuse_stage.1.weight": state_dict.pop(f"{fuse_base}.bn.weight"),
40+
"decoder.fuse_stage.1.bias": state_dict.pop(f"{fuse_base}.bn.bias"),
41+
"decoder.fuse_stage.1.running_mean": state_dict.pop(f"{fuse_base}.bn.running_mean"),
42+
"decoder.fuse_stage.1.running_var": state_dict.pop(f"{fuse_base}.bn.running_var"),
43+
"decoder.fuse_stage.1.num_batches_tracked": state_dict.pop(f"{fuse_base}.bn.num_batches_tracked"),
44+
}
45+
new_state_dict.update(fuse_weights)
46+
47+
# Map final layer components
48+
new_state_dict["segmentation_head.0.weight"] = state_dict.pop("decode_head.linear_pred.weight")
49+
new_state_dict["segmentation_head.0.bias"] = state_dict.pop("decode_head.linear_pred.bias")
50+
51+
del state_dict["decode_head.conv_seg.weight"]
52+
del state_dict["decode_head.conv_seg.bias"]
53+
54+
assert len(state_dict) == 0, f"Unmapped keys: {state_dict.keys()}"
55+
56+
# fmt: on
57+
return new_state_dict
58+
59+
60+
def get_np_image():
61+
url = "https://huggingface.co/datasets/hf-internal-testing/fixtures_ade20k/resolve/main/ADE_val_00000001.jpg"
62+
image = Image.open(requests.get(url, stream=True).raw)
63+
return np.array(image)
64+
65+
66+
def main(args):
67+
original_checkpoint = torch.load(args.path, map_location="cpu", weights_only=True)
68+
smp_state_dict = convert_state_dict_to_smp(original_checkpoint)
69+
70+
config = original_checkpoint["meta"]["config"]
71+
num_classes = int(config.split("num_classes=")[1].split(",\n")[0])
72+
decoder_dims = int(config.split("embed_dim=")[1].split(",\n")[0])
73+
height, width = [
74+
int(x) for x in config.split("crop_size=(")[1].split("), ")[0].split(",")
75+
]
76+
model_size = args.path.split("segformer.")[1][:2]
77+
78+
# Create the model
79+
model = smp.create_model(
80+
in_channels=3,
81+
classes=num_classes,
82+
arch="segformer",
83+
encoder_name=f"mit_{model_size}",
84+
encoder_weights=None,
85+
decoder_segmentation_channels=decoder_dims,
86+
).eval()
87+
88+
# Load the converted state dict
89+
model.load_state_dict(smp_state_dict, strict=True)
90+
91+
# Preprocessing params
92+
preprocessing = A.Compose(
93+
[
94+
A.Resize(height, width, p=1),
95+
A.Normalize(
96+
mean=[123.675, 116.28, 103.53],
97+
std=[58.395, 57.12, 57.375],
98+
max_pixel_value=1.0,
99+
p=1,
100+
),
101+
]
102+
)
103+
104+
# Prepare the input
105+
image = get_np_image()
106+
normalized_image = preprocessing(image=image)["image"]
107+
tensor = torch.tensor(normalized_image).permute(2, 0, 1).unsqueeze(0).float()
108+
109+
# Forward pass
110+
with torch.no_grad():
111+
mask = model(tensor)
112+
113+
# Postprocessing
114+
mask = torch.nn.functional.interpolate(
115+
mask, size=(image.shape[0], image.shape[1]), mode="bilinear"
116+
)
117+
mask = torch.argmax(mask, dim=1)
118+
mask = mask.squeeze().cpu().numpy()
119+
120+
model_name = args.path.split("/")[-1].replace(".pth", "").replace(".", "-")
121+
122+
model.save_pretrained(model_name)
123+
preprocessing.save_pretrained(model_name)
124+
125+
# fmt: off
126+
plt.subplot(121), plt.axis('off'), plt.imshow(image), plt.title('Input Image')
127+
plt.subplot(122), plt.axis('off'), plt.imshow(mask), plt.title('Output Mask')
128+
plt.savefig(f"{model_name}/example_mask.png")
129+
# fmt: on
130+
131+
if args.push_to_hub:
132+
repo_id = f"smp-hub/{model_name}"
133+
api = huggingface_hub.HfApi()
134+
api.create_repo(repo_id=repo_id, repo_type="model")
135+
api.upload_folder(folder_path=model_name, repo_id=repo_id)
136+
137+
138+
if __name__ == "__main__":
139+
parser = argparse.ArgumentParser()
140+
parser.add_argument(
141+
"--path",
142+
type=str,
143+
default="weights/trained_models/segformer.b2.512x512.ade.160k.pth",
144+
)
145+
parser.add_argument("--push_to_hub", action="store_true")
146+
args = parser.parse_args()
147+
148+
main(args)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
5+
from segmentation_models_pytorch.base import modules as md
6+
7+
8+
class MLP(nn.Module):
9+
def __init__(self, skip_channels, segmentation_channels):
10+
super().__init__()
11+
12+
self.linear = nn.Linear(skip_channels, segmentation_channels)
13+
14+
def forward(self, x: torch.Tensor):
15+
batch, _, height, width = x.shape
16+
x = x.flatten(2).transpose(1, 2)
17+
x = self.linear(x)
18+
x = x.transpose(1, 2).reshape(batch, -1, height, width).contiguous()
19+
return x
20+
21+
22+
class SegformerDecoder(nn.Module):
23+
def __init__(
24+
self,
25+
encoder_channels,
26+
encoder_depth=5,
27+
segmentation_channels=256,
28+
):
29+
super().__init__()
30+
31+
if encoder_depth < 3:
32+
raise ValueError(
33+
"Encoder depth for Segformer decoder cannot be less than 3, got {}.".format(
34+
encoder_depth
35+
)
36+
)
37+
38+
if encoder_channels[1] == 0:
39+
encoder_channels = tuple(
40+
channel for index, channel in enumerate(encoder_channels) if index != 1
41+
)
42+
encoder_channels = encoder_channels[::-1]
43+
44+
self.mlp_stage = nn.ModuleList(
45+
[MLP(channel, segmentation_channels) for channel in encoder_channels[:-1]]
46+
)
47+
48+
self.fuse_stage = md.Conv2dReLU(
49+
in_channels=(len(encoder_channels) - 1) * segmentation_channels,
50+
out_channels=segmentation_channels,
51+
kernel_size=1,
52+
use_batchnorm=True,
53+
)
54+
55+
def forward(self, *features):
56+
# Resize all features to the size of the largest feature
57+
target_size = [dim // 4 for dim in features[0].shape[2:]]
58+
59+
features = features[2:] if features[1].size(1) == 0 else features[1:]
60+
features = features[::-1] # reverse channels to start from head of encoder
61+
62+
resized_features = []
63+
for feature, stage in zip(features, self.mlp_stage):
64+
feature = stage(feature)
65+
resized_feature = F.interpolate(
66+
feature, size=target_size, mode="bilinear", align_corners=False
67+
)
68+
resized_features.append(resized_feature)
69+
70+
output = self.fuse_stage(torch.cat(resized_features, dim=1))
71+
72+
return output
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
from typing import Any, Optional, Union, Callable
2+
3+
from segmentation_models_pytorch.base import (
4+
ClassificationHead,
5+
SegmentationHead,
6+
SegmentationModel,
7+
)
8+
from segmentation_models_pytorch.encoders import get_encoder
9+
10+
from .decoder import SegformerDecoder
11+
12+
13+
class Segformer(SegmentationModel):
14+
"""Segformer is simple and efficient design for semantic segmentation with Transformers
15+
16+
Args:
17+
encoder_name: Name of the classification model that will be used as an encoder (a.k.a backbone)
18+
to extract features of different spatial resolution
19+
encoder_depth: A number of stages used in encoder in range [3, 5]. Each stage generate features
20+
two times smaller in spatial dimensions than previous one (e.g. for depth 0 we will have features
21+
with shapes [(N, C, H, W),], for depth 1 - [(N, C, H, W), (N, C, H // 2, W // 2)] and so on).
22+
Default is 5
23+
encoder_weights: One of **None** (random initialization), **"imagenet"** (pre-training on ImageNet) and
24+
other pretrained weights (see table with available weights for each encoder_name)
25+
decoder_segmentation_channels: A number of convolution filters in segmentation blocks, default is 256
26+
in_channels: A number of input channels for the model, default is 3 (RGB images)
27+
classes: A number of classes for output mask (or you can think as a number of channels of output mask)
28+
activation: An activation function to apply after the final convolution layer.
29+
Available options are **"sigmoid"**, **"softmax"**, **"logsoftmax"**, **"tanh"**, **"identity"**,
30+
**callable** and **None**.
31+
Default is **None**
32+
aux_params: Dictionary with parameters of the auxiliary output (classification head). Auxiliary output is build
33+
on top of encoder if **aux_params** is not **None** (default). Supported params:
34+
- classes (int): A number of classes
35+
- pooling (str): One of "max", "avg". Default is "avg"
36+
- dropout (float): Dropout factor in [0, 1)
37+
- activation (str): An activation function to apply "sigmoid"/"softmax"
38+
(could be **None** to return logits)
39+
kwargs: Arguments passed to the encoder class ``__init__()`` function. Applies only to ``timm`` models. Keys with ``None`` values are pruned before passing.
40+
41+
Returns:
42+
``torch.nn.Module``: **Segformer**
43+
44+
.. _Segformer:
45+
https://arxiv.org/abs/2105.15203
46+
47+
"""
48+
49+
def __init__(
50+
self,
51+
encoder_name: str = "resnet34",
52+
encoder_depth: int = 5,
53+
encoder_weights: Optional[str] = "imagenet",
54+
decoder_segmentation_channels: int = 256,
55+
in_channels: int = 3,
56+
classes: int = 1,
57+
activation: Optional[Union[str, Callable]] = None,
58+
aux_params: Optional[dict] = None,
59+
**kwargs: dict[str, Any],
60+
):
61+
super().__init__()
62+
63+
self.encoder = get_encoder(
64+
encoder_name,
65+
in_channels=in_channels,
66+
depth=encoder_depth,
67+
weights=encoder_weights,
68+
**kwargs,
69+
)
70+
71+
self.decoder = SegformerDecoder(
72+
encoder_channels=self.encoder.out_channels,
73+
encoder_depth=encoder_depth,
74+
segmentation_channels=decoder_segmentation_channels,
75+
)
76+
77+
self.segmentation_head = SegmentationHead(
78+
in_channels=decoder_segmentation_channels,
79+
out_channels=classes,
80+
activation=activation,
81+
kernel_size=1,
82+
upsampling=4,
83+
)
84+
85+
if aux_params is not None:
86+
self.classification_head = ClassificationHead(
87+
in_channels=self.encoder.out_channels[-1], **aux_params
88+
)
89+
else:
90+
self.classification_head = None
91+
92+
self.name = "segformer-{}".format(encoder_name)
93+
self.initialize()

0 commit comments

Comments
 (0)