5
5
import logging
6
6
from dataclasses import dataclass
7
7
from functools import partial
8
- from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , TypeVar , Union , Protocol , Iterator
8
+ from typing import Any , Callable , Dict , List , Optional , Set , Tuple , Type , Union
9
9
from fnmatch import fnmatch
10
10
import importlib
11
11
12
12
import torch
13
13
import torch .nn as nn
14
- import torch .optim as optim
14
+ import torch .optim
15
15
16
16
from ._param_groups import param_groups_layer_decay , param_groups_weight_decay
17
+ from ._types import ParamsT , OptimType , OptimizerCallable
17
18
from .adabelief import AdaBelief
18
19
from .adafactor import Adafactor
19
20
from .adafactor_bv import AdafactorBigVision
39
40
40
41
_logger = logging .getLogger (__name__ )
41
42
42
- # Type variables
43
- T = TypeVar ('T' )
44
- Params = Union [Iterator [nn .Parameter ], Iterator [Dict [str , Any ]]]
45
- OptimType = TypeVar ('OptimType' , bound = 'optim.Optimizer' )
46
-
47
43
48
44
def _import_class (class_string : str ) -> Type :
49
45
"""Dynamically import a class from a string."""
@@ -55,11 +51,6 @@ def _import_class(class_string: str) -> Type:
55
51
raise ImportError (f"Could not import { class_string } : { e } " )
56
52
57
53
58
- class OptimizerCallable (Protocol ):
59
- """Protocol for optimizer constructor signatures."""
60
-
61
- def __call__ (self , params : Params , ** kwargs ) -> optim .Optimizer : ...
62
-
63
54
64
55
@dataclass (frozen = True )
65
56
class OptimInfo :
@@ -76,7 +67,7 @@ class OptimInfo:
76
67
defaults: Optional default parameters for the optimizer
77
68
"""
78
69
name : str
79
- opt_class : Union [str , Type [ optim . Optimizer ] ]
70
+ opt_class : Union [str , OptimType ]
80
71
description : str = ''
81
72
has_eps : bool = True
82
73
has_momentum : bool = False
@@ -185,7 +176,7 @@ def get_optimizer_class(
185
176
self ,
186
177
name_or_info : Union [str , OptimInfo ],
187
178
bind_defaults : bool = True ,
188
- ) -> Union [Type [ optim . Optimizer ] , OptimizerCallable ]:
179
+ ) -> Union [OptimType , OptimizerCallable ]:
189
180
"""Get the optimizer class with any default arguments applied.
190
181
191
182
This allows direct instantiation of optimizers with their default configs
@@ -234,17 +225,17 @@ def get_optimizer_class(
234
225
235
226
def create_optimizer (
236
227
self ,
237
- model_or_params : Union [nn .Module , Params ],
228
+ model_or_params : Union [nn .Module , ParamsT ],
238
229
opt : str ,
239
230
lr : Optional [float ] = None ,
240
231
weight_decay : float = 0. ,
241
232
momentum : float = 0.9 ,
242
233
foreach : Optional [bool ] = None ,
243
234
weight_decay_exclude_1d : bool = True ,
244
235
layer_decay : Optional [float ] = None ,
245
- param_group_fn : Optional [Callable [[nn .Module ], Params ]] = None ,
236
+ param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
246
237
** kwargs : Any ,
247
- ) -> optim .Optimizer :
238
+ ) -> torch . optim .Optimizer :
248
239
"""Create an optimizer instance.
249
240
250
241
Args:
@@ -347,15 +338,15 @@ def _register_sgd_variants(registry: OptimizerRegistry) -> None:
347
338
sgd_optimizers = [
348
339
OptimInfo (
349
340
name = 'sgd' ,
350
- opt_class = optim .SGD ,
341
+ opt_class = torch . optim .SGD ,
351
342
description = 'torch.Optim Stochastic Gradient Descent (SGD) with Nesterov momentum' ,
352
343
has_eps = False ,
353
344
has_momentum = True ,
354
345
defaults = {'nesterov' : True }
355
346
),
356
347
OptimInfo (
357
348
name = 'momentum' ,
358
- opt_class = optim .SGD ,
349
+ opt_class = torch . optim .SGD ,
359
350
description = 'torch.Optim Stochastic Gradient Descent (SGD) with classical momentum' ,
360
351
has_eps = False ,
361
352
has_momentum = True ,
@@ -386,13 +377,13 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
386
377
adam_optimizers = [
387
378
OptimInfo (
388
379
name = 'adam' ,
389
- opt_class = optim .Adam ,
380
+ opt_class = torch . optim .Adam ,
390
381
description = 'torch.optim.Adam, Adaptive Moment Estimation' ,
391
382
has_betas = True
392
383
),
393
384
OptimInfo (
394
385
name = 'adamw' ,
395
- opt_class = optim .AdamW ,
386
+ opt_class = torch . optim .AdamW ,
396
387
description = 'torch.optim.AdamW, Adam with decoupled weight decay' ,
397
388
has_betas = True
398
389
),
@@ -448,7 +439,7 @@ def _register_adam_variants(registry: OptimizerRegistry) -> None:
448
439
),
449
440
OptimInfo (
450
441
name = 'adamax' ,
451
- opt_class = optim .Adamax ,
442
+ opt_class = torch . optim .Adamax ,
452
443
description = 'torch.optim.Adamax, Adam with infinity norm for more stable updates' ,
453
444
has_betas = True
454
445
),
@@ -526,6 +517,87 @@ def _register_lamb_lars(registry: OptimizerRegistry) -> None:
526
517
registry .register (opt )
527
518
528
519
520
+ def _register_cautious_optimizers (registry : OptimizerRegistry ) -> None :
521
+ cautious_optimizers = [
522
+ OptimInfo (
523
+ name = 'cadafactor' ,
524
+ opt_class = Adafactor ,
525
+ description = 'Cautious Adafactor' ,
526
+ defaults = {'caution' : True }
527
+ ),
528
+ OptimInfo (
529
+ name = 'cadafactorbv' ,
530
+ opt_class = AdafactorBigVision ,
531
+ description = 'Cautious Big Vision Adafactor' ,
532
+ defaults = {'caution' : True }
533
+ ),
534
+ OptimInfo (
535
+ name = 'cadamw' ,
536
+ opt_class = AdamWLegacy ,
537
+ description = 'Cautious AdamW' ,
538
+ has_betas = True ,
539
+ defaults = {'caution' : True }
540
+ ),
541
+ OptimInfo (
542
+ name = 'cadopt' ,
543
+ opt_class = Adopt ,
544
+ description = 'Cautious Adopt' ,
545
+ defaults = {'caution' : True }
546
+ ),
547
+ OptimInfo (
548
+ name = 'cadoptw' ,
549
+ opt_class = Adopt ,
550
+ description = 'Cautious AdoptW (decoupled decay)' ,
551
+ defaults = {'decoupled' : True , 'caution' : True }
552
+ ),
553
+ OptimInfo (
554
+ name = 'clamb' ,
555
+ opt_class = Lamb ,
556
+ description = 'Cautious LAMB' ,
557
+ has_betas = True ,
558
+ defaults = {'caution' : True }
559
+ ),
560
+ OptimInfo (
561
+ name = 'claprop' ,
562
+ opt_class = LaProp ,
563
+ description = 'Cautious LaProp' ,
564
+ has_betas = True ,
565
+ defaults = {'caution' : True }
566
+ ),
567
+ OptimInfo (
568
+ name = 'clion' ,
569
+ opt_class = Lion ,
570
+ description = 'Cautious Lion' ,
571
+ has_eps = False ,
572
+ has_betas = True ,
573
+ defaults = {'caution' : True }
574
+ ),
575
+ OptimInfo (
576
+ name = 'cnadamw' ,
577
+ opt_class = NAdamW ,
578
+ description = 'Cautious NAdamW' ,
579
+ has_betas = True ,
580
+ defaults = {'caution' : True }
581
+ ),
582
+ OptimInfo (
583
+ name = 'crmsproptf' ,
584
+ opt_class = RMSpropTF ,
585
+ description = 'Cautious TensorFlow-style RMSprop' ,
586
+ has_momentum = True ,
587
+ defaults = {'alpha' : 0.9 , 'caution' : True }
588
+ ),
589
+ OptimInfo (
590
+ name = 'csgdw' ,
591
+ opt_class = SGDW ,
592
+ description = 'Cautious SGD with decoupled weight decay and Nesterov momentum' ,
593
+ has_eps = False ,
594
+ has_momentum = True ,
595
+ defaults = {'nesterov' : True , 'caution' : True }
596
+ ),
597
+ ]
598
+ for opt in cautious_optimizers :
599
+ registry .register (opt )
600
+
529
601
def _register_other_optimizers (registry : OptimizerRegistry ) -> None :
530
602
"""Register miscellaneous optimizers"""
531
603
other_optimizers = [
@@ -545,12 +617,12 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
545
617
),
546
618
OptimInfo (
547
619
name = 'adadelta' ,
548
- opt_class = optim .Adadelta ,
620
+ opt_class = torch . optim .Adadelta ,
549
621
description = 'torch.optim.Adadelta, Adapts learning rates based on running windows of gradients'
550
622
),
551
623
OptimInfo (
552
624
name = 'adagrad' ,
553
- opt_class = optim .Adagrad ,
625
+ opt_class = torch . optim .Adagrad ,
554
626
description = 'torch.optim.Adagrad, Adapts learning rates using cumulative squared gradients' ,
555
627
defaults = {'eps' : 1e-8 }
556
628
),
@@ -617,7 +689,7 @@ def _register_other_optimizers(registry: OptimizerRegistry) -> None:
617
689
),
618
690
OptimInfo (
619
691
name = 'rmsprop' ,
620
- opt_class = optim .RMSprop ,
692
+ opt_class = torch . optim .RMSprop ,
621
693
description = 'torch.optim.RMSprop, Root Mean Square Propagation' ,
622
694
has_momentum = True ,
623
695
defaults = {'alpha' : 0.9 }
@@ -765,6 +837,7 @@ def _register_default_optimizers() -> None:
765
837
_register_other_optimizers (default_registry )
766
838
_register_apex_optimizers (default_registry )
767
839
_register_bnb_optimizers (default_registry )
840
+ _register_cautious_optimizers (default_registry )
768
841
769
842
# Register aliases
770
843
default_registry .register_alias ('nesterov' , 'sgd' )
@@ -839,7 +912,7 @@ def get_optimizer_info(name: str) -> OptimInfo:
839
912
def get_optimizer_class (
840
913
name : str ,
841
914
bind_defaults : bool = True ,
842
- ) -> Union [Type [ optim . Optimizer ] , OptimizerCallable ]:
915
+ ) -> Union [OptimType , OptimizerCallable ]:
843
916
"""Get optimizer class by name with option to bind default arguments.
844
917
845
918
Retrieves the optimizer class or a partial function with default arguments bound.
@@ -874,17 +947,17 @@ def get_optimizer_class(
874
947
875
948
876
949
def create_optimizer_v2 (
877
- model_or_params : Union [nn .Module , Params ],
950
+ model_or_params : Union [nn .Module , ParamsT ],
878
951
opt : str = 'sgd' ,
879
952
lr : Optional [float ] = None ,
880
953
weight_decay : float = 0. ,
881
954
momentum : float = 0.9 ,
882
955
foreach : Optional [bool ] = None ,
883
956
filter_bias_and_bn : bool = True ,
884
957
layer_decay : Optional [float ] = None ,
885
- param_group_fn : Optional [Callable [[nn .Module ], Params ]] = None ,
958
+ param_group_fn : Optional [Callable [[nn .Module ], ParamsT ]] = None ,
886
959
** kwargs : Any ,
887
- ) -> optim .Optimizer :
960
+ ) -> torch . optim .Optimizer :
888
961
"""Create an optimizer instance via timm registry.
889
962
890
963
Creates and configures an optimizer with appropriate parameter groups and settings.
@@ -985,7 +1058,11 @@ def optimizer_kwargs(cfg):
985
1058
return kwargs
986
1059
987
1060
988
- def create_optimizer (args , model , filter_bias_and_bn = True ):
1061
+ def create_optimizer (
1062
+ args ,
1063
+ model : Union [nn .Module , ParamsT ],
1064
+ filter_bias_and_bn : bool = True ,
1065
+ ) -> torch .optim .Optimizer :
989
1066
""" Legacy optimizer factory for backwards compatibility.
990
1067
NOTE: Use create_optimizer_v2 for new code.
991
1068
"""
0 commit comments