14
14
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
15
15
# SPDX-License-Identifier: Apache-2.0
16
16
import math
17
+ from typing import Optional , Tuple
17
18
18
19
import torch
19
20
from torch .optim .optimizer import Optimizer
20
21
21
-
22
- def mars_single_tensor (
23
- p ,
24
- grad ,
25
- exp_avg ,
26
- exp_avg_sq ,
27
- lr ,
28
- weight_decay ,
29
- beta1 ,
30
- beta2 ,
31
- last_grad ,
32
- eps ,
33
- step ,
34
- gamma ,
35
- mars_type ,
36
- is_grad_2d ,
37
- optimize_1d ,
38
- lr_1d_factor ,
39
- betas_1d ,
22
+ from ._types import ParamsT
23
+
24
+
25
+ def _mars_single_tensor_step (
26
+ p : torch .Tensor ,
27
+ grad : torch .Tensor ,
28
+ exp_avg : torch .Tensor ,
29
+ exp_avg_sq : torch .Tensor ,
30
+ lr : float ,
31
+ weight_decay : float ,
32
+ beta1 : float ,
33
+ beta2 : float ,
34
+ last_grad : torch .Tensor ,
35
+ eps : float ,
36
+ step : int ,
37
+ gamma : float ,
38
+ mars_type : str ,
39
+ is_grad_2d : bool ,
40
+ optimize_1d : bool ,
41
+ lr_1d_factor : bool ,
42
+ betas_1d : Tuple [float , float ],
43
+ caution : bool ,
40
44
):
41
- # optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
45
+ # optimize_1d ==> use MARS for 1d param, else use AdamW
42
46
if optimize_1d or is_grad_2d :
43
47
one_minus_beta1 = 1. - beta1
44
- c_t = (grad - last_grad ).mul_ (gamma * (beta1 / one_minus_beta1 )).add_ (grad )
45
- c_t_norm = torch .norm (c_t )
46
- if c_t_norm > 1. :
47
- c_t = c_t / c_t_norm
48
+ if step == 1 :
49
+ # this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
50
+ c_t = grad
51
+ else :
52
+ c_t = (grad - last_grad ).mul_ (gamma * (beta1 / one_minus_beta1 )).add_ (grad )
53
+ c_t_norm = torch .norm (c_t )
54
+ if c_t_norm > 1. :
55
+ c_t = c_t / c_t_norm
48
56
exp_avg .mul_ (beta1 ).add_ (c_t , alpha = one_minus_beta1 )
57
+ if caution :
58
+ mask = (exp_avg * grad > 0 ).to (grad .dtype )
59
+ mask .div_ (mask .mean ().clamp_ (min = 1e-3 ))
60
+ exp_avg = exp_avg * mask
49
61
if mars_type == "adamw" :
50
62
exp_avg_sq .mul_ (beta2 ).addcmul_ (c_t , c_t , value = 1. - beta2 )
51
63
bias_correction1 = 1.0 - beta1 ** step
@@ -64,6 +76,10 @@ def mars_single_tensor(
64
76
bias_correction1 = 1.0 - beta1_1d ** step
65
77
bias_correction2 = 1.0 - beta2_1d ** step
66
78
denom = (exp_avg_sq .sqrt () / math .sqrt (bias_correction2 )).add_ (eps )
79
+ if caution :
80
+ mask = (exp_avg * grad > 0 ).to (grad .dtype )
81
+ mask .div_ (mask .mean ().clamp_ (min = 1e-3 ))
82
+ exp_avg = exp_avg * mask
67
83
update = p * weight_decay + (exp_avg / bias_correction1 ).div_ (denom )
68
84
p .add_ (update , alpha = - (lr * lr_1d_factor ))
69
85
return exp_avg , exp_avg_sq
@@ -78,16 +94,17 @@ class Mars(Optimizer):
78
94
"""
79
95
def __init__ (
80
96
self ,
81
- params ,
82
- lr = 3e-3 ,
83
- betas = (0.9 , 0.99 ),
84
- eps = 1e-8 ,
85
- weight_decay = 0. ,
86
- gamma = 0.025 ,
87
- mars_type = "adamw" ,
88
- optimize_1d = False ,
89
- lr_1d_factor = 1.0 ,
90
- betas_1d = None ,
97
+ params : ParamsT ,
98
+ lr : float = 3e-3 ,
99
+ betas : Tuple [float , float ] = (0.9 , 0.99 ),
100
+ eps : float = 1e-8 ,
101
+ weight_decay : float = 0. ,
102
+ gamma : float = 0.025 ,
103
+ mars_type : str = "adamw" ,
104
+ optimize_1d : bool = False ,
105
+ lr_1d_factor : float = 1.0 ,
106
+ betas_1d : Optional [Tuple [float , float ]] = None ,
107
+ caution : bool = False
91
108
):
92
109
if not 0.0 <= lr :
93
110
raise ValueError ("Invalid learning rate: {}" .format (lr ))
@@ -109,9 +126,15 @@ def __init__(
109
126
optimize_1d = optimize_1d ,
110
127
lr_1d_factor = lr_1d_factor ,
111
128
betas_1d = betas_1d or betas ,
129
+ caution = caution ,
112
130
)
113
131
super (Mars , self ).__init__ (params , defaults )
114
132
133
+ def __setstate__ (self , state ):
134
+ super (Mars , self ).__setstate__ (state )
135
+ for group in self .param_groups :
136
+ group .setdefault ('caution' , False )
137
+
115
138
@torch .no_grad ()
116
139
def step (self , closure = None ):
117
140
"""Performs a single optimization step.
@@ -134,7 +157,6 @@ def step(self, closure=None):
134
157
raise RuntimeError ('Adam does not support sparse gradients, please consider SparseAdam instead' )
135
158
136
159
state = self .state [p ]
137
- # ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
138
160
# State initialization
139
161
if len (state ) <= 1 :
140
162
state ['step' ] = 0
@@ -155,7 +177,8 @@ def step(self, closure=None):
155
177
beta1 , beta2 = group ['betas' ]
156
178
is_grad_2d = grad .ndim >= 2
157
179
158
- mars_single_tensor (
180
+ # FIXME add multi-tensor (if usage warrants), make more standard
181
+ _mars_single_tensor_step (
159
182
p ,
160
183
grad ,
161
184
exp_avg ,
@@ -173,6 +196,7 @@ def step(self, closure=None):
173
196
optimize_1d = group ['optimize_1d' ],
174
197
lr_1d_factor = group ['lr_1d_factor' ],
175
198
betas_1d = group ['betas_1d' ],
199
+ caution = group ['caution' ],
176
200
)
177
201
178
202
state ['last_grad' ] = grad
0 commit comments