24
24
25
25
26
26
class Upsample1D (nn .Module ):
27
- """
28
- An upsampling layer with an optional convolution.
27
+ """A 1D upsampling layer with an optional convolution.
29
28
30
29
Parameters:
31
- channels: channels in the inputs and outputs.
32
- use_conv: a bool determining if a convolution is applied.
33
- use_conv_transpose:
34
- out_channels:
30
+ channels (`int`):
31
+ number of channels in the inputs and outputs.
32
+ use_conv (`bool`, default `False`):
33
+ option to use a convolution.
34
+ use_conv_transpose (`bool`, default `False`):
35
+ option to use a convolution transpose.
36
+ out_channels (`int`, optional):
37
+ number of output channels. Defaults to `channels`.
35
38
"""
36
39
37
40
def __init__ (self , channels , use_conv = False , use_conv_transpose = False , out_channels = None , name = "conv" ):
@@ -62,14 +65,17 @@ def forward(self, x):
62
65
63
66
64
67
class Downsample1D (nn .Module ):
65
- """
66
- A downsampling layer with an optional convolution.
68
+ """A 1D downsampling layer with an optional convolution.
67
69
68
70
Parameters:
69
- channels: channels in the inputs and outputs.
70
- use_conv: a bool determining if a convolution is applied.
71
- out_channels:
72
- padding:
71
+ channels (`int`):
72
+ number of channels in the inputs and outputs.
73
+ use_conv (`bool`, default `False`):
74
+ option to use a convolution.
75
+ out_channels (`int`, optional):
76
+ number of output channels. Defaults to `channels`.
77
+ padding (`int`, default `1`):
78
+ padding for the convolution.
73
79
"""
74
80
75
81
def __init__ (self , channels , use_conv = False , out_channels = None , padding = 1 , name = "conv" ):
@@ -93,14 +99,17 @@ def forward(self, x):
93
99
94
100
95
101
class Upsample2D (nn .Module ):
96
- """
97
- An upsampling layer with an optional convolution.
102
+ """A 2D upsampling layer with an optional convolution.
98
103
99
104
Parameters:
100
- channels: channels in the inputs and outputs.
101
- use_conv: a bool determining if a convolution is applied.
102
- use_conv_transpose:
103
- out_channels:
105
+ channels (`int`):
106
+ number of channels in the inputs and outputs.
107
+ use_conv (`bool`, default `False`):
108
+ option to use a convolution.
109
+ use_conv_transpose (`bool`, default `False`):
110
+ option to use a convolution transpose.
111
+ out_channels (`int`, optional):
112
+ number of output channels. Defaults to `channels`.
104
113
"""
105
114
106
115
def __init__ (self , channels , use_conv = False , use_conv_transpose = False , out_channels = None , name = "conv" ):
@@ -162,14 +171,17 @@ def forward(self, hidden_states, output_size=None):
162
171
163
172
164
173
class Downsample2D (nn .Module ):
165
- """
166
- A downsampling layer with an optional convolution.
174
+ """A 2D downsampling layer with an optional convolution.
167
175
168
176
Parameters:
169
- channels: channels in the inputs and outputs.
170
- use_conv: a bool determining if a convolution is applied.
171
- out_channels:
172
- padding:
177
+ channels (`int`):
178
+ number of channels in the inputs and outputs.
179
+ use_conv (`bool`, default `False`):
180
+ option to use a convolution.
181
+ out_channels (`int`, optional):
182
+ number of output channels. Defaults to `channels`.
183
+ padding (`int`, default `1`):
184
+ padding for the convolution.
173
185
"""
174
186
175
187
def __init__ (self , channels , use_conv = False , out_channels = None , padding = 1 , name = "conv" ):
@@ -209,6 +221,19 @@ def forward(self, hidden_states):
209
221
210
222
211
223
class FirUpsample2D (nn .Module ):
224
+ """A 2D FIR upsampling layer with an optional convolution.
225
+
226
+ Parameters:
227
+ channels (`int`):
228
+ number of channels in the inputs and outputs.
229
+ use_conv (`bool`, default `False`):
230
+ option to use a convolution.
231
+ out_channels (`int`, optional):
232
+ number of output channels. Defaults to `channels`.
233
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
234
+ kernel for the FIR filter.
235
+ """
236
+
212
237
def __init__ (self , channels = None , out_channels = None , use_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
213
238
super ().__init__ ()
214
239
out_channels = out_channels if out_channels else channels
@@ -309,6 +334,19 @@ def forward(self, hidden_states):
309
334
310
335
311
336
class FirDownsample2D (nn .Module ):
337
+ """A 2D FIR downsampling layer with an optional convolution.
338
+
339
+ Parameters:
340
+ channels (`int`):
341
+ number of channels in the inputs and outputs.
342
+ use_conv (`bool`, default `False`):
343
+ option to use a convolution.
344
+ out_channels (`int`, optional):
345
+ number of output channels. Defaults to `channels`.
346
+ fir_kernel (`tuple`, default `(1, 3, 3, 1)`):
347
+ kernel for the FIR filter.
348
+ """
349
+
312
350
def __init__ (self , channels = None , out_channels = None , use_conv = False , fir_kernel = (1 , 3 , 3 , 1 )):
313
351
super ().__init__ ()
314
352
out_channels = out_channels if out_channels else channels
0 commit comments