Skip to content

Commit 7cf6836

Browse files
committed
Cautious optimizer impl plus some typing cleanup.
1 parent aeb1ed7 commit 7cf6836

13 files changed

+526
-239
lines changed

tests/test_optim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
298298
assert isinstance(opt_info, OptimInfo)
299299

300300
lr = (1e-2,) * 4
301-
if optimizer in ('mars',):
301+
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
302302
lr = (1e-3,) * 4
303303

304304
try:

timm/optim/_optim_factory.py

Lines changed: 107 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,16 @@
55
import logging
66
from dataclasses import dataclass
77
from functools import partial
8-
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union, Protocol, Iterator
8+
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union
99
from fnmatch import fnmatch
1010
import importlib
1111

1212
import torch
1313
import torch.nn as nn
14-
import torch.optim as optim
14+
import torch.optim
1515

1616
from ._param_groups import param_groups_layer_decay, param_groups_weight_decay
17+
from ._types import ParamsT, OptimType, OptimizerCallable
1718
from .adabelief import AdaBelief
1819
from .adafactor import Adafactor
1920
from .adafactor_bv import AdafactorBigVision
@@ -39,11 +40,6 @@
3940

4041
_logger = logging.getLogger(__name__)
4142

42-
# Type variables
43-
T = TypeVar('T')
44-
Params = Union[Iterator[nn.Parameter], Iterator[Dict[str, Any]]]
45-
OptimType = TypeVar('OptimType', bound='optim.Optimizer')
46-
4743

4844
def _import_class(class_string: str) -> Type:
4945
"""Dynamically import a class from a string."""
@@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
5551
raise ImportError(f"Could not import {class_string}: {e}")
5652

5753

58-
class OptimizerCallable(Protocol):
59-
"""Protocol for optimizer constructor signatures."""
60-
61-
def __call__(self, params: Params, **kwargs) -> optim.Optimizer: ...
62-
6354

