|
29 | 29 | import transformers
|
30 | 30 | from accelerate import Accelerator
|
31 | 31 | from accelerate.logging import get_logger
|
| 32 | +from accelerate.state import AcceleratorState |
32 | 33 | from accelerate.utils import ProjectConfiguration, set_seed
|
33 | 34 | from datasets import load_dataset
|
34 | 35 | from huggingface_hub import create_repo, upload_folder
|
35 | 36 | from packaging import version
|
36 | 37 | from torchvision import transforms
|
37 | 38 | from tqdm.auto import tqdm
|
38 | 39 | from transformers import CLIPTextModel, CLIPTokenizer
|
| 40 | +from transformers.utils import ContextManagers |
39 | 41 |
|
40 | 42 | import diffusers
|
41 | 43 | from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
|
@@ -464,10 +466,34 @@ def main():
|
464 | 466 | tokenizer = CLIPTokenizer.from_pretrained(
|
465 | 467 | args.pretrained_model_name_or_path, subfolder="tokenizer", revision=args.revision
|
466 | 468 | )
|
467 |
| - text_encoder = CLIPTextModel.from_pretrained( |
468 |
| - args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision |
469 |
| - ) |
470 |
| - vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision) |
| 469 | + |
| 470 | + def deepspeed_zero_init_disabled_context_manager(): |
| 471 | + """ |
| 472 | + returns either a context list that includes one that will disable zero.Init or an empty context list |
| 473 | + """ |
| 474 | + deepspeed_plugin = AcceleratorState() if accelerate.state.is_initialized() else None |
| 475 | + if deepspeed_plugin is None: |
| 476 | + return [] |
| 477 | + |
| 478 | + return [deepspeed_plugin.zero3_init_context_manager(enable=False)] |
| 479 | + |
| 480 | + # Currently Accelerate doesn't know how to handle multiple models under Deepspeed ZeRO stage 3. |
| 481 | + # For this to work properly all models must be run through `accelerate.prepare`. But accelerate |
| 482 | + # will try to assign the same optimizer with the same weights to all models during |
| 483 | + # `deepspeed.initialize`, which of course doesn't work. |
| 484 | + # |
| 485 | + # For now the following workaround will partially support Deepspeed ZeRO-3, by excluding the 2 |
| 486 | + # frozen models from being partitioned during `zero.Init` which gets called during |
| 487 | + # `from_pretrained` So CLIPTextModel and AutoencoderKL will not enjoy the parameter sharding |
| 488 | + # across multiple gpus and only UNet2DConditionModel will get ZeRO sharded. |
| 489 | + with ContextManagers(deepspeed_zero_init_disabled_context_manager()): |
| 490 | + text_encoder = CLIPTextModel.from_pretrained( |
| 491 | + args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision |
| 492 | + ) |
| 493 | + vae = AutoencoderKL.from_pretrained( |
| 494 | + args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision |
| 495 | + ) |
| 496 | + |
471 | 497 | unet = UNet2DConditionModel.from_pretrained(
|
472 | 498 | args.pretrained_model_name_or_path, subfolder="unet", revision=args.non_ema_revision
|
473 | 499 | )
|
|
0 commit comments