File tree 1 file changed +18
-5
lines changed
1 file changed +18
-5
lines changed Original file line number Diff line number Diff line change
1
+ import contextlib
1
2
import copy
2
3
import os
3
4
import random
6
7
import numpy as np
7
8
import torch
8
9
9
- from .utils import deprecate
10
+ from .utils import deprecate , is_transformers_available
11
+
12
+
13
+ if is_transformers_available ():
14
+ import transformers
10
15
11
16
12
17
def enable_full_determinism (seed : int ):
@@ -197,11 +202,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
197
202
self .cur_decay_value = decay
198
203
one_minus_decay = 1 - decay
199
204
205
+ context_manager = contextlib .nullcontext
206
+ if is_transformers_available () and transformers .deepspeed .is_deepspeed_zero3_enabled ():
207
+ import deepspeed
208
+
200
209
for s_param , param in zip (self .shadow_params , parameters ):
201
- if param .requires_grad :
202
- s_param .sub_ (one_minus_decay * (s_param - param ))
203
- else :
204
- s_param .copy_ (param )
210
+ if is_transformers_available () and transformers .deepspeed .is_deepspeed_zero3_enabled ():
211
+ context_manager = deepspeed .zero .GatheredParameters (param , modifier_rank = None )
212
+
213
+ with context_manager ():
214
+ if param .requires_grad :
215
+ s_param .sub_ (one_minus_decay * (s_param - param ))
216
+ else :
217
+ s_param .copy_ (param )
205
218
206
219
def copy_to (self , parameters : Iterable [torch .nn .Parameter ]) -> None :
207
220
"""
You can’t perform that action at this time.
0 commit comments