Skip to content

Support zero-3 for FLUX training #10743

Open
@xiaoyewww

Description

@xiaoyewww

Describe the bug

Due to memory limitations, I am attempting to use Zero-3 for Flux training on 8 GPUs with 32GB each. I encountered a bug similar to the one reported in this issue: #1865. I made modifications based on the solution proposed in this pull request: #3076. However, the same error persists. In my opinion, the fix does not work as expected, at least not entirely. Could you advise on how to modify it further?

The relevant code from https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/train_dreambooth_lora_flux.py#L1157 has been updated as follows:

    def deepspeed_zero_init_disabled_context_manager():
        """
        returns either a context list that includes one that will disable zero.Init or an empty context list
        """

        deepspeed_plugin = AcceleratorState().deepspeed_plugin if accelerate.state.is_initialized() else None
        print(f"deepspeed_plugin: {deepspeed_plugin}")
        if deepspeed_plugin is None:
            return []

        return [deepspeed_plugin.zero3_init_context_manager(enable=False)]

    with ContextManagers(deepspeed_zero_init_disabled_context_manager()):
        text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
        vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path,
            subfolder="vae",
            revision=args.revision,
            variant=args.variant,
        )

Reproduction

deepspeed config:

{
    "train_batch_size": "auto",
    "train_micro_batch_size_per_gpu": "auto",
    "gradient_accumulation_steps":"auto",
    "zero_optimization": {
      "stage": 3,
      "offload_optimizer": {"device": "cpu"},
      "stage3_gather_16bit_weights_on_model_save": false,
      "overlap_comm": false
    },
    "bf16": {
    "enabled": true
    },
    "fp16": {
    "enabled": false
    }
  }
  

accelerate config:

compute_environment: LOCAL_MACHINE
deepspeed_config:
  deepspeed_config_file: "config/ds_config.json"
distributed_type: DEEPSPEED
machine_rank: 0
main_training_function: main
num_machines: 1
num_processes: 8

training shell:

#!/bin/bash

export MODEL_NAME="black-forest-labs/FLUX.1-dev"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="trained-flux"

export DS_SKIP_CUDA_CHECK=1

export ACCELERATE_CONFIG_FILE="config/accelerate_config.yaml"

ACCELERATE_CONFIG_FILE_PATH=${1:-$ACCELERATE_CONFIG_FILE}  

FLUXOUTPUT_DIR=flux_lora_output

mkdir -p $FLUXOUTPUT_DIR

accelerate launch --config_file $ACCELERATE_CONFIG_FILE_PATH train_dreambooth_lora_flux.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --mixed_precision="bf16" \
  --instance_prompt="a photo of sks dog" \
  --resolution=1024 \
  --train_batch_size=4 \
  --guidance_scale=1 \
  --gradient_accumulation_steps=1 \
  --learning_rate=1e-4 \
  --report_to="tensorboard" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=100 \
  --gradient_checkpointing \
  --seed="0"

Logs

RuntimeError: 'weight' must be 2-D

System Info

pytorch: 2.1.0
deepspeed: 0.14.0
accelerate: 1.3.0
diffusers: develop

Who can help?

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions