Skip to content

Commit 303f769

Browse files
committed
Add cautious mars, improve test reliability by skipping grad diff for first step
1 parent 82e8677 commit 303f769

File tree

3 files changed

+69
-36
lines changed

3 files changed

+69
-36
lines changed

tests/test_optim.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def test_optim_factory(optimizer):
300300
lr = (1e-2,) * 4
301301
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
302302
lr = (1e-3,) * 4
303+
elif optimizer in ('cmars',):
304+
lr = (1e-4,) * 4
303305

304306
try:
305307
if not opt_info.second_order: # basic tests don't support second order right now

timm/optim/_optim_factory.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -572,6 +572,13 @@ def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
572572
has_betas=True,
573573
defaults = {'caution': True}
574574
),
575+
OptimInfo(
576+
name='cmars',
577+
opt_class=Mars,
578+
description='Cautious MARS',
579+
has_betas=True,
580+
defaults={'caution': True}
581+
),
575582
OptimInfo(
576583
name='cnadamw',
577584
opt_class=NAdamW,

timm/optim/mars.py

Lines changed: 60 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,38 +14,50 @@
1414
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
1515
# SPDX-License-Identifier: Apache-2.0
1616
import math
17+
from typing import Optional, Tuple
1718

1819
import torch
1920
from torch.optim.optimizer import Optimizer
2021

21-
22-
def mars_single_tensor(
23-
p,
24-
grad,
25-
exp_avg,
26-
exp_avg_sq,
27-
lr,
28-
weight_decay,
29-
beta1,
30-
beta2,
31-
last_grad,
32-
eps,
33-
step,
34-
gamma,
35-
mars_type,
36-
is_grad_2d,
37-
optimize_1d,
38-
lr_1d_factor,
39-
betas_1d,
22+
from ._types import ParamsT
23+
24+
25+
def _mars_single_tensor_step(
26+
p: torch.Tensor,
27+
grad: torch.Tensor,
28+
exp_avg: torch.Tensor,
29+
exp_avg_sq: torch.Tensor,
30+
lr: float,
31+
weight_decay: float,
32+
beta1: float,
33+
beta2: float,
34+
last_grad: torch.Tensor,
35+
eps: float,
36+
step: int,
37+
gamma: float,
38+
mars_type: str,
39+
is_grad_2d: bool,
40+
optimize_1d: bool,
41+
lr_1d_factor: bool,
42+
betas_1d: Tuple[float, float],
43+
caution: bool,
4044
):
41-
# optimize_1d: use MARS for 1d para, not: use AdamW for 1d para
45+
# optimize_1d ==> use MARS for 1d param, else use AdamW
4246
if optimize_1d or is_grad_2d:
4347
one_minus_beta1 = 1. - beta1
44-
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
45-
c_t_norm = torch.norm(c_t)
46-
if c_t_norm > 1.:
47-
c_t = c_t / c_t_norm
48+
if step == 1:
49+
# this is a timm addition, making first step more consistent when no grad history, otherwise tests fail
50+
c_t = grad
51+
else:
52+
c_t = (grad - last_grad).mul_(gamma * (beta1 / one_minus_beta1)).add_(grad)
53+
c_t_norm = torch.norm(c_t)
54+
if c_t_norm > 1.:
55+
c_t = c_t / c_t_norm
4856
exp_avg.mul_(beta1).add_(c_t, alpha=one_minus_beta1)
57+
if caution:
58+
mask = (exp_avg * grad > 0).to(grad.dtype)
59+
mask.div_(mask.mean().clamp_(min=1e-3))
60+
exp_avg = exp_avg * mask
4961
if mars_type == "adamw":
5062
exp_avg_sq.mul_(beta2).addcmul_(c_t, c_t, value=1. - beta2)
5163
bias_correction1 = 1.0 - beta1 ** step
@@ -64,6 +76,10 @@ def mars_single_tensor(
6476
bias_correction1 = 1.0 - beta1_1d ** step
6577
bias_correction2 = 1.0 - beta2_1d ** step
6678
denom = (exp_avg_sq.sqrt() / math.sqrt(bias_correction2)).add_(eps)
79+
if caution:
80+
mask = (exp_avg * grad > 0).to(grad.dtype)
81+
mask.div_(mask.mean().clamp_(min=1e-3))
82+
exp_avg = exp_avg * mask
6783
update = p * weight_decay + (exp_avg / bias_correction1).div_(denom)
6884
p.add_(update, alpha=-(lr * lr_1d_factor))
6985
return exp_avg, exp_avg_sq
@@ -78,16 +94,17 @@ class Mars(Optimizer):
7894
"""
7995
def __init__(
8096
self,
81-
params,
82-
lr=3e-3,
83-
betas=(0.9, 0.99),
84-
eps=1e-8,
85-
weight_decay=0.,
86-
gamma=0.025,
87-
mars_type="adamw",
88-
optimize_1d=False,
89-
lr_1d_factor=1.0,
90-
betas_1d=None,
97+
params: ParamsT,
98+
lr: float = 3e-3,
99+
betas: Tuple[float, float] = (0.9, 0.99),
100+
eps: float = 1e-8,
101+
weight_decay: float = 0.,
102+
gamma: float = 0.025,
103+
mars_type: str = "adamw",
104+
optimize_1d: bool = False,
105+
lr_1d_factor: float = 1.0,
106+
betas_1d: Optional[Tuple[float, float]] = None,
107+
caution: bool = False
91108
):
92109
if not 0.0 <= lr:
93110
raise ValueError("Invalid learning rate: {}".format(lr))
@@ -109,9 +126,15 @@ def __init__(
109126
optimize_1d=optimize_1d,
110127
lr_1d_factor=lr_1d_factor,
111128
betas_1d=betas_1d or betas,
129+
caution=caution,
112130
)
113131
super(Mars, self).__init__(params, defaults)
114132

133+
def __setstate__(self, state):
134+
super(Mars, self).__setstate__(state)
135+
for group in self.param_groups:
136+
group.setdefault('caution', False)
137+
115138
@torch.no_grad()
116139
def step(self, closure=None):
117140
"""Performs a single optimization step.
@@ -134,7 +157,6 @@ def step(self, closure=None):
134157
raise RuntimeError('Adam does not support sparse gradients, please consider SparseAdam instead')
135158

136159
state = self.state[p]
137-
# ('----- starting a parameter state', state.keys(), 'Length of state', len(state))
138160
# State initialization
139161
if len(state) <= 1:
140162
state['step'] = 0
@@ -155,7 +177,8 @@ def step(self, closure=None):
155177
beta1, beta2 = group['betas']
156178
is_grad_2d = grad.ndim >= 2
157179

158-
mars_single_tensor(
180+
# FIXME add multi-tensor (if usage warrants), make more standard
181+
_mars_single_tensor_step(
159182
p,
160183
grad,
161184
exp_avg,
@@ -173,6 +196,7 @@ def step(self, closure=None):
173196
optimize_1d=group['optimize_1d'],
174197
lr_1d_factor=group['lr_1d_factor'],
175198
betas_1d=group['betas_1d'],
199+
caution=group['caution'],
176200
)
177201

178202
state['last_grad'] = grad

0 commit comments

Comments
 (0)