Skip to content

Commit 7c32d3b

Browse files
committed
Work around _foreach_maximum issue, need scalar other support
1 parent 7cf6836 commit 7c32d3b

File tree

2 files changed

+11
-3
lines changed

2 files changed

+11
-3
lines changed

timm/optim/lion.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -141,8 +141,11 @@ def lion(
141141
r"""Functional API that performs Lion algorithm computation.
142142
"""
143143
if foreach is None:
144-
# Placeholder for more complex foreach logic to be added when value is not set
145-
foreach = True
144+
try:
145+
# cannot do foreach if this overload doesn't exist when caution enabled
146+
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads()
147+
except:
148+
foreach = False
146149

147150
if foreach and torch.jit.is_scripting():
148151
raise RuntimeError('torch.jit.script not supported with foreach optimizers')

timm/optim/nadamw.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,12 @@ def nadamw(
169169
' singleton tensors')
170170

171171
if foreach is None:
172-
foreach = True
172+
try:
173+
# cannot do foreach if this overload doesn't exist when caution enabled
174+
foreach = not caution or 'Scalar' in torch.ops.aten._foreach_maximum.overloads()
175+
except:
176+
foreach = False
177+
173178
if foreach and not torch.jit.is_scripting():
174179
func = _multi_tensor_nadamw
175180
else:

0 commit comments

Comments
 (0)