Skip to content

Commit ffc686b

Browse files
authored
Change the injection method of conditions on lynxnet (#225)
1 parent 8dd53f7 commit ffc686b

File tree

3 files changed

+19
-10
lines changed

3 files changed

+19
-10
lines changed

configs/templates/config_acoustic.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ backbone_args:
6363
num_channels: 1024
6464
num_layers: 6
6565
kernel_size: 31
66-
dropout_rate: 0.0
66+
dropout_rate: 0.1
6767
#backbone_type: 'wavenet'
6868
#backbone_args:
6969
# num_channels: 512

configs/templates/config_variance.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ pitch_prediction_args:
8787
# backbone_args:
8888
# num_layers: 6
8989
# num_channels: 512
90+
# dropout_rate: 0.1
9091

9192
variances_prediction_args:
9293
total_repeat_bins: 48
@@ -99,6 +100,7 @@ variances_prediction_args:
99100
# backbone_args:
100101
# num_layers: 6
101102
# num_channels: 384
103+
# dropout_rate: 0.1
102104

103105
lambda_dur_loss: 1.0
104106
lambda_pitch_loss: 1.0

modules/backbones/lynxnet.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,12 @@
1010
from utils.hparams import hparams
1111

1212

13+
class Conv1d(torch.nn.Conv1d):
14+
def __init__(self, *args, **kwargs):
15+
super().__init__(*args, **kwargs)
16+
nn.init.kaiming_normal_(self.weight)
17+
18+
1319
class SwiGLU(nn.Module):
1420
# Swish-Applies the gated linear unit function.
1521
def __init__(self, dim=-1):
@@ -39,7 +45,7 @@ def calc_same_padding(kernel_size):
3945
pad = kernel_size // 2
4046
return pad, pad - (kernel_size + 1) % 2
4147

42-
def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.):
48+
def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.1):
4349
super().__init__()
4450
inner_dim = dim * expansion_factor
4551
activation_classes = {
@@ -57,7 +63,7 @@ def __init__(self, dim, expansion_factor, kernel_size=31, activation='PReLU', dr
5763
else:
5864
_dropout = nn.Identity()
5965
self.net = nn.Sequential(
60-
nn.LayerNorm(dim),
66+
nn.LayerNorm(dim, eps=1e-6),
6167
Transpose((1, 2)),
6268
nn.Conv1d(dim, inner_dim * 2, 1),
6369
SwiGLU(dim=1),
@@ -73,16 +79,17 @@ def forward(self, x):
7379

7480

7581
class LYNXNetResidualLayer(nn.Module):
76-
def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.):
82+
def __init__(self, dim_cond, dim, expansion_factor, kernel_size=31, activation='PReLU', dropout=0.1):
7783
super().__init__()
7884
self.diffusion_projection = nn.Conv1d(dim, dim, 1)
7985
self.conditioner_projection = nn.Conv1d(dim_cond, dim, 1)
8086
self.convmodule = LYNXConvModule(dim=dim, expansion_factor=expansion_factor, kernel_size=kernel_size,
8187
activation=activation, dropout=dropout)
8288

8389
def forward(self, x, conditioner, diffusion_step):
90+
x = x + self.conditioner_projection(conditioner)
8491
res_x = x.transpose(1, 2)
85-
x = x + self.diffusion_projection(diffusion_step) + self.conditioner_projection(conditioner)
92+
x = x + self.diffusion_projection(diffusion_step)
8693
x = x.transpose(1, 2)
8794
x = self.convmodule(x) # (#batch, dim, length)
8895
x = x + res_x
@@ -93,7 +100,7 @@ def forward(self, x, conditioner, diffusion_step):
93100

94101
class LYNXNet(nn.Module):
95102
def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansion_factor=2, kernel_size=31,
96-
activation='PReLU', dropout=0.):
103+
activation='PReLU', dropout=0.1):
97104
"""
98105
LYNXNet(Linear Gated Depthwise Separable Convolution Network)
99106
TIPS:You can control the style of the generated results by modifying the 'activation',
@@ -104,7 +111,7 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
104111
super().__init__()
105112
self.in_dims = in_dims
106113
self.n_feats = n_feats
107-
self.input_projection = nn.Conv1d(in_dims * n_feats, num_channels, 1)
114+
self.input_projection = Conv1d(in_dims * n_feats, num_channels, 1)
108115
self.diffusion_embedding = nn.Sequential(
109116
SinusoidalPosEmb(num_channels),
110117
nn.Linear(num_channels, num_channels * 4),
@@ -124,8 +131,8 @@ def __init__(self, in_dims, n_feats, *, num_layers=6, num_channels=512, expansio
124131
for i in range(num_layers)
125132
]
126133
)
127-
self.norm = nn.LayerNorm(num_channels)
128-
self.output_projection = nn.Conv1d(num_channels, in_dims * n_feats, kernel_size=1)
134+
self.norm = nn.LayerNorm(num_channels, eps=1e-6)
135+
self.output_projection = Conv1d(num_channels, in_dims * n_feats, kernel_size=1)
129136
nn.init.zeros_(self.output_projection.weight)
130137

131138
def forward(self, spec, diffusion_step, cond):
@@ -142,7 +149,7 @@ def forward(self, spec, diffusion_step, cond):
142149
x = spec.flatten(start_dim=1, end_dim=2) # [B, F x M, T]
143150

144151
x = self.input_projection(x) # x [B, residual_channel, T]
145-
x = F.gelu(x)
152+
# x = F.gelu(x)
146153

147154
diffusion_step = self.diffusion_embedding(diffusion_step).unsqueeze(-1)
148155

0 commit comments

Comments
 (0)