1
- import torch
2
- import torch .nn as nn
3
1
from .decoder import UnetDecoder
4
-
5
- from ..base .model import EncoderDecoder
2
+ from ..base import EncoderDecoder
6
3
from ..encoders import get_encoder
7
4
8
5
@@ -18,7 +15,6 @@ def __init__(
18
15
activation = 'sigmoid' ,
19
16
center = False , # usefull for VGG models
20
17
):
21
-
22
18
encoder = get_encoder (
23
19
encoder_name ,
24
20
encoder_weights = encoder_weights
@@ -32,23 +28,6 @@ def __init__(
32
28
center = center ,
33
29
)
34
30
35
- # define activation function
36
- if activation == 'softmax' :
37
- activation_fn = nn .Softmax (dim = 1 )
38
- elif activation == 'sigmoid' :
39
- activation_fn = nn .Sigmoid ()
40
- else :
41
- raise ValueError ('Activation should be "sigmoid" or "softmax"' )
42
-
43
- super ().__init__ (encoder , decoder , activation_fn )
44
-
45
- if encoder_weights is not None :
46
- self .set_preprocessing_params (
47
- input_size = encoder .input_size ,
48
- input_space = encoder .input_space ,
49
- input_range = encoder .input_range ,
50
- mean = encoder .mean ,
51
- std = encoder .std ,
52
- )
31
+ super ().__init__ (encoder , decoder , activation )
53
32
54
33
self .name = 'u-{}' .format (encoder_name )
0 commit comments