6455
@dataclass(frozen=True)
6556
class OptimInfo:
@@ -76,7 +67,7 @@ class OptimInfo:
7667
defaults: Optional default parameters for the optimizer
7768
"""
7869
name: str
79-
opt_class: Union[str, Type[optim.Optimizer]]
70+
opt_class: Union[str, OptimType]
8071
description: str = ''
8172
has_eps: bool = True
8273
has_momentum: bool = False
@@ -185,7 +176,7 @@ def get_optimizer_class(
185176
self,
186177
name_or_info: Union[str, OptimInfo],
187178
bind_defaults: bool = True,
188-
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
179+
) -> Union[OptimType, OptimizerCallable]:
189180
"""Get the optimizer class with any default arguments applied.
190181
191182
This allows direct instantiation of optimizers with their default configs
@@ -234,17 +225,17 @@ def get_optimizer_class(
234225

235226
def create_optimizer(
236227
self,
237-
model_or_params: Union[nn.Module, Params],
228+
model_or_params: Union[nn.Module, ParamsT],
238229
opt: str,
239230
lr: Optional[float] = None,
240231
weight_decay: float = 0.,
241232
momentum: float = 0.9,
242233
foreach: Optional[bool] = None,
243234
weight_decay_exclude_1d: bool = True,
244235
layer_decay: Optional[float] = None,
245-
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
236+
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
246237
**kwargs: Any,
247-
) -> optim.Optimizer:
238+
) -> torch.optim.Optimizer:
248239
"""Create an optimizer instance.
249240
250241
Args:
@@ -347,15 +338,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
347338
sgd_optimizers = [
348339
OptimInfo(
349340
name='sgd',
350-
opt_class=optim.SGD,
341+
opt_class=torch.optim.SGD,
351342
description='torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum',
352343
has_eps=False,
353344
has_momentum=True,
354345
defaults={'nesterov': True}
355346
),
356347
OptimInfo(
357348
name='momentum',
358-
opt_class=optim.SGD,
349+
opt_class=torch.optim.SGD,
359350
description='torch.Optim Stochastic Gradient Descent (SGD) with classical momentum',
360351
has_eps=False,
361352
has_momentum=True,
@@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
386377
adam_optimizers = [
387378
OptimInfo(
388379
name='adam',
389-
opt_class=optim.Adam,
380+
opt_class=torch.optim.Adam,
390381
description='torch.optim.Adam, Adaptive Moment Estimation',
391382
has_betas=True
392383
),
393384
OptimInfo(
394385
name='adamw',
395-
opt_class=optim.AdamW,
386+
opt_class=torch.optim.AdamW,
396387
description='torch.optim.AdamW, Adam with decoupled weight decay',
397388
has_betas=True
398389
),
@@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
448439
),
449440
OptimInfo(
450441
name='adamax',
451-
opt_class=optim.Adamax,
442+
opt_class=torch.optim.Adamax,
452443
description='torch.optim.Adamax, Adam with infinity norm for more stable updates',
453444
has_betas=True
454445
),
@@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
526517
registry.register(opt)
527518

528519

520+
def _register_cautious_optimizers(registry: OptimizerRegistry) -> None:
521+
cautious_optimizers = [
522+
OptimInfo(
523+
name='cadafactor',
524+
opt_class=Adafactor,
525+
description='Cautious Adafactor',
526+
defaults={'caution': True}
527+
),
528+
OptimInfo(
529+
name='cadafactorbv',
530+
opt_class=AdafactorBigVision,
531+
description='Cautious Big Vision Adafactor',
532+
defaults={'caution': True}
533+
),
534+
OptimInfo(
535+
name='cadamw',
536+
opt_class=AdamWLegacy,
537+
description='Cautious AdamW',
538+
has_betas=True,
539+
defaults={'caution': True}
540+
),
541+
OptimInfo(
542+
name='cadopt',
543+
opt_class=Adopt,
544+
description='Cautious Adopt',
545+
defaults={'caution': True}
546+
),
547+
OptimInfo(
548+
name='cadoptw',
549+
opt_class=Adopt,
550+
description='Cautious AdoptW (decoupled decay)',
551+
defaults={'decoupled': True, 'caution': True}
552+
),
553+
OptimInfo(
554+
name='clamb',
555+
opt_class=Lamb,
556+
description='Cautious LAMB',
557+
has_betas=True,
558+
defaults={'caution': True}
559+
),
560+
OptimInfo(
561+
name='claprop',
562+
opt_class=LaProp,
563+
description='Cautious LaProp',
564+
has_betas=True,
565+
defaults={'caution': True}
566+
),
567+
OptimInfo(
568+
name='clion',
569+
opt_class=Lion,
570+
description='Cautious Lion',
571+
has_eps=False,
572+
has_betas=True,
573+
defaults = {'caution': True}
574+
),
575+
OptimInfo(
576+
name='cnadamw',
577+
opt_class=NAdamW,
578+
description='Cautious NAdamW',
579+
has_betas=True,
580+
defaults={'caution': True}
581+
),
582+
OptimInfo(
583+
name='crmsproptf',
584+
opt_class=RMSpropTF,
585+
description='Cautious TensorFlow-style RMSprop',
586+
has_momentum=True,
587+
defaults={'alpha': 0.9, 'caution': True}
588+
),
589+
OptimInfo(
590+
name='csgdw',
591+
opt_class=SGDW,
592+
description='Cautious SGD with decoupled weight decay and Nesterov momentum',
593+
has_eps=False,
594+
has_momentum=True,
595+
defaults={'nesterov': True, 'caution': True}
596+
),
597+
]
598+
for opt in cautious_optimizers:
599+
registry.register(opt)
600+
529601
def _register_other_optimizers(registry: OptimizerRegistry) -> None:
530602
"""Register miscellaneous optimizers"""
531603
other_optimizers = [
@@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
545617
),
546618
OptimInfo(
547619
name='adadelta',
548-
opt_class=optim.Adadelta,
620+
opt_class=torch.optim.Adadelta,
549621
description='torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
550622
),
551623
OptimInfo(
552624
name='adagrad',
553-
opt_class=optim.Adagrad,
625+
opt_class=torch.optim.Adagrad,
554626
description='torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients',
555627
defaults={'eps': 1e-8}
556628
),
@@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
617689
),
618690
OptimInfo(
619691
name='rmsprop',
620-
opt_class=optim.RMSprop,
692+
opt_class=torch.optim.RMSprop,
621693
description='torch.optim.RMSprop, Root Mean Square Propagation',
622694
has_momentum=True,
623695
defaults={'alpha': 0.9}
@@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
765837
_register_other_optimizers(default_registry)
766838
_register_apex_optimizers(default_registry)
767839
_register_bnb_optimizers(default_registry)
840+
_register_cautious_optimizers(default_registry)
768841

769842
# Register aliases
770843
default_registry.register_alias('nesterov', 'sgd')
@@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
839912
def get_optimizer_class(
840913
name: str,
841914
bind_defaults: bool = True,
842-
) -> Union[Type[optim.Optimizer], OptimizerCallable]:
915+
) -> Union[OptimType, OptimizerCallable]:
843916
"""Get optimizer class by name with option to bind default arguments.
844917
845918
Retrieves the optimizer class or a partial function with default arguments bound.
@@ -874,17 +947,17 @@ def get_optimizer_class(
874947

875948

876949
def create_optimizer_v2(
877-
model_or_params: Union[nn.Module, Params],
950+
model_or_params: Union[nn.Module, ParamsT],
878951
opt: str = 'sgd',
879952
lr: Optional[float] = None,
880953
weight_decay: float = 0.,
881954
momentum: float = 0.9,
882955
foreach: Optional[bool] = None,
883956
filter_bias_and_bn: bool = True,
884957
layer_decay: Optional[float] = None,
885-
param_group_fn: Optional[Callable[[nn.Module], Params]] = None,
958+
param_group_fn: Optional[Callable[[nn.Module], ParamsT]] = None,
886959
**kwargs: Any,
887-
) -> optim.Optimizer:
960+
) -> torch.optim.Optimizer:
888961
"""Create an optimizer instance via timm registry.
889962
890963
Creates and configures an optimizer with appropriate parameter groups and settings.
@@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
9851058
return kwargs
9861059

9871060

988-
def create_optimizer(args, model, filter_bias_and_bn=True):
1061+
def create_optimizer(
1062+
args,
1063+
model: Union[nn.Module, ParamsT],
1064+
filter_bias_and_bn: bool = True,
1065+
) -> torch.optim.Optimizer:
9891066
""" Legacy optimizer factory for backwards compatibility.
9901067
NOTE: Use create_optimizer_v2 for new code.
9911068
"""

timm/optim/_types.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from typing import Any, Dict, Iterable, Union, Protocol, Type
2+
try:
3+
from typing import TypeAlias, TypeVar
4+
except ImportError:
5+
from typing_extensions import TypeAlias, TypeVar
6+
7+
import torch
8+
import torch.optim
9+
10+
try:
11+
from torch.optim.optimizer import ParamsT
12+
except (ImportError, TypeError):
13+
ParamsT: TypeAlias = Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]
14+
15+
16+
OptimType = Type[torch.optim.Optimizer]
17+
18+
19+
class OptimizerCallable(Protocol):
20+
"""Protocol for optimizer constructor signatures."""
21+
22+
def __call__(self, params: ParamsT, **kwargs) -> torch.optim.Optimizer: ...
23+
24+
25+
__all__ = ['ParamsT', 'OptimType', 'OptimizerCallable']

0 commit comments

Comments
 (0)