Skip to content

[DreamBooth] add text encoder LoRA support in the DreamBooth training script #3130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Apr 20, 2023
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion docs/source/en/training/dreambooth.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,18 @@ DreamBooth finetuning is very sensitive to hyperparameters and easy to overfit.

<frameworkcontent>
<pt>
Let's try DreamBooth with a [few images of a dog](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ); download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:
Let's try DreamBooth with a
[few images of a dog](https://huggingface.co/datasets/diffusers/dog-example);
download and save them to a directory and then set the `INSTANCE_DIR` environment variable to that path:

```python
local_dir = "./path_to_training_images"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
Expand Down
9 changes: 8 additions & 1 deletion docs/source/en/training/lora.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ specific language governing permissions and limitations under the License.

<Tip warning={true}>

Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`].
Currently, LoRA is only supported for the attention layers of the [`UNet2DConditionalModel`]. We also
support LoRA fine-tuning of the text encoder for DreamBooth in a limited capacity. For more details on how we support
LoRA fine-tuning of the text encoder, refer to the discussion on [this PR](https://github.com/huggingface/diffusers/pull/2918).

</Tip>

Expand Down Expand Up @@ -175,6 +177,11 @@ accelerate launch train_dreambooth_lora.py \
--push_to_hub
```

It's also possible to additionally fine-tune the text encoder with LoRA. This, in most cases, leads
to better results with a slight increase in the compute. To allow fine-tuning the text encoder with LoRA,

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it be possible to add more experimental details of unet_lora and text_encoder_lora (and combined) in separate blogs (like dreambooth)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are leaving it to the community for now. But we will update if we find more :)

specify the `--train_text_encoder` while launching the `train_dreambooth_lora.py` script.


### Inference[[dreambooth-inference]]

Now you can use the model for inference by loading the base model in the [`StableDiffusionPipeline`]:
Expand Down
43 changes: 31 additions & 12 deletions examples/dreambooth/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,28 @@ write_basic_config()

### Dog toy example

Now let's get our dataset. Download images from [here](https://drive.google.com/drive/folders/1BO_dyz-p65qhBRRMRA4TbZ8qW4rB99JZ) and save them in a directory. This will be our training data.
Now let's get our dataset. For this example we will use some dog images: https://huggingface.co/datasets/diffusers/dog-example.

And launch the training using
Let's first download it locally:

```python
from huggingface_hub import snapshot_download

local_dir = "./dog"
snapshot_download(
"diffusers/dog-example",
local_dir=local_dir, repo_type="dataset",
ignore_patterns=".gitattributes",
)
```

And launch the training using:

**___Note: Change the `resolution` to 768 if you are using the [stable-diffusion-2](https://huggingface.co/stabilityai/stable-diffusion-2) 768x768 model.___**

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model"

accelerate launch train_dreambooth.py \
Expand All @@ -77,7 +90,7 @@ According to the paper, it's recommended to generate `num_epochs * num_samples`

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand Down Expand Up @@ -108,7 +121,7 @@ To install `bitandbytes` please refer to this [readme](https://github.com/TimDet

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand Down Expand Up @@ -141,7 +154,7 @@ It is possible to run dreambooth on a 12GB GPU by using the following optimizati

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand Down Expand Up @@ -185,7 +198,7 @@ does not seem to be compatible with DeepSpeed at the moment.

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand Down Expand Up @@ -217,7 +230,7 @@ ___Note: Training text encoder requires more memory, with this option the traini

```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand Down Expand Up @@ -300,7 +313,7 @@ Now, you can launch the training. Here we will use [Stable Diffusion 1-5](https:

```bash
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model"
```

Expand Down Expand Up @@ -342,6 +355,12 @@ The final LoRA embedding weights have been uploaded to [patrickvonplaten/lora_dr
The training results are summarized [here](https://api.wandb.ai/report/patrickvonplaten/xm6cd5q5).
You can use the `Step` slider to see how the model learned the features of our subject while the model trained.

Optionally, we can also train additional LoRA layers for the text encoder. Specify the `train_text_encoder` argument above for that. If you're interested to know more about how we
enable this support, check out this [PR](https://github.com/huggingface/diffusers/pull/2918).

With the default hyperparameters from the above, the training seems to go in a positive direction. Check out [this panel](https://wandb.ai/sayakpaul/dreambooth-lora/reports/test-23-04-17-17-00-13---Vmlldzo0MDkwNjMy). The trained LoRA layers are available [here](https://huggingface.co/sayakpaul/dreambooth).


### Inference

After training, LoRA weights can be loaded very easily into the original pipeline. First, you need to
Expand Down Expand Up @@ -386,7 +405,7 @@ pip install -U -r requirements_flax.txt

```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export OUTPUT_DIR="path-to-save-model"

python train_dreambooth_flax.py \
Expand All @@ -405,7 +424,7 @@ python train_dreambooth_flax.py \

```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand All @@ -429,7 +448,7 @@ python train_dreambooth_flax.py \

```bash
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export INSTANCE_DIR="path-to-instance-images"
export INSTANCE_DIR="dog"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"

Expand Down
101 changes: 83 additions & 18 deletions examples/dreambooth/train_dreambooth_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

import argparse
import hashlib
import itertools
import logging
import math
import os
Expand Down Expand Up @@ -43,12 +44,13 @@
DDPMScheduler,
DiffusionPipeline,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.loaders import AttnProcsLayers, LoraLoaderMixin
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.optimization import get_scheduler
from diffusers.utils import check_min_version, is_wandb_available
from diffusers.utils import TEXT_ENCODER_TARGET_MODULES, check_min_version, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


Expand All @@ -58,7 +60,7 @@
logger = get_logger(__name__)


def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_folder=None):
def save_model_card(repo_id: str, images=None, base_model=str, train_text_encoder=False, prompt=str, repo_folder=None):
img_str = ""
for i, image in enumerate(images):
image.save(os.path.join(repo_folder, f"image_{i}.png"))
Expand All @@ -83,6 +85,8 @@ def save_model_card(repo_id: str, images=None, base_model=str, prompt=str, repo_

These are LoRA adaption weights for {base_model}. The weights were trained on {prompt} using [DreamBooth](https://dreambooth.github.io/). You can find some example images in the following. \n
{img_str}

LoRA for the text encoder was enabled: {train_text_encoder}.
"""
with open(os.path.join(repo_folder, "README.md"), "w") as f:
f.write(yaml + model_card)
Expand Down Expand Up @@ -219,6 +223,11 @@ def parse_args(input_args=None):
" cropped. The images will be resized to the resolution first before cropping."
),
)
parser.add_argument(
"--train_text_encoder",
action="store_true",
help="Whether to train the text encoder. If set, the text encoder should be float32 precision.",
)
parser.add_argument(
"--train_batch_size", type=int, default=4, help="Batch size (per device) for the training dataloader."
)
Expand Down Expand Up @@ -547,7 +556,13 @@ def main(args):

# Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
# This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
# TODO (patil-suraj): Remove this check when gradient accumulation with two models is enabled in accelerate.
# TODO (sayakpaul): Remove this check when gradient accumulation with two models is enabled in accelerate.
if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
raise ValueError(
"Gradient accumulation is not supported when training the text encoder in distributed training. "
"Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
)

# Make one log on every process with the configuration for debugging.
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
Expand Down Expand Up @@ -691,7 +706,7 @@ def main(args):
# => 32 layers

# Set correct lora layers
lora_attn_procs = {}
unet_lora_attn_procs = {}
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"):
Expand All @@ -703,12 +718,33 @@ def main(args):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]

lora_attn_procs[name] = LoRAAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim)

unet.set_attn_processor(lora_attn_procs)
lora_layers = AttnProcsLayers(unet.attn_processors)
unet_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)

accelerator.register_for_checkpointing(lora_layers)
unet.set_attn_processor(unet_lora_attn_procs)
unet_lora_layers = AttnProcsLayers(unet.attn_processors)
accelerator.register_for_checkpointing(unet_lora_layers)

# The text encoder comes from 🤗 transformers, so we cannot directly modify it.
# So, instead, we monkey-patch the forward calls of its attention-blocks. For this,
# we first load a dummy pipeline with the text encoder and then do the monkey-patching.
text_encoder_lora_layers = None
if args.train_text_encoder:
text_lora_attn_procs = {}
for name, module in text_encoder.named_modules():
if any([x in name for x in TEXT_ENCODER_TARGET_MODULES]):
text_lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=module.out_features, cross_attention_dim=None
)
text_encoder_lora_layers = AttnProcsLayers(text_lora_attn_procs)
temp_pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path, text_encoder=text_encoder
)
temp_pipeline._modify_text_encoder(text_lora_attn_procs)
text_encoder = temp_pipeline.text_encoder
accelerator.register_for_checkpointing(unet_lora_layers)
del temp_pipeline

if args.scale_lr:
args.learning_rate = (
Expand Down Expand Up @@ -739,8 +775,13 @@ def main(args):
optimizer_class = torch.optim.AdamW

# Optimizer creation
params_to_optimize = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
optimizer = optimizer_class(
lora_layers.parameters(),
params_to_optimize,
lr=args.learning_rate,
betas=(args.adam_beta1, args.adam_beta2),
weight_decay=args.adam_weight_decay,
Expand Down Expand Up @@ -784,9 +825,14 @@ def main(args):
)

# Prepare everything with our `accelerator`.
lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
lora_layers, optimizer, train_dataloader, lr_scheduler
)
if args.train_text_encoder:
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, text_encoder_lora_layers, optimizer, train_dataloader, lr_scheduler
)
else:
unet_lora_layers, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
unet_lora_layers, optimizer, train_dataloader, lr_scheduler
)

# We need to recalculate our total training steps as the size of the training dataloader may have changed.
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
Expand Down Expand Up @@ -845,6 +891,8 @@ def main(args):

for epoch in range(first_epoch, args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
# Skip steps until we reach the resumed step
if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
Expand Down Expand Up @@ -900,7 +948,11 @@ def main(args):

accelerator.backward(loss)
if accelerator.sync_gradients:
params_to_clip = lora_layers.parameters()
params_to_clip = (
itertools.chain(unet_lora_layers.parameters(), text_encoder_lora_layers.parameters())
if args.train_text_encoder
else unet_lora_layers.parameters()
)
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
Expand All @@ -914,7 +966,14 @@ def main(args):
if global_step % args.checkpointing_steps == 0:
if accelerator.is_main_process:
save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}")
accelerator.save_state(save_path)
# We combine the text encoder and UNet LoRA parameters with a simple
# custom logic. `accelerator.save_state()` won't know that. So,
# use `LoraLoaderMixin.save_lora_weights()`.
LoraLoaderMixin.save_lora_weights(
save_directory=save_path,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)
logger.info(f"Saved state to {save_path}")

logs = {"loss": loss.detach().item(), "lr": lr_scheduler.get_last_lr()[0]}
Expand Down Expand Up @@ -970,7 +1029,12 @@ def main(args):
accelerator.wait_for_everyone()
if accelerator.is_main_process:
unet = unet.to(torch.float32)
unet.save_attn_procs(args.output_dir)
text_encoder = text_encoder.to(torch.float32)
LoraLoaderMixin.save_lora_weights(
save_directory=args.output_dir,
unet_lora_layers=unet_lora_layers,
text_encoder_lora_layers=text_encoder_lora_layers,
)

# Final inference
# Load previous pipeline
Expand All @@ -981,7 +1045,7 @@ def main(args):
pipeline = pipeline.to(accelerator.device)

# load attention processors
pipeline.unet.load_attn_procs(args.output_dir)
pipeline.load_attn_procs(args.output_dir)

# run inference
if args.validation_prompt and args.num_validation_images > 0:
Expand Down Expand Up @@ -1010,6 +1074,7 @@ def main(args):
repo_id,
images=images,
base_model=args.pretrained_model_name_or_path,
train_text_encoder=args.train_text_encoder,
prompt=args.instance_prompt,
repo_folder=args.output_dir,
)
Expand Down
Loading