Skip to content

Commit 9ebdd10

Browse files
[deepspeed] partial ZeRO-3 support (huggingface#3076)
* [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent 6f0d3ea commit 9ebdd10

File tree

1 file changed

+18
-5
lines changed

1 file changed

+18
-5
lines changed

training_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import copy
23
import os
34
import random
@@ -6,7 +7,11 @@
67
import numpy as np
78
import torch
89

9-
from .utils import deprecate
10+
from .utils import deprecate, is_transformers_available
11+
12+
13+
if is_transformers_available():
14+
import transformers
1015

1116

1217
def enable_full_determinism(seed: int):
@@ -197,11 +202,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
197202
self.cur_decay_value = decay
198203
one_minus_decay = 1 - decay
199204

205+
context_manager = contextlib.nullcontext
206+
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
207+
import deepspeed
208+
200209
for s_param, param in zip(self.shadow_params, parameters):
201-
if param.requires_grad:
202-
s_param.sub_(one_minus_decay * (s_param - param))
203-
else:
204-
s_param.copy_(param)
210+
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
211+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
212+
213+
with context_manager():
214+
if param.requires_grad:
215+
s_param.sub_(one_minus_decay * (s_param - param))
216+
else:
217+
s_param.copy_(param)
205218

206219
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
207220
"""

0 commit comments

Comments
 (0)