Skip to content

Commit ba31a14

Browse files
linoytsabansayakpaul
authored andcommitted
[SD 3.5 Dreambooth LoRA] support configurable training block & layers (#9762)
* configurable layers * configurable layers * update README * style * add test * style * add layer test, update readme, add nargs * readme * test style * remove print, change nargs * test arg change * style * revert nargs 2/2 * address sayaks comments * style * address sayaks comments
1 parent dd6de12 commit ba31a14

File tree

3 files changed

+143
-1
lines changed

3 files changed

+143
-1
lines changed

examples/dreambooth/README_sd3.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,40 @@ accelerate launch train_dreambooth_lora_sd3.py \
147147
--push_to_hub
148148
```
149149

150+
### Targeting Specific Blocks & Layers
151+
As image generation models get bigger & more powerful, more fine-tuners come to find that training only part of the
152+
transformer blocks (sometimes as little as two) can be enough to get great results.
153+
In some cases, it can be even better to maintain some of the blocks/layers frozen.
154+
155+
For **SD3.5-Large** specifically, you may find this information useful (taken from: [Stable Diffusion 3.5 Large Fine-tuning Tutorial](https://stabilityai.notion.site/Stable-Diffusion-3-5-Large-Fine-tuning-Tutorial-11a61cdcd1968027a15bdbd7c40be8c6#12461cdcd19680788a23c650dab26b93):
156+
> [!NOTE]
157+
> A commonly believed heuristic that we verified once again during the construction of the SD3.5 family of models is that later/higher layers (i.e. `30 - 37`)* impact tertiary details more heavily. Conversely, earlier layers (i.e. `12 - 24` )* influence the overall composition/primary form more.
158+
> So, freezing other layers/targeting specific layers is a viable approach.
159+
> `*`These suggested layers are speculative and not 100% guaranteed. The tips here are more or less a general idea for next steps.
160+
> **Photorealism**
161+
> In preliminary testing, we observed that freezing the last few layers of the architecture significantly improved model training when using a photorealistic dataset, preventing detail degradation introduced by small dataset from happening.
162+
> **Anatomy preservation**
163+
> To dampen any possible degradation of anatomy, training only the attention layers and **not** the adaptive linear layers could help. For reference, below is one of the transformer blocks.
164+
165+
166+
We've added `--lora_layers` and `--lora_blocks` to make LoRA training modules configurable.
167+
- with `--lora_blocks` you can specify the block numbers for training. E.g. passing -
168+
```diff
169+
--lora_blocks "12,13,14,15,16,17,18,19,20,21,22,23,24,30,31,32,33,34,35,36,37"
170+
```
171+
will trigger LoRA training of transformer blocks 12-24 and 30-37. By default, all blocks are trained.
172+
- with `--lora_layers` you can specify the types of layers you wish to train.
173+
By default, the trained layers are -
174+
`attn.add_k_proj,attn.add_q_proj,attn.add_v_proj,attn.to_add_out,attn.to_k,attn.to_out.0,attn.to_q,attn.to_v`
175+
If you wish to have a leaner LoRA / train more blocks over layers you could pass -
176+
```diff
177+
+ --lora_layers attn.to_k,attn.to_q,attn.to_v,attn.to_out.0
178+
```
179+
This will reduce LoRA size by roughly 50% for the same rank compared to the default.
180+
However, if you're after compact LoRAs, it's our impression that maintaining the default setting for `--lora_layers` and
181+
freezing some of the early & blocks is usually better.
182+
183+
150184
### Text Encoder Training
151185
Alongside the transformer, LoRA fine-tuning of the CLIP text encoders is now also supported.
152186
To do so, just specify `--train_text_encoder` while launching training. Please keep the following points in mind:

examples/dreambooth/test_dreambooth_lora_sd3.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,9 @@ class DreamBoothLoRASD3(ExamplesTestsAccelerate):
3838
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
3939
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
4040

41+
transformer_block_idx = 0
42+
layer_type = "attn.to_k"
43+
4144
def test_dreambooth_lora_sd3(self):
4245
with tempfile.TemporaryDirectory() as tmpdir:
4346
test_args = f"""
@@ -136,6 +139,74 @@ def test_dreambooth_lora_latent_caching(self):
136139
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
137140
self.assertTrue(starts_with_transformer)
138141

142+
def test_dreambooth_lora_block(self):
143+
with tempfile.TemporaryDirectory() as tmpdir:
144+
test_args = f"""
145+
{self.script_path}
146+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
147+
--instance_data_dir {self.instance_data_dir}
148+
--instance_prompt {self.instance_prompt}
149+
--resolution 64
150+
--train_batch_size 1
151+
--gradient_accumulation_steps 1
152+
--max_train_steps 2
153+
--lora_blocks {self.transformer_block_idx}
154+
--learning_rate 5.0e-04
155+
--scale_lr
156+
--lr_scheduler constant
157+
--lr_warmup_steps 0
158+
--output_dir {tmpdir}
159+
""".split()
160+
161+
run_command(self._launch_args + test_args)
162+
# save_pretrained smoke test
163+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
164+
165+
# make sure the state_dict has the correct naming in the parameters.
166+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
167+
is_lora = all("lora" in k for k in lora_state_dict.keys())
168+
self.assertTrue(is_lora)
169+
170+
# when not training the text encoder, all the parameters in the state dict should start
171+
# with `"transformer"` in their names.
172+
# In this test, only params of transformer block 0 should be in the state dict
173+
starts_with_transformer = all(
174+
key.startswith("transformer.transformer_blocks.0") for key in lora_state_dict.keys()
175+
)
176+
self.assertTrue(starts_with_transformer)
177+
178+
def test_dreambooth_lora_layer(self):
179+
with tempfile.TemporaryDirectory() as tmpdir:
180+
test_args = f"""
181+
{self.script_path}
182+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
183+
--instance_data_dir {self.instance_data_dir}
184+
--instance_prompt {self.instance_prompt}
185+
--resolution 64
186+
--train_batch_size 1
187+
--gradient_accumulation_steps 1
188+
--max_train_steps 2
189+
--lora_layers {self.layer_type}
190+
--learning_rate 5.0e-04
191+
--scale_lr
192+
--lr_scheduler constant
193+
--lr_warmup_steps 0
194+
--output_dir {tmpdir}
195+
""".split()
196+
197+
run_command(self._launch_args + test_args)
198+
# save_pretrained smoke test
199+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
200+
201+
# make sure the state_dict has the correct naming in the parameters.
202+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
203+
is_lora = all("lora" in k for k in lora_state_dict.keys())
204+
self.assertTrue(is_lora)
205+
206+
# In this test, only transformer params of attention layers `attn.to_k` should be in the state dict
207+
starts_with_transformer = all("attn.to_k" in key for key in lora_state_dict.keys())
208+
self.assertTrue(starts_with_transformer)
209+
139210
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
140211
with tempfile.TemporaryDirectory() as tmpdir:
141212
test_args = f"""

examples/dreambooth/train_dreambooth_lora_sd3.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -571,6 +571,25 @@ def parse_args(input_args=None):
571571
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder"
572572
)
573573

574+
parser.add_argument(
575+
"--lora_layers",
576+
type=str,
577+
default=None,
578+
help=(
579+
"The transformer block layers to apply LoRA training on. Please specify the layers in a comma seperated string."
580+
"For examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md"
581+
),
582+
)
583+
parser.add_argument(
584+
"--lora_blocks",
585+
type=str,
586+
default=None,
587+
help=(
588+
"The transformer blocks to apply LoRA training on. Please specify the block numbers in a comma seperated manner."
589+
'E.g. - "--lora_blocks 12,30" will result in lora training of transformer blocks 12 and 30. For more examples refer to https://github.com/huggingface/diffusers/blob/main/examples/dreambooth/README_SD3.md'
590+
),
591+
)
592+
574593
parser.add_argument(
575594
"--adam_epsilon",
576595
type=float,
@@ -1222,13 +1241,31 @@ def main(args):
12221241
if args.train_text_encoder:
12231242
text_encoder_one.gradient_checkpointing_enable()
12241243
text_encoder_two.gradient_checkpointing_enable()
1244+
if args.lora_layers is not None:
1245+
target_modules = [layer.strip() for layer in args.lora_layers.split(",")]
1246+
else:
1247+
target_modules = [
1248+
"attn.add_k_proj",
1249+
"attn.add_q_proj",
1250+
"attn.add_v_proj",
1251+
"attn.to_add_out",
1252+
"attn.to_k",
1253+
"attn.to_out.0",
1254+
"attn.to_q",
1255+
"attn.to_v",
1256+
]
1257+
if args.lora_blocks is not None:
1258+
target_blocks = [int(block.strip()) for block in args.lora_blocks.split(",")]
1259+
target_modules = [
1260+
f"transformer_blocks.{block}.{module}" for block in target_blocks for module in target_modules
1261+
]
12251262

12261263
# now we will add new LoRA weights to the attention layers
12271264
transformer_lora_config = LoraConfig(
12281265
r=args.rank,
12291266
lora_alpha=args.rank,
12301267
init_lora_weights="gaussian",
1231-
target_modules=["to_k", "to_q", "to_v", "to_out.0"],
1268+
target_modules=target_modules,
12321269
)
12331270
transformer.add_adapter(transformer_lora_config)
12341271

0 commit comments

Comments
 (0)