Skip to content

[deepspeed] partial ZeRO-3 support #3076

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 30 additions & 4 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@
import transformers
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.state import AcceleratorState
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm
from transformers import CLIPTextModel, CLIPTokenizer
from transformers.utils import ContextManagers

import diffusers
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
Expand Down Expand Up @@ -464,10 +466,34 @@ def main():
tokenizer = CLIPTokenizer.from_pretrained(
args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
)
text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)

def deepspeed_zero_init_disabled_context_manager():
"""
returns either a context list that includes one that will disable zero.Init or an empty context list
"""
deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None
if deepspeed_plugin is None:
return []

return [deepspeed_plugin.zero3_init_context_manager(enable=False)]

# Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3.
# For this to work properly all models must be run through `accelerate.prepare`. But accelerate
# will try to assign the same optimizer with the same weights to all models during
# `deepspeed.initialize`, which of course doesn't work.
#
# For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2
# frozen models from being partitioned during `zero.Init` which gets called during
# `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding
# across multiple gpus and only UNet2DConditionModel will get ZeRO sharded.
with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

works for me!

text_encoder = CLIPTextModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
vae = AutoencoderKL.from_pretrained(
args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision
)

unet = UNet2DConditionModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
)
Expand Down
23 changes: 18 additions & 5 deletions src/diffusers/training_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import contextlib
import copy
import os
import random
Expand All @@ -6,7 +7,11 @@
import numpy as np
import torch

from .utils import deprecate
from .utils import deprecate, is_transformers_available


if is_transformers_available():
import transformers


def enable_full_determinism(seed: int):
Expand Down Expand Up @@ -197,11 +202,19 @@ def step(self, parameters: Iterable[torch.nn.Parameter]):
self.cur_decay_value = decay
one_minus_decay = 1 - decay

context_manager = contextlib.nullcontext
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
import deepspeed

for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param))
else:
s_param.copy_(param)
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)

with context_manager():
if param.requires_grad:
s_param.sub_(one_minus_decay * (s_param - param))
else:
s_param.copy_(param)

def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Expand Down