-
-
Notifications
You must be signed in to change notification settings - Fork 1.8k
Enable any resolution for Unet #1029
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
0cab989
b2166ea
26725de
5b75f11
6b2ca90
eb81c1f
d5a80df
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,19 +2,24 @@ | |
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
from typing import Optional, Sequence | ||
from segmentation_models_pytorch.base import modules as md | ||
|
||
|
||
class DecoderBlock(nn.Module): | ||
class UnetDecoderBlock(nn.Module): | ||
"""A decoder block in the U-Net architecture that performs upsampling and feature fusion.""" | ||
|
||
def __init__( | ||
self, | ||
in_channels, | ||
skip_channels, | ||
out_channels, | ||
use_batchnorm=True, | ||
attention_type=None, | ||
in_channels: int, | ||
skip_channels: int, | ||
out_channels: int, | ||
use_batchnorm: bool = True, | ||
attention_type: Optional[str] = None, | ||
interpolation_mode: str = "nearest", | ||
): | ||
super().__init__() | ||
self.interpolation_mode = interpolation_mode | ||
self.conv1 = md.Conv2dReLU( | ||
in_channels + skip_channels, | ||
out_channels, | ||
|
@@ -34,19 +39,31 @@ | |
) | ||
self.attention2 = md.Attention(attention_type, in_channels=out_channels) | ||
|
||
def forward(self, x, skip=None): | ||
x = F.interpolate(x, scale_factor=2, mode="nearest") | ||
if skip is not None: | ||
x = torch.cat([x, skip], dim=1) | ||
x = self.attention1(x) | ||
x = self.conv1(x) | ||
x = self.conv2(x) | ||
x = self.attention2(x) | ||
return x | ||
def forward( | ||
self, | ||
feature_map: torch.Tensor, | ||
target_height: int, | ||
target_width: int, | ||
Comment on lines
+45
to
+46
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this an intentional backwards-incompatible change? It is no longer possible to use U-Net without specifying height/width. Could we instead default to the same height/width as the input like we previously did? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, not sure I got this.. this is just for a layer, but Decoder pass height and width. Can you please specify what is broken? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, the problem was that I'm directly using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Part of the fails are related to modified decoder forward, previously input features were unpacked with There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The rest can be resolved with renaming center->add_center_block here There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I know it's easy to support the new API, although it's harder to support both. I guess the question is whether this is intentional or not. Maybe it would help to add a "backwards-incompatible" label to PRs like this. Even better would be to deprecate the old syntax before completely removing it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes that was intentional, added in PR which fix support of torch.script/compile/export. These are internal modules, so I hope it will not break too many things.. But you are right, I will add the label and I will also highlight it in release notes |
||
skip_connection: Optional[torch.Tensor] = None, | ||
) -> torch.Tensor: | ||
feature_map = F.interpolate( | ||
feature_map, | ||
size=(target_height, target_width), | ||
mode=self.interpolation_mode, | ||
) | ||
if skip_connection is not None: | ||
feature_map = torch.cat([feature_map, skip_connection], dim=1) | ||
feature_map = self.attention1(feature_map) | ||
feature_map = self.conv1(feature_map) | ||
feature_map = self.conv2(feature_map) | ||
feature_map = self.attention2(feature_map) | ||
return feature_map | ||
|
||
|
||
class UnetCenterBlock(nn.Sequential): | ||
"""Center block of the Unet decoder. Applied to the last feature map of the encoder.""" | ||
|
||
class CenterBlock(nn.Sequential): | ||
def __init__(self, in_channels, out_channels, use_batchnorm=True): | ||
def __init__(self, in_channels: int, out_channels: int, use_batchnorm: bool = True): | ||
conv1 = md.Conv2dReLU( | ||
in_channels, | ||
out_channels, | ||
|
@@ -65,14 +82,21 @@ | |
|
||
|
||
class UnetDecoder(nn.Module): | ||
"""The decoder part of the U-Net architecture. | ||
|
||
Takes encoded features from different stages of the encoder and progressively upsamples them while | ||
combining with skip connections. This helps preserve fine-grained details in the final segmentation. | ||
""" | ||
|
||
def __init__( | ||
self, | ||
encoder_channels, | ||
decoder_channels, | ||
n_blocks=5, | ||
use_batchnorm=True, | ||
attention_type=None, | ||
center=False, | ||
encoder_channels: Sequence[int], | ||
decoder_channels: Sequence[int], | ||
n_blocks: int = 5, | ||
use_batchnorm: bool = True, | ||
attention_type: Optional[str] = None, | ||
add_center_block: bool = False, | ||
interpolation_mode: str = "nearest", | ||
): | ||
super().__init__() | ||
|
||
|
@@ -94,31 +118,45 @@ | |
skip_channels = list(encoder_channels[1:]) + [0] | ||
out_channels = decoder_channels | ||
|
||
if center: | ||
self.center = CenterBlock( | ||
if add_center_block: | ||
self.center = UnetCenterBlock( | ||
head_channels, head_channels, use_batchnorm=use_batchnorm | ||
) | ||
else: | ||
self.center = nn.Identity() | ||
|
||
# combine decoder keyword arguments | ||
kwargs = dict(use_batchnorm=use_batchnorm, attention_type=attention_type) | ||
blocks = [ | ||
DecoderBlock(in_ch, skip_ch, out_ch, **kwargs) | ||
for in_ch, skip_ch, out_ch in zip(in_channels, skip_channels, out_channels) | ||
] | ||
self.blocks = nn.ModuleList(blocks) | ||
|
||
def forward(self, *features): | ||
self.blocks = nn.ModuleList() | ||
for block_in_channels, block_skip_channels, block_out_channels in zip( | ||
in_channels, skip_channels, out_channels | ||
): | ||
block = UnetDecoderBlock( | ||
block_in_channels, | ||
block_skip_channels, | ||
block_out_channels, | ||
use_batchnorm=use_batchnorm, | ||
attention_type=attention_type, | ||
interpolation_mode=interpolation_mode, | ||
) | ||
self.blocks.append(block) | ||
|
||
def forward(self, *features: torch.Tensor) -> torch.Tensor: | ||
# spatial shapes of features: [hw, hw/2, hw/4, hw/8, ...] | ||
spatial_shapes = [feature.shape[2:] for feature in features] | ||
spatial_shapes = spatial_shapes[::-1] | ||
|
||
features = features[1:] # remove first skip with same spatial resolution | ||
features = features[::-1] # reverse channels to start from head of encoder | ||
|
||
head = features[0] | ||
skips = features[1:] | ||
skip_connections = features[1:] | ||
|
||
x = self.center(head) | ||
|
||
for i, decoder_block in enumerate(self.blocks): | ||
skip = skips[i] if i < len(skips) else None | ||
x = decoder_block(x, skip) | ||
# upsample to the next spatial shape | ||
height, width = spatial_shapes[i + 1] | ||
skip_connection = skip_connections[i] if i < len(skip_connections) else None | ||
x = decoder_block(x, height, width, skip_connection=skip_connection) | ||
|
||
return x |
Uh oh!
There was an error while loading. Please reload this page.