Skip to content

Commit 9882d42

Browse files
committed
Refactor model subclassing
1 parent 110dfd5 commit 9882d42

File tree

1 file changed

+2
-23
lines changed
  • segmentation_models_pytorch/unet

1 file changed

+2
-23
lines changed
+2-23
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
import torch
2-
import torch.nn as nn
31
from .decoder import UnetDecoder
4-
5-
from ..base.model import EncoderDecoder
2+
from ..base import EncoderDecoder
63
from ..encoders import get_encoder
74

85

@@ -18,7 +15,6 @@ def __init__(
1815
activation='sigmoid',
1916
center=False, # usefull for VGG models
2017
):
21-
2218
encoder = get_encoder(
2319
encoder_name,
2420
encoder_weights=encoder_weights
@@ -32,23 +28,6 @@ def __init__(
3228
center=center,
3329
)
3430

35-
# define activation function
36-
if activation == 'softmax':
37-
activation_fn = nn.Softmax(dim=1)
38-
elif activation == 'sigmoid':
39-
activation_fn = nn.Sigmoid()
40-
else:
41-
raise ValueError('Activation should be "sigmoid" or "softmax"')
42-
43-
super().__init__(encoder, decoder, activation_fn)
44-
45-
if encoder_weights is not None:
46-
self.set_preprocessing_params(
47-
input_size=encoder.input_size,
48-
input_space=encoder.input_space,
49-
input_range=encoder.input_range,
50-
mean=encoder.mean,
51-
std=encoder.std,
52-
)
31+
super().__init__(encoder, decoder, activation)
5332

5433
self.name = 'u-{}'.format(encoder_name)

0 commit comments

Comments
 (0)