Skip to content

Commit f95615b

Browse files
HelloWorldBeginnermhh001sayakpaul
authored
Fixed the bug related to saving DeepSpeed models. (#6628)
* Fixed the bug related to saving DeepSpeed models. * Add information about training SD models using DeepSpeed to the README. * Apply suggestions from code review --------- Co-authored-by: mhh001 <[email protected]> Co-authored-by: Sayak Paul <[email protected]>
1 parent a9288b4 commit f95615b

File tree

2 files changed

+65
-4
lines changed

2 files changed

+65
-4
lines changed

examples/text_to_image/README_sdxl.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,66 @@ The above command will also run inference as fine-tuning progresses and log the
183183

184184
* SDXL's VAE is known to suffer from numerical instability issues. This is why we also expose a CLI argument namely `--pretrained_vae_model_name_or_path` that lets you specify the location of a better VAE (such as [this one](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix)).
185185

186+
187+
### Using DeepSpeed
188+
Using DeepSpeed one can reduce the consumption of GPU memory, enabling the training of models on GPUs with smaller memory sizes. DeepSpeed is capable of offloading model parameters to the machine's memory, or it can distribute parameters, gradients, and optimizer states across multiple GPUs. This allows for the training of larger models under the same hardware configuration.
189+
190+
First, you need to use the `accelerate config` command to choose to use DeepSpeed, or manually use the accelerate config file to set up DeepSpeed.
191+
192+
Here is an example of a config file for using DeepSpeed. For more detailed explanations of the configuration, you can refer to this [link](https://huggingface.co/docs/accelerate/usage_guides/deepspeed).
193+
```yaml
194+
compute_environment: LOCAL_MACHINE
195+
debug: true
196+
deepspeed_config:
197+
gradient_accumulation_steps: 1
198+
gradient_clipping: 1.0
199+
offload_optimizer_device: none
200+
offload_param_device: none
201+
zero3_init_flag: false
202+
zero_stage: 2
203+
distributed_type: DEEPSPEED
204+
downcast_bf16: 'no'
205+
machine_rank: 0
206+
main_training_function: main
207+
mixed_precision: fp16
208+
num_machines: 1
209+
num_processes: 1
210+
rdzv_backend: static
211+
same_network: true
212+
tpu_env: []
213+
tpu_use_cluster: false
214+
tpu_use_sudo: false
215+
use_cpu: false
216+
```
217+
You need to save the mentioned configuration as an `accelerate_config.yaml` file. Then, you need to input the path of your `accelerate_config.yaml` file into the `ACCELERATE_CONFIG_FILE` parameter. This way you can use DeepSpeed to train your SDXL model in LoRA. Additionally, you can use DeepSpeed to train other SD models in this way.
218+
219+
```shell
220+
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
221+
export VAE_NAME="madebyollin/sdxl-vae-fp16-fix"
222+
export DATASET_NAME="lambdalabs/pokemon-blip-captions"
223+
export ACCELERATE_CONFIG_FILE="your accelerate_config.yaml"
224+
225+
accelerate launch --config_file $ACCELERATE_CONFIG_FILE train_text_to_image_lora_sdxl.py \
226+
--pretrained_model_name_or_path=$MODEL_NAME \
227+
--pretrained_vae_model_name_or_path=$VAE_NAME \
228+
--dataset_name=$DATASET_NAME --caption_column="text" \
229+
--resolution=1024 \
230+
--train_batch_size=1 \
231+
--num_train_epochs=2 \
232+
--checkpointing_steps=2 \
233+
--learning_rate=1e-04 \
234+
--lr_scheduler="constant" \
235+
--lr_warmup_steps=0 \
236+
--mixed_precision="fp16" \
237+
--max_train_steps=20 \
238+
--validation_epochs=20 \
239+
--seed=1234 \
240+
--output_dir="sd-pokemon-model-lora-sdxl" \
241+
--validation_prompt="cute dragon creature"
242+
243+
```
244+
245+
186246
### Finetuning the text encoder and UNet
187247

188248
The script also allows you to finetune the `text_encoder` along with the `unet`.

examples/text_to_image/train_text_to_image_lora_sdxl.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -652,21 +652,22 @@ def save_model_hook(models, weights, output_dir):
652652
text_encoder_two_lora_layers_to_save = None
653653

654654
for model in models:
655-
if isinstance(model, type(unwrap_model(unet))):
655+
if isinstance(unwrap_model(model), type(unwrap_model(unet))):
656656
unet_lora_layers_to_save = convert_state_dict_to_diffusers(get_peft_model_state_dict(model))
657-
elif isinstance(model, type(unwrap_model(text_encoder_one))):
657+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
658658
text_encoder_one_lora_layers_to_save = convert_state_dict_to_diffusers(
659659
get_peft_model_state_dict(model)
660660
)
661-
elif isinstance(model, type(unwrap_model(text_encoder_two))):
661+
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
662662
text_encoder_two_lora_layers_to_save = convert_state_dict_to_diffusers(
663663
get_peft_model_state_dict(model)
664664
)
665665
else:
666666
raise ValueError(f"unexpected save model: {model.__class__}")
667667

668668
# make sure to pop weight so that corresponding model is not saved again
669-
weights.pop()
669+
if weights:
670+
weights.pop()
670671

671672
StableDiffusionXLPipeline.save_lora_weights(
672673
output_dir,

0 commit comments

Comments
 (0)