-
Notifications
You must be signed in to change notification settings - Fork 6k
[flux dreambooth lora training] make LoRA target modules configurable + small bug fix #9646
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
51b0194
beb11ea
ad37cdf
31d8576
ff5511c
f611e5f
faa95af
b17f9bf
0ca6950
8c95792
e912ff8
29152db
7276da7
a5b3be3
73b0e0f
8c18e1e
4f034b9
2e3a7a1
76e119a
7c533ae
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -554,6 +554,15 @@ def parse_args(input_args=None): | |
"--adam_weight_decay_text_encoder", type=float, default=1e-03, help="Weight decay to use for text_encoder" | ||
) | ||
|
||
parser.add_argument( | ||
"--lora_layers", | ||
type=str, | ||
default=None, | ||
help=( | ||
'The transformer modules to apply LoRA training on. Please specify the layers in a comma seperated. E.g. - "to_k,to_q,to_v,to_out.0" will result in lora training of attention layers only' | ||
), | ||
) | ||
linoytsaban marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
parser.add_argument( | ||
"--adam_epsilon", | ||
type=float, | ||
|
@@ -1186,12 +1195,30 @@ def main(args): | |
if args.train_text_encoder: | ||
text_encoder_one.gradient_checkpointing_enable() | ||
|
||
# now we will add new LoRA weights to the attention layers | ||
if args.lora_layers is not None: | ||
target_modules = [layer.strip() for layer in args.lora_layers.split(",")] | ||
else: | ||
target_modules = [ | ||
"attn.to_k", | ||
"attn.to_q", | ||
"attn.to_v", | ||
"attn.to_out.0", | ||
"attn.add_k_proj", | ||
"attn.add_q_proj", | ||
"attn.add_v_proj", | ||
"attn.to_add_out", | ||
"ff.net.0.proj", | ||
"ff.net.2", | ||
"ff_context.net.0.proj", | ||
"ff_context.net.2", | ||
] | ||
Comment on lines
+1201
to
+1214
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems like a bit breaking no? Better to not do it and instead make a note from the README? WDYT? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Breaking or just changing default behavior? I think it's geared more towards the latter, but I think it's in line with the other trainers & makes sense for Transformer based models, so maybe a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah maybe a warning note at the beginning of the README should cut it. With this change, we're likely also increasing the total training wall-clock time in the default setting, so, that is worth noting. |
||
|
||
# now we will add new LoRA weights the transformer layers | ||
transformer_lora_config = LoraConfig( | ||
r=args.rank, | ||
lora_alpha=args.rank, | ||
init_lora_weights="gaussian", | ||
target_modules=["to_k", "to_q", "to_v", "to_out.0"], | ||
target_modules=target_modules, | ||
) | ||
transformer.add_adapter(transformer_lora_config) | ||
if args.train_text_encoder: | ||
|
@@ -1367,10 +1394,9 @@ def load_model_hook(models, input_dir): | |
f" {args.text_encoder_lr} and learning_rate: {args.learning_rate}. " | ||
f"When using prodigy only learning_rate is used as the initial learning rate." | ||
) | ||
# changes the learning rate of text_encoder_parameters_one and text_encoder_parameters_two to be | ||
# changes the learning rate of text_encoder_parameters_one to be | ||
# --learning_rate | ||
params_to_optimize[1]["lr"] = args.learning_rate | ||
params_to_optimize[2]["lr"] = args.learning_rate | ||
|
||
optimizer = optimizer_class( | ||
params_to_optimize, | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's provide the author courtesy here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@linoytsaban did we resolve this?