Skip to content

Commit af2a237

Browse files
[deepspeed] partial ZeRO-3 support (#3076)
* [deepspeed] partial ZeRO-3 support * cleanup * improve deepspeed fixes * Improve * make style --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent d71db89 commit af2a237

File tree

2 files changed

+48
-9
lines changed

2 files changed

+48
-9
lines changed

examples/text_to_image/train_text_to_image.py

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,15 @@
2929
import transformers
3030
from accelerate import Accelerator
3131
from accelerate.logging import get_logger
32+
from accelerate.state import AcceleratorState
3233
from accelerate.utils import ProjectConfiguration, set_seed
3334
from datasets import load_dataset
3435
from huggingface_hub import create_repo, upload_folder
3536
from packaging import version
3637
from torchvision import transforms
3738
from tqdm.auto import tqdm
3839
from transformers import CLIPTextModel, CLIPTokenizer
40+
from transformers.utils import ContextManagers
3941

4042
import diffusers
4143
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
@@ -464,10 +466,34 @@ def main():
464466
tokenizer = CLIPTokenizer.from_pretrained(
465467
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
466468
)
467-
text_encoder = CLIPTextModel.from_pretrained(
468-
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
469-
)
470-
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
469+
470+
def deepspeed_zero_init_disabled_context_manager():
471+
"""
472+
returns either a context list that includes one that will disable zero.Init or an empty context list
473+
"""
474+
deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None
475+
if deepspeed_plugin is None:
476+
return []
477+
478+
return [deepspeed_plugin.zero3_init_context_manager(enable=False)]
479+
480+
# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
481+
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
482+
# will try to assign the same optimizer with the same weights to all models during
483+
# `deepspeed.initialize`, which of course doesn't work.
484+
#
485+
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
486+
# frozen models from being partitioned during `zero.Init` which gets called during
487+
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
488+
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
489+
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
490+
text_encoder = CLIPTextModel.from_pretrained(
491+
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
492+
)
493+
vae = AutoencoderKL.from_pretrained(
494+
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
495+
)
496+
471497
unet = UNet2DConditionModel.from_pretrained(
472498
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
473499
)

src/diffusers/training_utils.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import contextlib
12
import copy
23
import os
34
import random
@@ -6,7 +7,11 @@
67
import numpy as np
78
import torch
89

9-
from .utils import deprecate
10+
from .utils import deprecate, is_transformers_available
11+
12+
13+
if is_transformers_available():
14+
import transformers
1015

1116

1217
def enable_full_determinism(seed: int):
@@ -197,11 +202,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
197202
self.cur_decay_value = decay
198203
one_minus_decay = 1 - decay
199204

205+
context_manager = contextlib.nullcontext
206+
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
207+
import deepspeed
208+
200209
for s_param, param in zip(self.shadow_params, parameters):
201-
if param.requires_grad:
202-
s_param.sub_(one_minus_decay * (s_param - param))
203-
else:
204-
s_param.copy_(param)
210+
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
211+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
212+
213+
with context_manager():
214+
if param.requires_grad:
215+
s_param.sub_(one_minus_decay * (s_param - param))
216+
else:
217+
s_param.copy_(param)
205218

206219
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
207220
"""

0 commit comments

Comments
 (0)