Skip to content

Commit 3045fb2

Browse files
authored
[DreamBooth] add text encoder LoRA support in the DreamBooth training script (#3130)
* add: LoRA text encoder support for DreamBooth example. * fix initialization. * fix: modification call. * add: entry in the readme. * use dog dataset from hub. * fix: params to clip. * add entry to the LoRA doc. * add: tests for lora. * remove unnecessary list comprehension./
1 parent 7b0ba48 commit 3045fb2

File tree

5 files changed

+197
-32
lines changed

5 files changed

+197
-32
lines changed

docs/source/en/training/dreambooth.mdx

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit.
6060

6161
<frameworkcontent>
6262
<pt>
63-
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
63+
Let's try DreamBooth with a
64+
[few images of a dog](https://huggingface.co/datasets/diffusers/dog-example);
65+
download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
66+
67+
```python
68+
local_dir = "./path_to_training_images"
69+
snapshot_download(
70+
"diffusers/dog-example",
71+
local_dir=local_dir, repo_type="dataset",
72+
ignore_patterns=".gitattributes",
73+
)
74+
```
6475

6576
```bash
6677
export MODEL_NAME="CompVis/stable-diffusion-v1-4"

docs/source/en/training/lora.mdx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.
1616

1717
<Tip warning={true}>
1818

19-
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
19+
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
20+
support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support
21+
LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918).
2022

2123
</Tip>
2224

@@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \
175177
--push_to_hub
176178
```
177179

180+
It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
181+
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,
182+
specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.
183+
184+
178185
### Inference[[dreambooth-inference]]
179186

180187
Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:

examples/dreambooth/README.md

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,28 @@ write_basic_config()
4545

4646
### Dog toy example
4747

48-
Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
48+
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.
4949

50-
And launch the training using
50+
Let's first download it locally:
51+
52+
```python
53+
from huggingface_hub import snapshot_download
54+
55+
local_dir = "./dog"
56+
snapshot_download(
57+
"diffusers/dog-example",
58+
local_dir=local_dir, repo_type="dataset",
59+
ignore_patterns=".gitattributes",
60+
)
61+
```
62+
63+
And launch the training using:
5164

5265
**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**
5366

5467
```bash
5568
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
56-
export INSTANCE_DIR="path-to-instance-images"
69+
export INSTANCE_DIR="dog"
5770
export OUTPUT_DIR="path-to-save-model"
5871

5972
accelerate launch train_dreambooth.py \
@@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples`
7790

7891
```bash
7992
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
80-
export INSTANCE_DIR="path-to-instance-images"
93+
export INSTANCE_DIR="dog"
8194
export CLASS_DIR="path-to-class-images"
8295
export OUTPUT_DIR="path-to-save-model"
8396

@@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet
108121

109122
```bash
110123
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
111-
export INSTANCE_DIR="path-to-instance-images"
124+
export INSTANCE_DIR="dog"
112125
export CLASS_DIR="path-to-class-images"
113126
export OUTPUT_DIR="path-to-save-model"
114127

@@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati
141154

142155
```bash
143156
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
144-
export INSTANCE_DIR="path-to-instance-images"
157+
export INSTANCE_DIR="dog"
145158
export CLASS_DIR="path-to-class-images"
146159
export OUTPUT_DIR="path-to-save-model"
147160

@@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment.
185198

186199
```bash
187200
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
188-
export INSTANCE_DIR="path-to-instance-images"
201+
export INSTANCE_DIR="dog"
189202
export CLASS_DIR="path-to-class-images"
190203
export OUTPUT_DIR="path-to-save-model"
191204

@@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini
217230

218231
```bash
219232
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
220-
export INSTANCE_DIR="path-to-instance-images"
233+
export INSTANCE_DIR="dog"
221234
export CLASS_DIR="path-to-class-images"
222235
export OUTPUT_DIR="path-to-save-model"
223236

@@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https:
300313

301314
```bash
302315
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
303-
export INSTANCE_DIR="path-to-instance-images"
316+
export INSTANCE_DIR="dog"
304317
export OUTPUT_DIR="path-to-save-model"
305318
```
306319

@@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
342355
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
343356
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.
344357

358+
Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
359+
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).
360+
361+
With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).
362+
363+
345364
### Inference
346365

347366
After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
@@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt
386405

387406
```bash
388407
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
389-
export INSTANCE_DIR="path-to-instance-images"
408+
export INSTANCE_DIR="dog"
390409
export OUTPUT_DIR="path-to-save-model"
391410

392411
python train_dreambooth_flax.py \
@@ -405,7 +424,7 @@ python train_dreambooth_flax.py \
405424

406425
```bash
407426
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
408-
export INSTANCE_DIR="path-to-instance-images"
427+
export INSTANCE_DIR="dog"
409428
export CLASS_DIR="path-to-class-images"
410429
export OUTPUT_DIR="path-to-save-model"
411430

@@ -429,7 +448,7 @@ python train_dreambooth_flax.py \
429448

430449
```bash
431450
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
432-
export INSTANCE_DIR="path-to-instance-images"
451+
export INSTANCE_DIR="dog"
433452
export CLASS_DIR="path-to-class-images"
434453
export OUTPUT_DIR="path-to-save-model"
435454

examples/dreambooth/train_dreambooth_lora.py

Lines changed: 83 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import argparse
1717
import hashlib
18+
import itertools
1819
import logging
1920
import math
2021
import os
@@ -43,12 +44,13 @@
4344
DDPMScheduler,
4445
DiffusionPipeline,
4546
DPMSolverMultistepScheduler,
47+
StableDiffusionPipeline,
4648
UNet2DConditionModel,
4749
)
48-
from diffusers.loaders import AttnProcsLayers
50+
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
4951
from diffusers.models.attention_processor import LoRAAttnProcessor
5052
from diffusers.optimization import get_scheduler
51-
from diffusers.utils import check_min_version, is_wandb_available
53+
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
5254
from diffusers.utils.import_utils import is_xformers_available
5355

5456

@@ -58,7 +60,7 @@
5860
logger = get_logger(__name__)
5961

6062

61-
def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
63+
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
6264
img_str = ""
6365
for i, image in enumerate(images):
6466
image.save(os.path.join(repo_folder, f"image_{i}.png"))
@@ -83,6 +85,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_
8385
8486
These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
8587
{img_str}
88+
89+
LoRA for the text encoder was enabled: {train_text_encoder}.
8690
"""
8791
with open(os.path.join(repo_folder, "README.md"), "w") as f:
8892
f.write(yaml + model_card)
@@ -219,6 +223,11 @@ def parse_args(input_args=None):
219223
" cropped. The images will be resized to the resolution first before cropping."
220224
),
221225
)
226+
parser.add_argument(
227+
"--train_text_encoder",
228+
action="store_true",
229+
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
230+
)
222231
parser.add_argument(
223232
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
224233
)
@@ -547,7 +556,13 @@ def main(args):
547556

548557
# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
549558
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
550-
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
559+
# TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
560+
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
561+
raise ValueError(
562+
"Gradient accumulation is not supported when training the text encoder in distributed training. "
563+
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
564+
)
565+
551566
# Make one log on every process with the configuration for debugging.
552567
logging.basicConfig(
553568
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
@@ -691,7 +706,7 @@ def main(args):
691706
# => 32 layers
692707

693708
# Set correct lora layers
694-
lora_attn_procs = {}
709+
unet_lora_attn_procs = {}
695710
for name in unet.attn_processors.keys():
696711
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
697712
if name.startswith("mid_block"):
@@ -703,12 +718,33 @@ def main(args):
703718
block_id = int(name[len("down_blocks.")])
704719
hidden_size = unet.config.block_out_channels[block_id]
705720

706-
lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)
707-
708-
unet.set_attn_processor(lora_attn_procs)
709-
lora_layers = AttnProcsLayers(unet.attn_processors)
721+
unet_lora_attn_procs[name] = LoRAAttnProcessor(
722+
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
723+
)
710724

711-
accelerator.register_for_checkpointing(lora_layers)
725+
unet.set_attn_processor(unet_lora_attn_procs)
726+
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
727+
accelerator.register_for_checkpointing(unet_lora_layers)
728+
729+
# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
730+
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
731+
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
732+
text_encoder_lora_layers = None
733+
if args.train_text_encoder:
734+
text_lora_attn_procs = {}
735+
for name, module in text_encoder.named_modules():
736+
if any(x in name for x in TEXT_ENCODER_TARGET_MODULES):
737+
text_lora_attn_procs[name] = LoRAAttnProcessor(
738+
hidden_size=module.out_features, cross_attention_dim=None
739+
)
740+
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
741+
temp_pipeline = StableDiffusionPipeline.from_pretrained(
742+
args.pretrained_model_name_or_path, text_encoder=text_encoder
743+
)
744+
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
745+
text_encoder = temp_pipeline.text_encoder
746+
accelerator.register_for_checkpointing(unet_lora_layers)
747+
del temp_pipeline
712748

713749
if args.scale_lr:
714750
args.learning_rate = (
@@ -739,8 +775,13 @@ def main(args):
739775
optimizer_class = torch.optim.AdamW
740776

741777
# Optimizer creation
778+
params_to_optimize = (
779+
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
780+
if args.train_text_encoder
781+
else unet_lora_layers.parameters()
782+
)
742783
optimizer = optimizer_class(
743-
lora_layers.parameters(),
784+
params_to_optimize,
744785
lr=args.learning_rate,
745786
betas=(args.adam_beta1, args.adam_beta2),
746787
weight_decay=args.adam_weight_decay,
@@ -784,9 +825,14 @@ def main(args):
784825
)
785826

786827
# Prepare everything with our `accelerator`.
787-
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
788-
lora_layers, optimizer, train_dataloader, lr_scheduler
789-
)
828+
if args.train_text_encoder:
829+
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
830+
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
831+
)
832+
else:
833+
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
834+
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
835+
)
790836

791837
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
792838
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
@@ -845,6 +891,8 @@ def main(args):
845891

846892
for epoch in range(first_epoch, args.num_train_epochs):
847893
unet.train()
894+
if args.train_text_encoder:
895+
text_encoder.train()
848896
for step, batch in enumerate(train_dataloader):
849897
# Skip steps until we reach the resumed step
850898
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
@@ -900,7 +948,11 @@ def main(args):
900948

901949
accelerator.backward(loss)
902950
if accelerator.sync_gradients:
903-
params_to_clip = lora_layers.parameters()
951+
params_to_clip = (
952+
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
953+
if args.train_text_encoder
954+
else unet_lora_layers.parameters()
955+
)
904956
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
905957
optimizer.step()
906958
lr_scheduler.step()
@@ -914,7 +966,14 @@ def main(args):
914966
if global_step % args.checkpointing_steps == 0:
915967
if accelerator.is_main_process:
916968
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
917-
accelerator.save_state(save_path)
969+
# We combine the text encoder and UNet LoRA parameters with a simple
970+
# custom logic. `accelerator.save_state()` won't know that. So,
971+
# use `LoraLoaderMixin.save_lora_weights()`.
972+
LoraLoaderMixin.save_lora_weights(
973+
save_directory=save_path,
974+
unet_lora_layers=unet_lora_layers,
975+
text_encoder_lora_layers=text_encoder_lora_layers,
976+
)
918977
logger.info(f"Saved state to {save_path}")
919978

920979
logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
@@ -970,7 +1029,12 @@ def main(args):
9701029
accelerator.wait_for_everyone()
9711030
if accelerator.is_main_process:
9721031
unet = unet.to(torch.float32)
973-
unet.save_attn_procs(args.output_dir)
1032+
text_encoder = text_encoder.to(torch.float32)
1033+
LoraLoaderMixin.save_lora_weights(
1034+
save_directory=args.output_dir,
1035+
unet_lora_layers=unet_lora_layers,
1036+
text_encoder_lora_layers=text_encoder_lora_layers,
1037+
)
9741038

9751039
# Final inference
9761040
# Load previous pipeline
@@ -981,7 +1045,7 @@ def main(args):
9811045
pipeline = pipeline.to(accelerator.device)
9821046

9831047
# load attention processors
984-
pipeline.unet.load_attn_procs(args.output_dir)
1048+
pipeline.load_attn_procs(args.output_dir)
9851049

9861050
# run inference
9871051
if args.validation_prompt and args.num_validation_images > 0:
@@ -1010,6 +1074,7 @@ def main(args):
10101074
repo_id,
10111075
images=images,
10121076
base_model=args.pretrained_model_name_or_path,
1077+
train_text_encoder=args.train_text_encoder,
10131078
prompt=args.instance_prompt,
10141079
repo_folder=args.output_dir,
10151080
)

0 commit comments

Comments
 (0)