-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtitans.py
812 lines (633 loc) · 30.2 KB
/
titans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
import math
from functools import partial
import torch
from torch import nn, Tensor
import torch.nn.functional as F
from torch.nn import Linear, Module
from torch.func import functional_call, vmap, grad
import einx
from einops import rearrange, repeat, pack, unpack
from einops.layers.torch import Rearrange, Reduce
from tensordict import TensorDict
from associative_scan import associative_scan, binary_operator, pad_at_dim
"""
ein notation:
b - batch (批次)
n - sequence (序列)
d - feature dimension (特征维度)
c - intra-chunk (块内维度)
"""
# 使用 partial 为 Linear 层创建一个不带偏置的版本
LinearNoBias = partial(Linear, bias = False)
def exists(v):
"""
检查变量是否存在(不为 None)。
参数:
v (Any): 任意变量。
返回:
bool: 如果 v 不为 None,则返回 True,否则返回 False。
"""
return v is not None
def default(v, d):
"""
如果变量存在(不为 None),则返回变量本身;否则返回默认值。
参数:
v (Any): 任意变量。
d (Any): 默认值。
返回:
Any: 如果 v 存在,则返回 v;否则返回 d。
"""
return v if exists(v) else d
def identity(t):
"""
返回输入张量本身。
参数:
t (Tensor): 输入张量。
返回:
Tensor: 输入张量。
"""
return t
def round_down_multiple(seq, mult):
"""
将序列长度向下取整到指定倍数的倍数。
参数:
seq (int): 序列长度。
mult (int): 倍数。
返回:
int: 向下取整后的序列长度。
"""
return seq // mult * mult
def round_up_multiple(seq, mult):
"""
将序列长度向上取整到指定倍数的倍数。
参数:
seq (int): 序列长度。
mult (int): 倍数。
返回:
int: 向上取整后的序列长度。
"""
return math.ceil(seq / mult) * mult
def pack_one_with_inverse(t, pattern):
"""
打包张量并返回用于解包的逆函数。
参数:
t (Tensor): 需要打包的张量。
pattern (Tuple[int, ...]): 打包模式,指定每个维度如何分割。
返回:
Tuple[Tensor, Callable]: 打包后的张量和一个用于解包的函数。
"""
packed, packed_shape = pack([t], pattern)
def inverse(out, inv_pattern = None):
"""
解包张量。
参数:
out (Tensor): 需要解包的张量。
inv_pattern (Tuple[int, ...], 可选): 解包模式,默认为 None。如果为 None,则使用默认的打包模式。
返回:
Tensor: 解包后的张量。
"""
inv_pattern = default(inv_pattern, pattern)
return unpack(out, packed_shape, inv_pattern)[0]
return packed, inverse
def softclamp_max(t, max_value):
"""
对张量进行软裁剪,限制其最大值。
参数:
t (Tensor): 输入张量。
max_value (float): 最大值。
返回:
Tensor: 软裁剪后的张量。
"""
half_max_value = max_value / 2
return ((t / half_max_value).tanh() * half_max_value) + half_max_value
def softclamp_grad_norm(t, max_value):
"""
对梯度进行软裁剪,限制其范数。
参数:
t (Tensor): 输入张量。
max_value (float): 最大范数。
返回:
Tensor: 软裁剪后的梯度。
"""
# 打包张量,以便在解包时恢复原始形状
t, inverse = pack_one_with_inverse(t, 'bn *')
# 计算梯度的范数
norm = t.norm(dim = -1, keepdim = True)
# 对范数进行软裁剪
clamped_norm = softclamp_max(norm, max_value)
# 根据范数的比例调整梯度
t = t * (clamped_norm / norm)
# 解包张量,恢复原始形状
return inverse(t)
class MultiheadRMSNorm(Module):
"""
多头RMS归一化(Multihead RMSNorm)模块。
该模块对输入张量应用RMS归一化,并使用多头参数对每个头进行缩放。
"""
def __init__(self, dim, heads):
"""
初始化多头RMS归一化模块。
参数:
dim (int): 特征维度。
heads (int): 头的数量。
"""
super().__init__()
# 初始化RMS归一化层,不使用可学习的仿射参数
self.rmsnorm = nn.RMSNorm(dim, elementwise_affine = False)
# 初始化多头缩放参数,形状为 (heads, 1, dim)
self.gamma = nn.Parameter(torch.zeros(heads, 1, dim))
def forward(self, x):
"""
前向传播方法。
参数:
x (Tensor): 输入张量,形状为 (batch_size, ..., dim)。
返回:
Tensor: 归一化并缩放后的张量,形状与输入相同。
"""
# 对输入张量应用RMS归一化
# 将多头缩放参数与归一化后的张量相加,并进行缩放
# gamma 的形状为 (heads, 1, dim),通过广播机制与 normed 对齐
return self.rmsnorm(x) * (self.gamma + 1.)
class MemoryMLP(Module):
"""
记忆多层感知机(Memory MLP)模块。
该模块由多个线性层组成,每个线性层后面跟随一个SiLU激活函数(除了第一个线性层)。
"""
def __init__(
self,
dim,
depth
):
"""
初始化记忆MLP模块。
参数:
dim (int): 输入和输出的特征维度。
depth (int): MLP的深度,即线性层的数量。
"""
super().__init__()
# 初始化参数列表,每个参数是一个线性层的权重矩阵,形状为 (dim, dim)
self.weights = nn.ParameterList([nn.Parameter(torch.randn(dim, dim)) for _ in range(depth)])
def forward(
self,
x
):
"""
前向传播方法。
参数:
x (Tensor): 输入张量,形状为 (batch_size, ..., dim)。
返回:
Tensor: MLP的输出,形状与输入相同。
"""
for ind, weight in enumerate(self.weights):
# 判断是否是第一个线性层
is_first = ind == 0
if not is_first:
# 如果不是第一个线性层,则应用SiLU激活函数
x = F.silu(x)
# 应用线性层
x = x @ weight
return x
def default_adaptive_step_transform(adaptive_step, max_lr = 1e-2):
"""
默认的自适应步长转换函数。
将自适应步长转换为学习率,范围从0到max_lr。
参数:
adaptive_step (Tensor): 自适应步长张量。
max_lr (float, 可选): 最大学习率,默认为1e-2。
返回:
Tensor: 转换后的学习率张量。
"""
return adaptive_step.sigmoid() * max_lr
def default_loss_fn(pred, target):
"""
默认的损失函数。
计算预测值与目标值之间的均方误差(MSE)。
参数:
pred (Tensor): 预测值张量。
target (Tensor): 目标值张量。
返回:
Tensor: 计算得到的损失值。
"""
return (pred - target).pow(2).mean(dim = -1)
class NeuralMemory(Module):
"""
神经记忆模块(Neural Memory Module)。
该模块实现了神经记忆机制,通过记忆模型存储和检索信息,并在训练过程中动态调整学习率和动量。
"""
def __init__(
self,
dim,
chunk_size = 1,
dim_head = None,
heads = 1,
model: Module | None = None,
store_memory_loss_fn = default_loss_fn,
adaptive_step_transform = default_adaptive_step_transform,
pre_rmsnorm = True,
post_rmsnorm = True,
max_grad_norm: float | None = None,
use_accelerated_scan = False,
default_mlp_kwargs: dict = dict(
depth = 2
)
):
"""
初始化神经记忆模块。
参数:
dim (int): 特征维度。
chunk_size (int, 可选): 块大小,默认为1。
dim_head (int, 可选): 每个注意力头的维度,默认为 None。如果为 None,则使用 `dim`。
heads (int, 可选): 注意力头的数量,默认为1。
model (Module, 可选): 记忆模型,默认为 None。如果为 None,则使用默认的 `MemoryMLP` 模型。
store_memory_loss_fn (Callable[[Tensor, Tensor], Tensor], 可选): 存储记忆时的损失函数,默认为默认的损失函数。
adaptive_step_transform (Callable[[Tensor], Tensor], 可选): 自适应步长转换函数,默认为默认的转换函数。
pre_rmsnorm (bool, 可选): 是否在存储前应用RMS归一化,默认为 True。
post_rmsnorm (bool, 可选): 是否在存储后应用RMS归一化,默认为 True。
max_grad_norm (float, 可选): 存储记忆时的最大梯度范数,默认为 None。
use_accelerated_scan (bool, 可选): 是否使用加速扫描,默认为 False。
default_mlp_kwargs (Dict[str, Any], 可选): 默认的MLP参数,默认为深度为2。
"""
super().__init__()
# 如果未指定每个头的维度,则使用特征维度
dim_head = default(dim_head, dim)
# norms
# 归一化层
# 检索前的RMS归一化
self.retrieve_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
# 存储前的RMS归一化
self.store_norm = nn.RMSNorm(dim) if pre_rmsnorm else nn.Identity()
# 存储后的多头RMS归一化
self.multihead_rmsnorm = MultiheadRMSNorm(dim_head, heads) if post_rmsnorm else nn.Identity()
# maybe multi-headed
# 多头处理
# 计算内部特征维度
dim_inner = dim_head * heads
# 保存注意力头的数量
self.heads = heads
# 将批次和头维度合并
self.split_heads = Rearrange('b n (h d) -> (b h) n d', h = heads)
# 将头和批次维度分开
self.merge_heads = Rearrange('b h n d -> b n (h d)')
# 如果有多个头,则使用线性层合并头;否则,使用恒等函数
self.combine_heads = LinearNoBias(dim_inner, dim) if heads > 1 else nn.Identity()
self.retrieve_gate = nn.Sequential(
LinearNoBias(dim, heads), # 线性层,将特征维度映射到头的数量
Rearrange('b n h -> b h n 1'), # 重塑张量形状
nn.Sigmoid() # 应用Sigmoid激活函数
) if heads > 1 else None # 如果只有一个头,则不需要门控机制
# memory mlp
# 记忆模型
if not exists(model):
# 如果未提供记忆模型,则使用默认的 `MemoryMLP` 模型
model = MemoryMLP(dim_head, **default_mlp_kwargs)
assert not exists(next(model.buffers(), None)), 'model cannot have buffers for now'
# the memory is the weights of the model
# 保存记忆模型
self.memory_model = model
# the chunk size within the paper where adaptive step, momentum, weight decay are shared
# 保存块大小
self.chunk_size = chunk_size
# prepare function for per sample gradients from model above, using torch.func
# 准备用于计算每个样本梯度的函数,使用 torch.func
def forward_and_loss(params, inputs, loss_weights, target):
# 使用记忆模型进行前向传播
pred = functional_call(self.memory_model, params, inputs)
# 计算损失,默认为均方误差
loss = self.store_memory_loss_fn(pred, target) # simple mse loss in paper - eq (12) - |M(k) - v|²
# 乘以损失权重
loss = loss * loss_weights
return loss.sum()
# 对每个样本计算梯度
self.per_sample_grad_fn = vmap(grad(forward_and_loss), in_dims = (None, 0, 0, 0))
# queries for retrieving from the model
# 查询函数,用于从模型中检索信息
self.to_queries = LinearNoBias(dim, dim_inner) # 线性层,将特征维度映射到内部特征维度
# keys and values for storing to the model
# 键和值函数,用于向模型中存储信息
self.to_keys_values = LinearNoBias(dim, dim_inner * 2) # 线性层,将特征维度映射到键和值维度
self.store_memory_loss_fn = store_memory_loss_fn # 保存存储记忆时的损失函数
# empty memory embed
# 空记忆嵌入
# 初始化空记忆嵌入为全零张量
self.empty_memory_embed = nn.Parameter(torch.zeros(dim))
# 使用正态分布初始化空记忆嵌入
nn.init.normal_(self.empty_memory_embed, std = 0.02)
# learned adaptive learning rate and momentum
# todo - explore mlp layerwise learned lr / momentum
# 学习到的自适应学习率和动量
self.to_momentum = nn.Sequential(
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size), # 对块内的特征进行平均
LinearNoBias(dim, heads), # 线性层,将特征维度映射到头的数量
Rearrange('b n h -> (b h) n 1') # 重塑张量形状
)
self.to_adaptive_step = nn.Sequential(
LinearNoBias(dim, heads), # 线性层,将特征维度映射到头的数量
Rearrange('b n h -> (b h) n') # 重塑张量形状
)
# 保存自适应步长转换函数
self.adaptive_step_transform = adaptive_step_transform
# allow for softclamp the gradient norms for storing memories
# 允许对存储记忆时的梯度范数进行软裁剪
self.max_grad_norm = max_grad_norm
# weight decay factor
# 权重衰减因子
self.to_decay_factor = nn.Sequential(
Reduce('b (n c) ... -> b n ...', 'mean', c = chunk_size), # 对块内的特征进行平均
LinearNoBias(dim, heads), # 线性层,将特征维度映射到头的数量
Rearrange('b n h -> (b h) n 1') # 重塑张量形状
)
# maybe use accelerated scan
# 是否使用加速扫描
self.use_accelerated_scan = use_accelerated_scan
def init_weights_and_momentum(self):
"""
初始化记忆模型的权重和动量。
返回:
Tuple[TensorDict, TensorDict]: 初始化的权重和动量,分别为 TensorDict 对象。
"""
# 获取记忆模型的所有参数,并将其转换为 TensorDict 对象
params = TensorDict(dict(self.memory_model.named_parameters()))
# 初始化权重为零张量
init_weights = params.clone().zero_()
# 初始化动量为零张量
init_momentum = params.clone().zero_()
# 返回初始化的权重和动量
return init_weights, init_momentum
def init_empty_memory_embed(self, batch, seq_len):
"""
初始化空记忆嵌入。
参数:
batch (int): 批次大小。
seq_len (int): 序列长度。
返回:
Tensor: 初始化后的空记忆嵌入,形状为 (batch, seq_len, dim)。
"""
# 重复空记忆嵌入,生成形状为 (batch, seq_len, dim) 的张量
return repeat(self.empty_memory_embed, 'd -> b n d', b = batch, n = seq_len)
def store_memories(
self,
seq,
past_state: tuple[dict[str, Tensor], dict[str, Tensor]]
):
"""
存储记忆并更新记忆模型的权重和动量。
参数:
seq (Tensor): 输入序列,形状为 (batch, seq_len, dim)。
past_state (Tuple[Dict[str, Tensor], Dict[str, Tensor]]): 过去的状态,包含权重和动量。
返回:
Tuple[Dict[str, Tensor], Tuple[Dict[str, Tensor], Dict[str, Tensor]]]: 更新后的权重和动量,以及新的状态。
"""
# 对输入序列应用存储前的归一化
seq = self.store_norm(seq)
# curtail sequence by multiple of the chunk size
# only a complete chunk of the sequence provides the memory for the next chunk
# 计算序列长度和块大小
seq_len, chunk_size = seq.shape[-2], self.chunk_size
# 将序列长度向下取整到块大小的倍数,确保每个块完整
round_down_seq_len = round_down_multiple(seq_len, self.chunk_size)
# 截断序列,使其长度为块大小的倍数
seq = seq[:, :round_down_seq_len]
# curr weights + past weights, in the case that the initial weights are learned
# 获取当前记忆模型的权重
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
# 将过去的状态转换为 TensorDict 对象
past_state = tuple(TensorDict(d) for d in past_state)
past_weights, past_momentum = past_state
# 将当前权重与过去权重相加
curr_weights = curr_weights + past_weights
# pack batch and sequence dimension
# 计算自适应学习率:
# 对输入序列应用自适应步长模块(to_adaptive_step),然后应用自适应步长转换函数(adaptive_step_transform)
adaptive_lr = self.to_adaptive_step(seq)
adaptive_lr = self.adaptive_step_transform(adaptive_lr)
# 计算自适应动量:
# 对输入序列应用动量模块(to_momentum),然后使用 sigmoid 函数将其值压缩到 (0, 1) 之间。
adaptive_momentum = self.to_momentum(seq).sigmoid()
# 计算权重衰减因子:
# 对输入序列应用衰减因子模块(to_decay_factor),然后使用 sigmoid 函数将其值压缩到 (0, 1) 之间。
decay_factor = self.to_decay_factor(seq).sigmoid()
# keys and values
# 分离键和值:
# 对输入序列应用键值模块(to_keys_values),然后将其在最后一个维度上分割成两部分,分别作为键和值。
keys, values = self.to_keys_values(seq).chunk(2, dim = -1)
# maybe multi head
# 处理多头:
# 对键和值应用多头重塑(split_heads),将批次和头数维度合并。
keys, values = map(self.split_heads, (keys, values))
# 获取批次大小
batch = keys.shape[0]
# take care of chunking
# 处理块:
# 将键和值在序列维度上重塑为 (batch * n, c, d),其中 c 是块内维度,d 是特征维度。
keys, values = tuple(rearrange(t, 'b (n c) d -> (b n) c d', c = self.chunk_size) for t in (keys, values))
# 重塑自适应学习率:
# 将自适应学习率重塑为 (batch * n, c),以便与键和值对齐。
adaptive_lr = rearrange(adaptive_lr, 'b (n c) -> (b n) c', c = self.chunk_size)
# get grads and extra auxiliary loss (for backwarding through qkv projection in base neural memory module)
# 计算梯度并计算辅助损失:
# 使用 per_sample_grad_fn 计算每个样本的梯度,传入当前权重、键、自适应学习率和值。
grads = self.per_sample_grad_fn(dict(curr_weights), keys, adaptive_lr, values)
grads = TensorDict(grads)
# maybe softclamp grad norm
# 如果存在最大梯度范数,则对梯度进行软裁剪
if exists(self.max_grad_norm):
grads = grads.apply(lambda t: softclamp_grad_norm(t, self.max_grad_norm))
# restore batch and sequence dimension
# 恢复批次和序列维度:
# 将梯度张量从 (batch * n, ...) 重塑为 (batch, n, ...)。
grads = grads.apply(lambda t: rearrange(t, '(b n) ... -> b n ...', b = batch))
# negative gradients, adaptive lr already applied as loss weight
# 计算惊喜(surprises):
# 将梯度取负数,因为梯度下降需要负梯度。
surprises = grads.apply(lambda t: -t)
# determine scan function
# 定义默认的关联扫描函数:
# 使用 associative_scan 和 binary_operator 对输入的 gates 和 inputs 进行扫描。
def default_associative_scan(gates, inputs):
_, outputs = associative_scan(binary_operator, (gates, inputs))
return outputs
# 如果使用加速扫描:
if self.use_accelerated_scan:
from accelerated_scan.triton import scan as triton_scan
from accelerated_scan.warp import scan as warp_scan
scan = triton_scan if seq.is_cuda else warp_scan
# 定义加速扫描函数:
# 1. 对 gates 和 inputs 进行扩展和重塑。
# 2. 对序列长度进行填充,使其为2的幂。
# 3. 调用扫描函数。
# 4. 截取填充后的结果,并恢复原始形状。
def accelerate_scan_fn(gates, inputs):
gates = gates.expand_as(inputs)
gates, inputs = tuple(rearrange(t, 'b n d -> b d n') for t in (gates, inputs))
seq_len = gates.shape[-1]
next_power_two_seq_len = 2 ** max(5, int(math.ceil(math.log2(seq_len))))
gates = F.pad(gates, (0, next_power_two_seq_len - seq_len))
inputs = F.pad(inputs, (0, next_power_two_seq_len - seq_len))
outputs = scan(gates.contiguous(), inputs.contiguous())
outputs = outputs[..., :seq_len]
outputs = rearrange(outputs, 'b d n -> b n d')
return outputs
scan_fn = accelerate_scan_fn
else:
scan_fn = default_associative_scan
# momentum + weight decay - momentum is the new contribution, as most linear RNNs have learned forgetting gates
# 计算动量和更新:
# 1. 对每个参数名和对应的惊喜(surprise)进行迭代。
# 2. 使用 pack_one_with_inverse 对惊喜进行打包,并获取逆函数。
# 3. 使用 scan_fn 计算动量。
# 4. 再次使用 scan_fn 计算更新(考虑权重衰减)。
# 5. 将更新和动量逆打包,并存储到 updates 和 next_momentum 中。
next_momentum = TensorDict()
updates = TensorDict()
for param_name, surprise in surprises.items():
surprise, inverse_pack = pack_one_with_inverse(surprise, 'b n *')
# derive momentum with associative scan - eq (10)
# 计算动量:
# 使用关联扫描函数,根据自适应动量和惊喜计算动量。
momentum = scan_fn(adaptive_momentum, surprise) # momentum is S / surprise in the paper
# use associative scan again for learned forgetting (weight decay) - eq (13)
# 计算更新:
# 使用关联扫描函数,根据权重衰减因子和动量计算更新。
update = scan_fn(1. - decay_factor, momentum) # momentum is S / surprise in the paper
updates[param_name] = inverse_pack(update)
next_momentum[param_name] = inverse_pack(momentum)
# compute the next weight per batch
# 计算每个批次的下一个权重:
# 对每个参数,获取最后一个更新,并将其添加到当前权重中。
last_update = updates.apply(lambda t: t[:, -1])
next_state = (curr_weights + last_update, next_momentum)
return updates, next_state
def retrieve_memories(
self,
seq,
past_weights: dict[str, Tensor] | None = None,
):
"""
从记忆中检索信息。
参数:
seq (Tensor): 输入序列,形状为 (batch, seq_len, dim)。
past_weights (Dict[str, Tensor], 可选): 过去的权重,默认为 None。
返回:
Tensor: 检索到的记忆,形状为 (batch, seq_len + chunk_size - 1, dim)。
"""
# 获取块大小
chunk_size = self.chunk_size
# 获取批次大小和序列长度
batch, seq_len = seq.shape[:2]
# 对输入序列应用检索前的归一化
seq = self.retrieve_norm(seq)
assert seq_len >= chunk_size
# 截取序列,从第 (chunk_size - 1) 个时间步开始
seq = seq[:, (chunk_size - 1):]
# 获取截取后的序列长度
curtailed_seq_len = seq.shape[-2]
# 计算下一个序列长度,向上取整到块大小的倍数
next_seq_len = round_up_multiple(curtailed_seq_len, chunk_size)
# 计算需要填充的长度
padding = next_seq_len - curtailed_seq_len
# 判断是否需要填充
needs_pad = padding > 0
if needs_pad:
# 如果需要填充,则在序列维度上填充,填充值为0
seq = pad_at_dim(seq, (0, padding), dim = 1)
# the parameters of the memory model stores the memories of the key / values
# when the MLP has only 1 weight matrix, it is equivalent to `kv` fast weight memories from linear attention literature (recall fetching of memories is q @ (kv)) / schmidhuber's paper
# 获取当前记忆模型的权重
curr_weights = TensorDict(dict(self.memory_model.named_parameters()))
if exists(past_weights):
# 如果存在过去权重,则将其转换为 TensorDict 对象,并断言键与当前权重一致
past_weights = TensorDict(past_weights)
assert past_weights.keys() == curr_weights.keys()
# 将当前权重与过去权重相加
curr_weights = curr_weights + past_weights
# sequence Float['b n d'] to queries
# 将序列从 Float['b n d'] 转换为查询
queries = self.to_queries(seq)
# maybe multihead
# 处理多头
queries = self.split_heads(queries)
# fetch values from memory model
# 重塑权重张量形状为 (batch * n, ...),以便与查询对齐
curr_weights = curr_weights.apply(lambda t: rearrange(t, 'b n ... -> (b n) ...'))
# 重塑查询张量形状为 (batch * n, c, d),其中 c 是块内维度
queries = rearrange(queries, 'b (n c) d -> (b n) c d', c = chunk_size)
# forward functional call
# 使用记忆模型进行前向传播,获取值
values = functional_call(self.memory_model, dict(curr_weights), queries)
# reconstitute batch dimension
# 恢复批次和头的维度,形状为 (batch, heads, n * c, d)
values = rearrange(values, '(b h n) c d -> b h (n c) d', b = batch, h = self.heads)
# 应用多头RMS归一化
values = self.multihead_rmsnorm(values)
# maybe gate
# 如果存在检索门控机制,则应用门控
if exists(self.retrieve_gate):
values = values * self.retrieve_gate(seq)
# maybe merge heads and combine
# 合并多头
values = self.merge_heads(values)
# 组合多头
values = self.combine_heads(values)
# restore, pad with empty memory embed
# 恢复填充:
# 初始化空记忆嵌入,形状为 (batch, chunk_size - 1, dim)
empty_memory_embeds = self.init_empty_memory_embed(values.shape[0], chunk_size - 1)
# 将空记忆嵌入与检索到的记忆连接起来,形状为 (batch, chunk_size, dim)
values = torch.cat((empty_memory_embeds, values), dim = -2)
if needs_pad:
# 如果之前进行了填充,则去除末尾的填充部分
values = values[:, :-padding]
# 返回检索到的记忆
return values
def forward(
self,
seq,
store_seq = None,
past_state: tuple[dict[str, Tensor], dict[str, Tensor]] | None = None,
return_next_memories = False
):
"""
前向传播方法。
该方法实现了记忆的存储、检索以及更新过程。根据输入序列和过去状态,模型可以存储新的记忆,检索现有的记忆,并返回当前或下一个记忆状态。
参数:
seq (Tensor): 输入序列,形状为 (batch, seq_len, dim)。
- `batch`: 批次大小。
- `seq_len`: 序列长度。
- `dim`: 特征维度。
store_seq (Tensor, 可选): 用于存储的序列,默认为 None。
- 如果为 None,则使用输入序列 `seq` 进行记忆存储。
past_state (Tuple[Dict[str, Tensor], Dict[str, Tensor]], 可选): 过去的状态,包含权重和动量,默认为 None。
- 第一个字典包含过去的权重。
- 第二个字典包含过去的动量。
return_next_memories (bool, 可选): 是否返回下一个记忆状态,默认为 False。
- 如果为 True,则返回更新后的权重和动量。
- 如果为 False,则仅返回检索到的记忆。
返回:
Tuple[Tensor, Optional[Tuple[Dict[str, Tensor], Dict[str, Tensor]]]]:
- 如果 `return_next_memories` 为 False,则返回检索到的记忆,形状为 (batch, seq_len + chunk_size - 1, dim)。
- 如果 `return_next_memories` 为 True,则返回一个包含检索到的记忆和下一个记忆状态的元组。
"""
# 获取输入序列的批次大小和序列长度
batch, seq_len = seq.shape[:2]
if seq_len < self.chunk_size:
# 如果序列长度小于块大小,则返回初始化后的空记忆嵌入
return self.init_empty_memory_embed(batch, seq_len)
if exists(past_state):
# 如果存在过去状态,则将其转换为 TensorDict 对象
past_state = tuple(TensorDict(d) for d in past_state)
if not exists(past_state):
# 如果不存在过去状态,则初始化权重和动量
past_state = self.init_weights_and_momentum()
# 如果未提供存储序列,则使用输入序列
store_seq = default(store_seq, seq)
# 存储记忆并获取更新和下一个记忆状态
updates, next_memories = self.store_memories(store_seq, past_state)
# 获取过去的权重
past_weights, _ = past_state
# 检索记忆:使用过去的权重和更新进行检索
retrieved = self.retrieve_memories(seq, past_weights + updates)
if not return_next_memories:
# 如果不需要返回下一个记忆状态,则返回检索到的记忆
return retrieved
# 如果需要返回下一个记忆状态,则返回检索到的记忆和下一个记忆状态
return retrieved, next_memories