Skip to content

Commit 743a569

Browse files
[flux dreambooth lora training] make LoRA target modules configurable + small bug fix (#9646)
* make lora target modules configurable and change the default * style * make lora target modules configurable and change the default * fix bug when using prodigy and training te * fix mixed precision training as proposed in #9565 for full dreambooth as well * add test and notes * style * address sayaks comments * style * fix test --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent db5b6a9 commit 743a569

File tree

4 files changed

+87
-5
lines changed

4 files changed

+87
-5
lines changed

examples/dreambooth/README_flux.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,21 @@ accelerate launch train_dreambooth_lora_flux.py \
170170
--push_to_hub
171171
```
172172

173+
### Target Modules
174+
When LoRA was first adapted from language models to diffusion models, it was applied to the cross-attention layers in the Unet that relate the image representations with the prompts that describe them.
175+
More recently, SOTA text-to-image diffusion models replaced the Unet with a diffusion Transformer(DiT). With this change, we may also want to explore
176+
applying LoRA training onto different types of layers and blocks. To allow more flexibility and control over the targeted modules we added `--lora_layers`- in which you can specify in a comma seperated string
177+
the exact modules for LoRA training. Here are some examples of target modules you can provide:
178+
- for attention only layers: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0"`
179+
- to train the same modules as in the fal trainer: `--lora_layers="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2"`
180+
- to train the same modules as in ostris ai-toolkit / replicate trainer: `--lora_blocks="attn.to_k,attn.to_q,attn.to_v,attn.to_out.0,attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,ff.net.0.proj,ff.net.2,ff_context.net.0.proj,ff_context.net.2,norm1_context.linear, norm1.linear,norm.linear,proj_mlp,proj_out"`
181+
> [!NOTE]
182+
> `--lora_layers` can also be used to specify which **blocks** to apply LoRA training to. To do so, simply add a block prefix to each layer in the comma seperated string:
183+
> **single DiT blocks**: to target the ith single transformer block, add the prefix `single_transformer_blocks.i`, e.g. - `single_transformer_blocks.i.attn.to_k`
184+
> **MMDiT blocks**: to target the ith MMDiT block, add the prefix `transformer_blocks.i`, e.g. - `transformer_blocks.i.attn.to_k`
185+
> [!NOTE]
186+
> keep in mind that while training more layers can improve quality and expressiveness, it also increases the size of the output LoRA weights.
187+
173188
### Text Encoder Training
174189

175190
Alongside the transformer, fine-tuning of the CLIP text encoder is also supported.

examples/dreambooth/test_dreambooth_lora_flux.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ class DreamBoothLoRAFlux(ExamplesTestsAccelerate):
3737
instance_prompt = "photo"
3838
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
3939
script_path = "examples/dreambooth/train_dreambooth_lora_flux.py"
40+
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
4041

4142
def test_dreambooth_lora_flux(self):
4243
with tempfile.TemporaryDirectory() as tmpdir:
@@ -136,6 +137,43 @@ def test_dreambooth_lora_latent_caching(self):
136137
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
137138
self.assertTrue(starts_with_transformer)
138139

140+
def test_dreambooth_lora_layers(self):
141+
with tempfile.TemporaryDirectory() as tmpdir:
142+
test_args = f"""
143+
{self.script_path}
144+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
145+
--instance_data_dir {self.instance_data_dir}
146+
--instance_prompt {self.instance_prompt}
147+
--resolution 64
148+
--train_batch_size 1
149+
--gradient_accumulation_steps 1
150+
--max_train_steps 2
151+
--cache_latents
152+
--learning_rate 5.0e-04
153+
--scale_lr
154+
--lora_layers {self.transformer_layer_type}
155+
--lr_scheduler constant
156+
--lr_warmup_steps 0
157+
--output_dir {tmpdir}
158+
""".split()
159+
160+
run_command(self._launch_args + test_args)
161+
# save_pretrained smoke test
162+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
163+
164+
# make sure the state_dict has the correct naming in the parameters.
165+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
166+
is_lora = all("lora" in k for k in lora_state_dict.keys())
167+
self.assertTrue(is_lora)
168+
169+
# when not training the text encoder, all the parameters in the state dict should start
170+
# with `"transformer"` in their names. In this test, we only params of
171+
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
172+
starts_with_transformer = all(
173+
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
174+
)
175+
self.assertTrue(starts_with_transformer)
176+
139177
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
140178
with tempfile.TemporaryDirectory() as tmpdir:
141179
test_args = f"""

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def log_validation(
161161
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
162162
f" {args.validation_prompt}."
163163
)
164-
pipeline = pipeline.to(accelerator.device, dtype=torch_dtype)
164+
pipeline = pipeline.to(accelerator.device)
165165
pipeline.set_progress_bar_config(disable=True)
166166

167167
# run inference
@@ -1579,7 +1579,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
15791579
)
15801580

15811581
# handle guidance
1582-
if transformer.config.guidance_embeds:
1582+
if accelerator.unwrap_model(transformer).config.guidance_embeds:
15831583
guidance = torch.tensor([args.guidance_scale], device=accelerator.device)
15841584
guidance = guidance.expand(model_input.shape[0])
15851585
else:
@@ -1693,6 +1693,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
16931693
# create pipeline
16941694
if not args.train_text_encoder:
16951695
text_encoder_one, text_encoder_two = load_text_encoders(text_encoder_cls_one, text_encoder_cls_two)
1696+
text_encoder_one.to(weight_dtype)
1697+
text_encoder_two.to(weight_dtype)
16961698
else: # even when training the text encoder we're only training text encoder one
16971699
text_encoder_two = text_encoder_cls_two.from_pretrained(
16981700
args.pretrained_model_name_or_path,

examples/dreambooth/train_dreambooth_lora_flux.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -554,6 +554,15 @@ def parse_args(input_args=None):
554554
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
555555
)
556556

557+
parser.add_argument(
558+
"--lora_layers",
559+
type=str,
560+
default=None,
561+
help=(
562+
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only'
563+
),
564+
)
565+
557566
parser.add_argument(
558567
"--adam_epsilon",
559568
type=float,
@@ -1186,12 +1195,30 @@ def main(args):
11861195
if args.train_text_encoder:
11871196
text_encoder_one.gradient_checkpointing_enable()
11881197

1189-
# now we will add new LoRA weights to the attention layers
1198+
if args.lora_layers is not None:
1199+
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
1200+
else:
1201+
target_modules = [
1202+
"attn.to_k",
1203+
"attn.to_q",
1204+
"attn.to_v",
1205+
"attn.to_out.0",
1206+
"attn.add_k_proj",
1207+
"attn.add_q_proj",
1208+
"attn.add_v_proj",
1209+
"attn.to_add_out",
1210+
"ff.net.0.proj",
1211+
"ff.net.2",
1212+
"ff_context.net.0.proj",
1213+
"ff_context.net.2",
1214+
]
1215+
1216+
# now we will add new LoRA weights the transformer layers
11901217
transformer_lora_config = LoraConfig(
11911218
r=args.rank,
11921219
lora_alpha=args.rank,
11931220
init_lora_weights="gaussian",
1194-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1221+
target_modules=target_modules,
11951222
)
11961223
transformer.add_adapter(transformer_lora_config)
11971224
if args.train_text_encoder:
@@ -1367,7 +1394,7 @@ def load_model_hook(models, input_dir):
13671394
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. "
13681395
f"When using prodigy only learning_rate is used as the initial learning rate."
13691396
)
1370-
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be
1397+
# changes the learning rate of text_encoder_parameters_one to be
13711398
# --learning_rate
13721399
params_to_optimize[1]["lr"] = args.learning_rate
13731400

0 commit comments

Comments
 (0)