-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsoundstorm.py
2176 lines (1766 loc) · 81.9 KB
/
soundstorm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import math
from random import random, randrange
from functools import wraps
from contextlib import nullcontext
from collections import namedtuple
from pathlib import Path
from tqdm import tqdm
import torch
from torch.amp import autocast
from torch import Tensor, nn, einsum
from torch.nn import Module, ModuleList
import torch.nn.functional as F
from einops import rearrange, reduce, repeat, unpack, pack
from einops.layers.torch import Rearrange, EinMix
from beartype import beartype
from beartype.door import is_bearable
from beartype.typing import Any
from spear_tts_pytorch import TextToSemantic
from audiolm_pytorch import SoundStream
from audiolm_pytorch import HubertWithKmeans, FairseqVQWav2Vec
from gateloop_transformer import SimpleGateLoopLayer as GateLoop
from attend import Attend
# helpers
def exists(val):
"""
检查一个值是否存在(即不为 None)。
参数:
val: 需要检查的值。
返回:
bool: 如果值存在(不为 None),则返回 True;否则返回 False。
"""
return val is not None
def default(val, d):
"""
如果值存在,则返回该值;否则返回默认值。
参数:
val: 需要检查的可选值。
d: 默认值。
返回:
Any: 如果 val 存在,则返回 val;否则返回 d。
"""
return val if exists(val) else d
def divisible_by(numer, denom):
"""
检查一个数是否可以被另一个数整除。
参数:
numer (int): 被除数。
denom (int): 除数。
返回:
bool: 如果 numer 可以被 denom 整除,则返回 True;否则返回 False。
"""
return (numer % denom) == 0
def calc_same_padding(kernel_size):
"""
计算用于保持输入和输出尺寸相同的填充大小。
参数:
kernel_size (int): 卷积核的大小。
返回:
Tuple[int, int]: 填充大小,通常为 (pad, pad) 或 (pad, pad - 1)。
"""
pad = kernel_size // 2
return (pad, pad - (kernel_size + 1) % 2)
def eval_decorator(fn):
"""
创建一个装饰器,用于在评估模式下运行函数,并在函数执行前后保持模型的训练状态。
参数:
fn (function): 需要装饰的函数。
返回:
function: 装饰后的函数。
"""
@wraps(fn)
def inner(model, *args, **kwargs):
was_training = model.training
model.eval()
out = fn(model, *args, **kwargs)
model.train(was_training)
return out
return inner
# sampling helpers
def top_k(logits, thres = 0.9):
"""
对输入的 logits 应用 top-k 过滤。
参数:
logits (Tensor): 输入的 logits 张量。
thres (float, 可选): 保留 top-k 的阈值,范围在 0 到 1 之间。默认值为 0.9。
返回:
Tensor: 应用 top-k 过滤后的 logits 张量。
"""
# 计算要保留的 top-k 值
k = math.ceil((1 - thres) * logits.shape[-1])
# 获取 top-k 的值和索引
val, ind = logits.topk(k, dim = -1)
# 创建一个与 logits 形状相同的张量,并用负无穷填充
probs = torch.full_like(logits, float('-inf'))
# 将 top-k 的值填充回 probs 张量
probs.scatter_(2, ind, val)
# 返回应用 top-k 过滤后的 probs 张量
return probs
def log(t, eps = 1e-10):
"""
计算输入张量的对数,并添加一个极小值以防止数值不稳定。
参数:
t (Tensor): 输入张量。
eps (float, 可选): 添加的极小值,防止 log(0) 导致的数值不稳定。默认值为 1e-10。
返回:
Tensor: 输入张量的对数结果。
"""
return torch.log(t + eps)
def gumbel_noise(t):
"""
生成与输入张量形状相同的 Gumbel 噪声。
参数:
t (Tensor): 输入张量,用于确定噪声的形状。
返回:
Tensor: 与输入张量形状相同的 Gumbel 噪声。
"""
# 生成均匀分布的噪声,范围在 0 到 1 之间
noise = torch.zeros_like(t).uniform_(0, 1)
# 应用 Gumbel 变换生成 Gumbel 噪声
return -log(-log(noise))
def gumbel_sample(t, temperature = 1., dim = -1):
"""
对输入张量应用 Gumbel-Softmax 采样。
参数:
t (Tensor): 输入张量。
temperature (float, 可选): 温度参数,控制采样的平滑程度。默认值为 1.0。
dim (int, 可选): 采样的维度。默认值为 -1。
返回:
Tensor: 采样后的张量。
"""
return ((t / max(temperature, 1e-10)) + gumbel_noise(t)).argmax(dim = dim)
# prob helpers
# 概率辅助函数
def sample_prob(prob):
"""
根据给定的概率进行采样。
参数:
prob (float): 采样概率。
返回:
bool: 如果随机数小于等于概率,则返回 True;否则返回 False。
"""
return random() < prob
def coin_flip():
"""
进行一次硬币抛掷采样(50% 的概率)。
返回:
bool: 返回 True 或 False,概率各为 50%。
"""
return sample_prob(0.5)
# tensor helpers
@beartype
def get_mask_subset_prob(
mask: Tensor,
prob: float | Tensor,
min_mask: int = 0,
min_keep_mask: int = 0
):
"""
根据给定的概率和约束条件,从输入的掩码中随机选择子集。
参数:
mask (Tensor): 输入的掩码张量,形状为 (batch_size, sequence_length)。
prob (float 或 Tensor): 掩码的概率,可以是浮点数或形状为 (batch_size,) 的张量。
min_mask (int, 可选): 每个样本的最小掩码数量。默认值为 0。
min_keep_mask (int, 可选): 每个样本的最小保留掩码数量。默认值为 0。
返回:
Tensor: 生成的子集掩码张量,形状为 (batch_size, sequence_length)。
"""
# 获取批大小、序列长度和设备信息
batch, seq, device = *mask.shape, mask.device
if isinstance(prob, Tensor):
# 如果 prob 是张量,则重塑为 (batch_size, 1)
prob = rearrange(prob, 'b -> b 1')
# 计算每个样本中需要掩码的总数量
total = mask.sum(dim = -1, keepdim = True)
# 计算每个样本中最多可以掩码的数量
max_mask = (total - min_keep_mask).clamp(min = 0)
# 计算需要掩码的数量,并确保不小于 min_mask
num_to_mask = (total * prob).long().clamp(min = min_mask)
# 确保需要掩码的数量不超过最大允许的掩码数量
num_to_mask = torch.minimum(num_to_mask, max_mask)
# 生成随机的 logits 张量,范围在 [0, 1) 之间
logits = torch.rand((batch, seq), device = device)
# 将不需要掩码的位置填充为 -1
logits = logits.masked_fill(~mask, -1)
# 对 logits 进行排序,获取排序后的索引,并转换为浮点数
randperm = logits.argsort(dim = -1).argsort(dim = -1).float()
# 计算每个样本中填充的数量
num_padding = (~mask).sum(dim = -1, keepdim = True)
# 调整排序索引以排除填充位置
randperm -= num_padding
# 根据需要掩码的数量生成子集掩码
subset_mask = randperm < num_to_mask
# 确保不需要掩码的位置保持为 False
subset_mask.masked_fill_(~mask, False)
# 返回生成的子集掩码
return subset_mask
# schedules
def linear_schedule(t):
"""
线性调度函数。
参数:
t (float): 当前时间步长,范围在 [0, 1] 之间。
返回:
float: 调度后的值,范围在 [0, 1] 之间。
"""
# 返回线性递减的值
return 1 - t
def cosine_schedule(t):
""" https://arxiv.org/abs/2202.04200 """
"""
余弦调度函数。
参数:
t (float): 当前时间步长,范围在 [0, 1] 之间。
返回:
float: 调度后的值,范围在 [0, 1] 之间。
"""
# 返回余弦递减的值
return torch.cos(t * math.pi / 2)
# rotary embedding
class RotaryEmbedding(Module):
"""
旋转位置编码(Rotary Position Embedding)模块。
该模块用于在自注意力机制中引入位置信息,通过旋转输入向量来实现。
参数:
dim (int): 输入特征的维度。
theta (float, 可选): 控制频率的缩放因子。默认值为 10000。
"""
def __init__(self, dim, theta = 10000):
super().__init__()
# 计算逆频率,用于生成旋转角度
inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim))
# 注册逆频率缓冲区,不作为模型参数保存
self.register_buffer("inv_freq", inv_freq, persistent = False)
@property
def device(self):
"""
获取当前设备信息。
返回:
torch.device: 当前设备(CPU 或 GPU)。
"""
return next(self.buffers()).device
@autocast('cuda', enabled = False)
def forward(self, seq_len):
"""
生成旋转位置编码。
参数:
seq_len (int): 序列长度。
返回:
Tensor: 旋转位置编码张量,形状为 (seq_len, dim)。
"""
# 生成位置索引张量
t = torch.arange(seq_len, device = self.device).type_as(self.inv_freq)
# 计算频率张量
freqs = torch.einsum('i , j -> i j', t, self.inv_freq)
# 将频率张量复制并拼接,以匹配输入特征的维度
freqs = torch.cat((freqs, freqs), dim = -1)
# 返回旋转位置编码
return freqs
def rotate_half(x):
"""
对输入张量的后半部分进行旋转。
参数:
x (Tensor): 输入张量,形状为 (..., dim)。
返回:
Tensor: 旋转后的张量,形状为 (..., dim)。
"""
# 将输入张量拆分为两部分
x1, x2 = x.chunk(2, dim=-1)
# 将后半部分取反并与前半部分拼接,实现旋转
return torch.cat((-x2, x1), dim=-1)
@autocast('cuda', enabled = False)
def apply_rotary_pos_emb(pos, t):
"""
应用旋转位置编码到输入张量。
参数:
pos (Tensor): 旋转位置编码张量,形状为 (seq_len, dim)。
t (Tensor): 输入张量,形状为 (batch_size, seq_len, dim)。
返回:
Tensor: 应用旋转位置编码后的张量,形状为 (batch_size, seq_len, dim)。
"""
# 应用旋转位置编码
return (t * pos.cos()) + (rotate_half(t) * pos.sin())
# t5 relative positional bias
class T5RelativePositionBias(Module):
"""
T5 相对位置偏置模块,用于在 Transformer 模型中引入相对位置信息。
该模块通过桶化(bucketing)方法将相对位置映射到不同的桶中,并学习每个桶的偏置。
参数:
scale (float, 可选): 偏置的缩放因子。默认值为 1.0。
num_buckets (int, 可选): 桶的数量。默认值为 32。
max_distance (int, 可选): 最大相对距离,超过此距离的相对位置将被映射到同一个桶中。默认值为 128。
heads (int, 可选): 注意力头的数量。默认值为 8。
"""
def __init__(
self,
scale = 1.,
num_buckets = 32,
max_distance = 128,
heads = 8
):
super().__init__()
# 保存缩放因子
self.scale = scale
# 保存桶的数量
self.num_buckets = num_buckets
# 保存最大相对距离
self.max_distance = max_distance
# 定义嵌入层,用于学习每个桶的偏置
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position,
num_buckets = 32,
max_distance = 128
):
"""
将相对位置映射到不同的桶中。
参数:
relative_position (Tensor): 输入的相对位置张量。
num_buckets (int, 可选): 桶的数量。默认值为 32。
max_distance (int, 可选): 最大相对距离。默认值为 128。
返回:
Tensor: 映射后的桶索引张量。
"""
# 初始化返回值
ret = 0
# 取相对位置的负值
n = -relative_position
# 将桶的数量减半
num_buckets //= 2
# 如果相对位置小于 0,则加上桶数量的一半
ret += (n < 0).long() * num_buckets
# 取绝对值
n = torch.abs(n)
# 计算精确桶的最大索引
max_exact = num_buckets // 2
# 判断相对位置是否小于精确桶的最大索引
is_small = n < max_exact
# 计算大于精确桶的相对位置对应的桶索引
val_if_large = max_exact + (
torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
).long()
val_if_large = torch.min(
val_if_large,
# 确保桶索引不超过最大桶索引
torch.full_like(val_if_large, num_buckets - 1)
)
# 根据相对位置的大小选择桶索引
ret += torch.where(is_small, n, val_if_large)
# 返回映射后的桶索引
return ret
@property
def device(self):
"""
获取当前设备信息。
返回:
torch.device: 当前设备(CPU 或 GPU)。
"""
return next(self.parameters()).device
def forward(self, n):
"""
前向传播方法,用于计算相对位置偏置。
参数:
n (int): 序列长度。
返回:
Tensor: 相对位置偏置张量,形状为 (heads, n, n)。
"""
# 生成位置索引张量
pos = torch.arange(n, device = self.device).long()
# 计算相对位置张量
rel_pos = rearrange(pos, 'j -> 1 j') - rearrange(pos, 'i -> i 1')
# 将相对位置映射到桶中
rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
# 获取桶对应的偏置值
values = self.relative_attention_bias(rp_bucket)
# 重塑偏置张量的形状
bias = rearrange(values, 'i j h -> h i j')
# 返回缩放后的偏置
return bias * self.scale
# conformer
class Swish(Module):
"""
Swish 激活函数模块。
Swish 是一种自门控激活函数,定义为 x * sigmoid(x)。
参数:
无
"""
def forward(self, x):
"""
前向传播方法,应用 Swish 激活函数。
参数:
x (Tensor): 输入张量。
返回:
Tensor: 应用 Swish 激活后的张量。
"""
return x * x.sigmoid()
class GLU(Module):
"""
GLU(Gated Linear Unit)模块。
GLU 是一种门控机制,通过将输入分成两部分,一部分作为门控信号,另一部分作为输出信号。
参数:
dim (int): 分割维度的索引。
"""
def __init__(self, dim):
super().__init__()
# 保存分割维度的索引
self.dim = dim
def forward(self, x):
"""
前向传播方法,应用 GLU 门控机制。
参数:
x (Tensor): 输入张量。
返回:
Tensor: 应用 GLU 门控后的张量。
"""
# 将输入张量分成两部分
out, gate = x.chunk(2, dim=self.dim)
# 应用 GLU 门控机制
return out * gate.sigmoid()
class DepthWiseConv1d(Module):
"""
深度可分离卷积1D模块。
深度可分离卷积将标准卷积分解为深度卷积和逐点卷积,从而减少参数量和计算量。
参数:
chan_in (int): 输入通道数。
chan_out (int): 输出通道数。
kernel_size (int): 卷积核大小。
padding (int): 填充大小。
"""
def __init__(self, chan_in, chan_out, kernel_size, padding):
super().__init__()
# 保存填充大小
self.padding = padding
# 定义深度可分离卷积层
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups = chan_in)
def forward(self, x, mask = None):
"""
前向传播方法,应用深度可分离卷积。
参数:
x (Tensor): 输入张量。
mask (Tensor, 可选): 掩码张量,用于掩码卷积操作。默认值为 None。
返回:
Tensor: 卷积后的张量。
"""
if exists(mask):
# 重塑掩码张量的形状
mask = rearrange(mask, 'b n -> b 1 n')
# 应用掩码
x = x.masked_fill(~mask, 0.)
# 对输入张量进行填充
x = F.pad(x, self.padding)
# 进行卷积操作
out = self.conv(x)
if exists(mask):
# 再次应用掩码
out = out.masked_fill(~mask, 0.)
# 返回卷积后的张量
return out
# attention, feedforward, and conv module
class Scale(Module):
"""
Scale 模块,用于在函数输出上应用缩放因子。
参数:
scale (float): 缩放因子。
fn (callable): 需要应用缩放因子的函数。
"""
def __init__(self, scale, fn):
super().__init__()
self.fn = fn
self.scale = scale
def forward(self, x, **kwargs):
"""
前向传播方法,应用函数并乘以缩放因子。
参数:
x (Tensor): 输入张量。
**kwargs: 传递给函数的附加关键字参数。
返回:
Tensor: 应用函数并缩放后的张量。
"""
return self.fn(x, **kwargs) * self.scale
class ChanLayerNorm(Module):
"""
ChanLayerNorm 模块,用于对每个通道进行层归一化。
参数:
dim (int): 输入张量的通道维度。
"""
def __init__(self, dim):
super().__init__()
# 定义可学习的缩放参数 gamma,形状为 (1, dim, 1)
self.gamma = nn.Parameter(torch.ones(1, dim, 1))
def forward(self, x):
"""
前向传播方法,应用通道层归一化。
参数:
x (Tensor): 输入张量。
返回:
Tensor: 归一化后的张量。
"""
# 根据数据类型设置极小值,防止数值不稳定
eps = 1e-6 if x.dtype == torch.float32 else 1e-4
# 计算每个通道的方差
var = torch.var(x, dim = 1, unbiased = False, keepdim = True)
# 计算每个通道的均值
mean = torch.mean(x, dim = 1, keepdim = True)
# 应用通道层归一化
return (x - mean) * var.clamp(min = eps).rsqrt() * self.gamma
class PreNorm(Module):
"""
PreNorm 模块,用于在函数应用之前对输入进行层归一化。
参数:
dim (int): 输入张量的特征维度。
fn (callable): 需要应用归一化的函数。
"""
def __init__(self, dim, fn):
super().__init__()
self.fn = fn
self.norm = nn.LayerNorm(dim)
def forward(self, x, **kwargs):
"""
前向传播方法,应用层归一化并调用函数。
参数:
x (Tensor): 输入张量。
**kwargs: 传递给函数的附加关键字参数。
返回:
Tensor: 应用函数后的张量。
"""
# 应用层归一化
x = self.norm(x)
return self.fn(x, **kwargs)
class Attention(Module):
"""
注意力机制模块,用于计算输入序列的注意力权重和输出。
参数:
dim (int): 输入特征的维度。
heads (int, 可选): 注意力头的数量。默认值为 8。
dim_head (int, 可选): 每个注意力头的维度。默认值为 64。
dropout (float, 可选): Dropout 概率。默认值为 0.0。
flash (bool, 可选): 是否使用 FlashAttention 优化注意力计算。默认值为 True。
"""
def __init__(
self,
dim,
heads = 8,
dim_head = 64,
dropout = 0.,
flash = True
):
super().__init__()
# 计算内部维度
inner_dim = dim_head * heads
# 保存注意力头的数量
self.heads= heads
# 计算缩放因子
self.scale = dim_head ** -0.5
# 定义 Attend 模块,用于注意力计算
self.attend = Attend(
flash = flash,
dropout = dropout
)
# 定义 Dropout 层
self.dropout = nn.Dropout(dropout)
# 定义线性变换层,用于生成查询 (Q)
self.to_q = nn.Linear(dim, inner_dim, bias = False)
# 定义线性变换层,用于生成键 (K) 和值 (V)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias = False)
# 定义线性变换层,用于生成输出
self.to_out = nn.Linear(inner_dim, dim)
def forward(
self,
x,
context = None,
mask = None,
rotary_emb = None,
attn_bias = None,
return_values = False,
value_residual = None
):
"""
前向传播方法,应用注意力机制。
参数:
x (Tensor): 输入张量。
context (Tensor, 可选): 上下文张量,用于生成键和值。默认值为 None,表示使用输入张量作为上下文。
mask (Tensor, 可选): 掩码张量,用于掩码注意力计算。默认值为 None。
rotary_emb (Tensor, 可选): 旋转位置编码,用于位置感知注意力。默认值为 None。
attn_bias (Tensor, 可选): 注意力偏置,用于调整注意力权重。默认值为 None。
return_values (bool, 可选): 是否返回值张量。默认值为 False。
value_residual (Tensor, 可选): 值残差,用于残差连接。默认值为 None。
返回:
Tensor: 注意力输出张量。如果 return_values 为 True,则返回元组 (输出张量, 值张量)。
"""
# 获取序列长度、设备信息、注意力头数量以及是否存在上下文
n, device, h, has_context = x.shape[-2], x.device, self.heads, exists(context)
# 如果没有提供上下文,则使用输入张量作为上下文
context = default(context, x)
# 生成查询 (Q)、键 (K) 和值 (V)
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim = -1))
# 重塑张量形状
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), (q, k, v))
if exists(value_residual):
# 如果存在值残差,则将其与值张量混合
v = 0.5 * (v + value_residual)
if exists(rotary_emb):
# 应用旋转位置编码到查询
q = apply_rotary_pos_emb(rotary_emb, q)
# 应用旋转位置编码到键
k = apply_rotary_pos_emb(rotary_emb, k)
# 进行注意力计算
out = self.attend(q, k, v, mask = mask, attn_bias = attn_bias)
# 重塑输出张量形状
out = rearrange(out, 'b h n d -> b n (h d)')
# 应用输出线性变换
out = self.to_out(out)
# 如果不需要返回值张量,则返回输出
if not return_values:
return out
# 如果需要返回值张量,则返回输出和值张量
return out, v
class FeedForward(Module):
"""
前馈神经网络模块,用于对输入进行非线性变换。
参数:
dim (int): 输入特征的维度。
mult (int, 可选): 隐藏层维度的乘法因子。默认值为 4。
dropout (float, 可选): Dropout 概率。默认值为 0.0。
"""
def __init__(
self,
dim,
mult = 4,
dropout = 0.
):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult), # 线性变换层,将维度从 dim 增加到 dim * mult
Swish(), # Swish 激活函数
nn.Dropout(dropout), # Dropout 层
nn.Linear(dim * mult, dim), # 线性变换层,将维度从 dim * mult 减少到 dim
nn.Dropout(dropout) # Dropout 层
)
def forward(self, x):
"""
前向传播方法,应用前馈神经网络。
参数:
x (Tensor): 输入张量。
返回:
Tensor: 应用前馈神经网络后的张量。
"""
return self.net(x)
class ConformerConvModule(Module):
"""
Conformer 卷积模块,用于在 Conformer 模型中引入卷积操作。
参数:
dim (int): 输入特征的维度。
causal (bool, 可选): 是否使用因果卷积。默认值为 False。
expansion_factor (int, 可选): 隐藏层维度的扩展因子。默认值为 2。
kernel_size (int, 可选): 卷积核大小。默认值为 31。
dropout (float, 可选): Dropout 概率。默认值为 0.0。
"""
def __init__(
self,
dim,
causal = False,
expansion_factor = 2,
kernel_size = 31,
dropout = 0.
):
super().__init__()
# 计算内部维度
inner_dim = dim * expansion_factor
# 计算填充大小
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
self.net1 = nn.Sequential(
nn.LayerNorm(dim), # 层归一化
Rearrange('b n c -> b c n'), # 重塑张量形状
nn.Conv1d(dim, inner_dim * 2, 1), # 1D 卷积层,扩展维度
GLU(dim=1) # GLU 门控机制
)
# 深度可分离卷积
self.ds_conv = DepthWiseConv1d(inner_dim, inner_dim, kernel_size = kernel_size, padding = padding)
self.net2 = nn.Sequential(
Swish(), # Swish 激活函数
ChanLayerNorm(inner_dim), # 通道层归一化
nn.Conv1d(inner_dim, dim, 1), # 1D 卷积层,恢复维度
Rearrange('b c n -> b n c'), # 重塑张量形状
nn.Dropout(dropout) # Dropout 层
)
def forward(self, x, mask = None):
"""
前向传播方法,应用 Conformer 卷积模块。
参数:
x (Tensor): 输入张量。
mask (Tensor, 可选): 掩码张量,用于掩码卷积操作。默认值为 None。
返回:
Tensor: 应用 Conformer 卷积模块后的张量。
"""
# 应用第一个网络块
x = self.net1(x)
# 应用深度可分离卷积
x = self.ds_conv(x, mask = mask)
# 应用第二个网络块
return self.net2(x)
# Conformer Block
class ConformerBlock(Module):
"""
Conformer 块模块,结合了前馈神经网络、注意力机制和卷积操作。
参数:
dim (int): 输入特征的维度。
dim_head (int, 可选): 每个注意力头的维度。默认值为 64。
heads (int, 可选): 注意力头的数量。默认值为 8。
ff_mult (int, 可选): 前馈神经网络隐藏层维度的乘法因子。默认值为 4。
conv_expansion_factor (int, 可选): 卷积模块隐藏层维度的扩展因子。默认值为 2。
conv_kernel_size (int, 可选): 卷积核大小。默认值为 31。
attn_dropout (float, 可选): 注意力机制的 Dropout 概率。默认值为 0.0。
attn_flash (bool, 可选): 是否使用 FlashAttention。默认值为 True。
ff_dropout (float, 可选): 前馈神经网络的 Dropout 概率。默认值为 0.0。
conv_dropout (float, 可选): 卷积模块的 Dropout 概率。默认值为 0.0。
conv_causal (bool, 可选): 是否使用因果卷积。默认值为 False。
use_gateloop_layers (bool, 可选): 是否使用门控循环层。默认值为 False。
"""
def __init__(
self,
*,
dim,
dim_head = 64,
heads = 8,
ff_mult = 4,
conv_expansion_factor = 2,
conv_kernel_size = 31,
attn_dropout = 0.,
attn_flash = True,
ff_dropout = 0.,
conv_dropout = 0.,
conv_causal = False,
use_gateloop_layers = False
):
super().__init__()
# 定义第一个前馈神经网络
self.ff1 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
# 定义门控循环层(如果启用)
self.gateloop = GateLoop(dim) if use_gateloop_layers else None
# 定义注意力机制
self.attn = Attention(dim = dim, dim_head = dim_head, heads = heads, dropout = attn_dropout, flash = attn_flash)
# 定义卷积模块
self.conv = ConformerConvModule(dim = dim, causal = conv_causal, expansion_factor = conv_expansion_factor, kernel_size = conv_kernel_size, dropout = conv_dropout)
# 定义第二个前馈神经网络
self.ff2 = FeedForward(dim = dim, mult = ff_mult, dropout = ff_dropout)
# 对注意力机制应用前置归一化
self.attn = PreNorm(dim, self.attn)
# 对第一个前馈神经网络应用缩放和前置归一化
self.ff1 = Scale(0.5, PreNorm(dim, self.ff1))
# 对第二个前馈神经网络应用缩放和前置归一化
self.ff2 = Scale(0.5, PreNorm(dim, self.ff2))
# 定义层归一化层
self.post_norm = nn.LayerNorm(dim)
def forward(
self,
x,
mask = None,
rotary_emb = None,
attn_bias = None,
attn_value_residual = None,
return_values = False
):
"""
前向传播方法,应用 Conformer 块。
参数:
x (Tensor): 输入张量。
mask (Tensor, 可选): 掩码张量,用于掩码注意力计算。默认值为 None。
rotary_emb (Tensor, 可选): 旋转位置编码,用于位置感知注意力。默认值为 None。
attn_bias (Tensor, 可选): 注意力偏置,用于调整注意力权重。默认值为 None。
attn_value_residual (Tensor, 可选): 注意力值残差,用于残差连接。默认值为 None。
return_values (bool, 可选): 是否返回值张量。默认值为 False。
返回:
Tensor: 应用 Conformer 块后的张量。如果 return_values 为 True,则返回元组 (输出张量, 注意力值张量)。
"""
# 应用第一个前馈神经网络并添加残差连接
x = self.ff1(x) + x
if exists(self.gateloop):
# 应用门控循环层并添加残差连接
x = self.gateloop(x) + x
# 应用注意力机制并返回值
attn_out, attn_values = self.attn(x, mask = mask, rotary_emb = rotary_emb, attn_bias = attn_bias, value_residual = attn_value_residual, return_values = True)
# 添加注意力输出残差
x = attn_out + x
# 应用卷积模块并添加残差
x = self.conv(x, mask = mask) + x
# 应用第二个前馈神经网络并添加残差
x = self.ff2(x) + x
# 应用层归一化
x = self.post_norm(x)
if not return_values:
# 如果不需要返回值张量,则返回输出
return x
# 如果需要返回值张量,则返回输出和注意力值
return x, attn_values
# Conformer
class Conformer(Module):