File tree 2 files changed +11
-3
lines changed
2 files changed +11
-3
lines changed Original file line number Diff line number Diff line change @@ -141,8 +141,11 @@ def lion(
141
141
r"""Functional API that performs Lion algorithm computation.
142
142
"""
143
143
if foreach is None :
144
- # Placeholder for more complex foreach logic to be added when value is not set
145
- foreach = True
144
+ try :
145
+ # cannot do foreach if this overload doesn't exist when caution enabled
146
+ foreach = not caution or 'Scalar' in torch .ops .aten ._foreach_maximum .overloads ()
147
+ except :
148
+ foreach = False
146
149
147
150
if foreach and torch .jit .is_scripting ():
148
151
raise RuntimeError ('torch.jit.script not supported with foreach optimizers' )
Original file line number Diff line number Diff line change @@ -169,7 +169,12 @@ def nadamw(
169
169
' singleton tensors' )
170
170
171
171
if foreach is None :
172
- foreach = True
172
+ try :
173
+ # cannot do foreach if this overload doesn't exist when caution enabled
174
+ foreach = not caution or 'Scalar' in torch .ops .aten ._foreach_maximum .overloads ()
175
+ except :
176
+ foreach = False
177
+
173
178
if foreach and not torch .jit .is_scripting ():
174
179
func = _multi_tensor_nadamw
175
180
else :
You can’t perform that action at this time.
0 commit comments