-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathconformer.py
541 lines (445 loc) · 19.7 KB
/
conformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
import torch
from torch import nn, einsum
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
def exists(val):
"""
检查一个值是否存在(即不为None)。
Args:
val: 任意类型的值。
Returns:
bool: 如果值不为None,则返回True;否则返回False。
"""
return val is not None
def default(val, d):
"""
如果值存在(即不为None),则返回该值;否则,返回默认值。
Args:
val: 任意类型的值。
d: 默认值。
Returns:
任意类型: 如果val不为None,则返回val;否则返回d。
"""
return val if exists(val) else d
def calc_same_padding(kernel_size):
"""
计算用于保持输入输出尺寸相同的填充大小。
在卷积操作中,为了保持输入和输出的空间尺寸相同,
需要在输入的边缘进行适当的填充。
Args:
kernel_size (int): 卷积核的大小。
Returns:
tuple: 包含填充大小的元组,格式为 (pad_left, pad_right)。
"""
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
class Swish(nn.Module):
"""
Swish激活函数类。
Swish是一种自门控激活函数,定义为:f(x) = x * sigmoid(x)。
它在某些深度学习任务中比ReLU表现更好。
Args:
None
Returns:
torch.Tensor: 经过Swish激活函数处理后的张量。
"""
def forward(self, x):
return x * x.sigmoid()
class GLU(nn.Module):
"""
GLU(Gated Linear Unit)激活函数类。
GLU是一种门控机制,将输入张量分成两部分,一部分作为门控信号,另一部分作为输出信号。
Args:
dim (int): 分割维度的索引。
Returns:
torch.Tensor: 经过GLU激活函数处理后的张量。
"""
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
# 将输入张量沿指定维度分割成两部分
out, gate = x.chunk(2, dim=self.dim)
# 应用门控机制:输出 = out * sigmoid(gate)
return out * gate.sigmoid()
class DepthWiseConv1d(nn.Module):
"""
深度可分离卷积1D类。
深度可分离卷积将标准卷积分解为深度卷积和逐点卷积,
以减少计算量和参数量。
Args:
chan_in (int): 输入通道数。
chan_out (int): 输出通道数。
kernel_size (int): 卷积核的大小。
padding (int 或 tuple): 填充大小。
Returns:
torch.Tensor: 经过深度可分离卷积处理后的张量。
"""
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
self.padding = padding
# 定义深度可分离卷积层,groups=chan_in表示每个输入通道有独立的卷积核
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
def forward(self, x):
# 对输入张量进行填充
x = F.pad(x, self.padding)
# 应用深度可分离卷积
return self.conv(x)
class Scale(nn.Module):
"""
缩放层类。
该层对输入张量应用一个缩放因子。
Args:
scale (float): 缩放因子。
fn (nn.Module): 要应用的神经网络模块。
Returns:
torch.Tensor: 经过缩放处理后的张量。
"""
def __init__(self, scale, fn):
super().__init__()
# 要应用的神经网络模块
self.fn = fn
# 缩放因子
self.scale = scale
def forward(self, x, **kwargs):
# 应用神经网络模块,并乘以缩放因子
return self.fn(x, **kwargs) * self.scale
class PreNorm(nn.Module):
"""
预归一化层类。
该层在应用某个神经网络模块之前,对输入张量进行层归一化。
Args:
dim (int): 归一化的维度。
fn (nn.Module): 要应用的神经网络模块。
Returns:
torch.Tensor: 经过预归一化处理后的张量。
"""
def __init__(self, dim, fn):
super().__init__()
# 要应用的神经网络模块
self.fn = fn
# 层归一化层
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
# 对输入张量进行层归一化
x = self.norm(x)
# 应用神经网络模块
return self.fn(x, **kwargs)
class Attention(nn.Module):
"""
自注意力机制(Self-Attention)实现。
该类实现了多头自注意力机制(Multi-Head Self-Attention),并支持相对位置编码。
它可以用于各种Transformer模型中,以捕捉输入序列中不同位置之间的关系。
Args:
dim (int): 输入和输出的维度大小。
heads (int, optional): 多头注意力的头数,默认为8。
dim_head (int, optional): 每个注意力头的维度大小,默认为64。
dropout (float, optional): Dropout概率,默认为0。
max_pos_emb (int, optional): 最大位置嵌入尺寸,默认为512。
"""
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
max_pos_emb = 512
):
super().__init__()
# 计算内部维度,用于线性变换
inner_dim = dim_head * heads
# 多头注意力的头数
self.heads= heads
# 缩放因子,用于缩放注意力得分
self.scale = dim_head ** -0.5
# 线性变换层,用于计算查询(Q)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
# 线性变换层,用于计算键(K)和值(V)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
# 线性变换层,用于输出投影
self.to_out = nn.Linear(inner_dim, dim)
# 最大位置嵌入尺寸
self.max_pos_emb = max_pos_emb
# 相对位置嵌入层,用于计算相对位置编码
self.rel_pos_emb = nn.Embedding(2 * max_pos_emb + 1, dim_head)
# Dropout层,用于防止过拟合
self.dropout = nn.Dropout(dropout)
def forward(
self,
x,
context = None,
mask = None,
context_mask = None
):
"""
前向传播方法,执行多头自注意力计算。
Args:
x (torch.Tensor): 输入张量,形状为 (batch_size, sequence_length, dim)。
context (torch.Tensor, optional): 上下文张量,用于跨注意力机制。如果为None,则使用x作为上下文。
mask (torch.Tensor, optional): 输入张量的掩码,用于屏蔽某些位置。
context_mask (torch.Tensor, optional): 上下文张量的掩码,用于屏蔽某些位置。
Returns:
torch.Tensor: 经过多头自注意力处理后的输出张量,形状为 (batch_size, sequence_length, dim)。
"""
n, device, h, max_pos_emb, has_context = x.shape[-2], x.device, self.heads, self.max_pos_emb, exists(context)
# 如果没有提供上下文张量,则使用输入张量作为上下文
context = default(context, x)
# 计算查询(Q)、键(K)和值(V)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 重塑张量以适应多头注意力机制,形状变为 (batch_size, heads, sequence_length, dim_head)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
# 计算注意力得分(未缩放的)
dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
# Shaw的相对位置编码实现
seq = torch.arange(n, device = device) # 生成序列索引
# 计算距离矩阵,形状为 (sequence_length, sequence_length)
dist = rearrange(seq, 'i -> i ()') - rearrange(seq, 'j -> () j')
# 将距离限制在 [-max_pos_emb, max_pos_emb] 之间,并加上偏移量
dist = dist.clamp(-max_pos_emb, max_pos_emb) + max_pos_emb
# 获取相对位置嵌入,形状为 (2 * max_pos_emb + 1, dim_head)
rel_pos_emb = self.rel_pos_emb(dist).to(q)
# 计算相对位置注意力得分,形状为 (batch_size, heads, sequence_length, sequence_length)
pos_attn = einsum('b h n d, n r d -> b h n r', q, rel_pos_emb) * self.scale
# 将相对位置注意力得分加到原始注意力得分上
dots = dots + pos_attn
# 处理掩码
if exists(mask) or exists(context_mask):
# 如果没有提供输入掩码,则使用全1掩码
mask = default(mask, lambda: torch.ones(*x.shape[:2], device = device))
# 如果没有提供上下文掩码,则使用全1掩码或输入掩码
context_mask = default(context_mask, mask) if not has_context else default(context_mask, lambda: torch.ones(*context.shape[:2], device = device))
# 获取掩码填充值(最小浮点数)
mask_value = -torch.finfo(dots.dtype).max
# 重塑掩码以匹配注意力得分的形状
mask = rearrange(mask, 'b i -> b () i ()') * rearrange(context_mask, 'b j -> b () () j')
# 对注意力得分进行掩码填充,掩码为0的位置填充为mask_value
dots.masked_fill_(~mask, mask_value)
# 对注意力得分应用softmax函数,得到注意力权重
attn = dots.softmax(dim = -1)
# 通过注意力权重对值进行加权求和,得到输出
out = einsum('b h i j, b h j d -> b h i d', attn, v)
# 重塑输出张量的形状为 (batch_size, sequence_length, inner_dim)
out = rearrange(out, 'b h n d -> b n (h d)')
# 通过线性变换层进行输出投影
out = self.to_out(out)
# 应用Dropout正则化
return self.dropout(out)
class FeedForward(nn.Module):
"""
前馈神经网络(Feed-Forward Network, FFN)类。
该网络通常用于Transformer模型中,作为多头注意力机制后的位置前馈网络。
它由线性变换、激活函数和Dropout层组成。
Args:
dim (int): 输入和输出的维度大小。
mult (int, optional): 内部维度相对于输入维度的倍数,默认为4。
dropout (float, optional): Dropout概率,默认为0。
Returns:
torch.Tensor: 经过前馈神经网络处理后的张量。
"""
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
# 定义前馈神经网络的序列结构
self.net = nn.Sequential(
nn.Linear(dim, dim * mult), # 线性变换,扩展维度
Swish(), # 应用Swish激活函数
nn.Dropout(dropout), # 应用Dropout正则化
nn.Linear(dim * mult, dim), # 线性变换,恢复原始维度
nn.Dropout(dropout) # 应用Dropout正则化
)
def forward(self, x):
"""
前向传播方法,执行前馈神经网络的计算。
Args:
x (torch.Tensor): 输入张量,形状为 (batch_size, sequence_length, dim)。
Returns:
torch.Tensor: 经过前馈神经网络处理后的输出张量,形状为 (batch_size, sequence_length, dim)。
"""
return self.net(x)
# ===========================
# Conformer 卷积模块类(ConformerConvModule)
# ===========================
class ConformerConvModule(nn.Module):
"""
Conformer模型的卷积模块类。
该模块实现了Conformer模型中的卷积部分,用于捕捉局部特征。
它结合了层归一化、1D卷积、门控线性单元(GLU)、深度可分离卷积、批量归一化、Swish激活函数等。
Args:
dim (int): 输入和输出的维度大小。
causal (bool, optional): 是否为因果卷积,默认为False。
expansion_factor (int, optional): 扩展因子,用于扩展内部维度,默认为2。
kernel_size (int, optional): 卷积核的大小,默认为31。
dropout (float, optional): Dropout概率,默认为0。
Returns:
torch.Tensor: 经过卷积模块处理后的张量。
"""
def __init__(
self,
dim,
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.
):
super().__init__()
# 计算内部维度
inner_dim = dim * expansion_factor
# 计算填充大小,如果是因果卷积,则只填充左侧;否则进行相同填充
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
# 定义卷积模块的序列结构
self.net = nn.Sequential(
nn.LayerNorm(dim), # 对输入进行层归一化
Rearrange('b n c -> b c n'), # 重排张量形状以适应1D卷积
nn.Conv1d(dim, inner_dim * 2, 1), # 1D卷积,扩展内部维度
GLU(dim=1), # 应用GLU激活函数
DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding), # 深度可分离卷积
nn.BatchNorm1d(inner_dim) if not causal else nn.Identity(), # 如果不是因果卷积,则应用批量归一化;否则使用恒等映射
Swish(), # 应用Swish激活函数
nn.Conv1d(inner_dim, dim, 1), # 1D卷积,恢复原始维度
Rearrange('b c n -> b n c'), # 重排张量形状以恢复原始形状
nn.Dropout(dropout) # 应用Dropout正则化
)
def forward(self, x):
"""
前向传播方法,执行卷积模块的计算。
Args:
x (torch.Tensor): 输入张量,形状为 (batch_size, sequence_length, dim)。
Returns:
torch.Tensor: 经过卷积模块处理后的输出张量,形状为 (batch_size, sequence_length, dim)。
"""
return self.net(x)
class ConformerBlock(nn.Module):
"""
Conformer模块块。
该模块块结合了前馈神经网络(FFN)、多头自注意力机制(Multi-Head Self-Attention)和卷积模块(ConformerConvModule),
并通过层归一化和残差连接来增强模型的表达能力。
Args:
dim (int): 输入和输出的维度大小。
dim_head (int, optional): 每个注意力头的维度大小,默认为64。
heads (int, optional): 多头注意力的头数,默认为8。
ff_mult (int, optional): 前馈神经网络内部维度相对于输入维度的倍数,默认为4。
conv_expansion_factor (int, optional): 卷积模块的扩展因子,默认为2。
conv_kernel_size (int, optional): 卷积核的大小,默认为31。
attn_dropout (float, optional): 多头自注意力机制的Dropout概率,默认为0。
ff_dropout (float, optional): 前馈神经网络的Dropout概率,默认为0。
conv_dropout (float, optional): 卷积模块的Dropout概率,默认为0。
conv_causal (bool, optional): 是否为因果卷积,默认为False。
Returns:
torch.Tensor: 经过Conformer模块块处理后的张量。
"""
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
# 定义第一个前馈神经网络(FFN)
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
# 定义多头自注意力机制
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout)
# 定义卷积模块
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
# 定义第二个前馈神经网络(FFN)
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
# 使用预归一化(PreNorm)对注意力机制和前馈神经网络进行归一化处理
self.attn = PreNorm(dim, self.attn)
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
# 定义层归一化层,用于最后的输出归一化
self.post_norm = nn.LayerNorm(dim)
def forward(self, x, mask = None):
"""
前向传播方法,执行Conformer模块块的计算。
Args:
x (torch.Tensor): 输入张量,形状为 (batch_size, sequence_length, dim)。
mask (torch.Tensor, optional): 输入张量的掩码,用于屏蔽某些位置。
Returns:
torch.Tensor: 经过Conformer模块块处理后的输出张量,形状为 (batch_size, sequence_length, dim)。
"""
# 应用第一个前馈神经网络,并添加残差连接
x = self.ff1(x) + x
# 应用多头自注意力机制,并添加残差连接
x = self.attn(x, mask = mask) + x
# 应用卷积模块,并添加残差连接
x = self.conv(x) + x
# 应用第二个前馈神经网络,并添加残差连接
x = self.ff2(x) + x
# 应用层归一化
x = self.post_norm(x)
return x
class Conformer(nn.Module):
"""
Conformer模型类。
该模型由多个Conformer模块块堆叠而成,能够处理序列数据,如语音识别、自然语言处理等任务。
Args:
dim (int): 输入和输出的维度大小。
depth (int): Conformer模块块的堆叠深度。
dim_head (int, optional): 每个注意力头的维度大小,默认为64。
heads (int, optional): 多头注意力的头数,默认为8。
ff_mult (int, optional): 前馈神经网络内部维度相对于输入维度的倍数,默认为4。
conv_expansion_factor (int, optional): 卷积模块的扩展因子,默认为2。
conv_kernel_size (int, optional): 卷积核的大小,默认为31。
attn_dropout (float, optional): 多头自注意力机制的Dropout概率,默认为0。
ff_dropout (float, optional): 前馈神经网络的Dropout概率,默认为0。
conv_dropout (float, optional): 卷积模块的Dropout概率,默认为0。
conv_causal (bool, optional): 是否为因果卷积,默认为False。
Returns:
torch.Tensor: 经过Conformer模型处理后的张量。
"""
def __init__(
self,
dim,
*,
depth,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False
):
super().__init__()
# 设置维度大小
self.dim = dim
# 初始化模块列表,用于存储多个Conformer模块块
self.layers = nn.ModuleList([])
# 堆叠多个Conformer模块块
for _ in range(depth):
self.layers.append(ConformerBlock(
dim = dim,
dim_head = dim_head,
heads = heads,
ff_mult = ff_mult,
conv_expansion_factor = conv_expansion_factor,
conv_kernel_size = conv_kernel_size,
conv_causal = conv_causal
))
def forward(self, x):
"""
前向传播方法,执行Conformer模型的计算。
Args:
x (torch.Tensor): 输入张量,形状为 (batch_size, sequence_length, dim)。
Returns:
torch.Tensor: 经过Conformer模型处理后的输出张量,形状为 (batch_size, sequence_length, dim)。
"""
# 遍历所有Conformer模块块,并应用它们到输入张量
for block in self.layers:
x = block(x)
return x