Skip to content

Commit 02eeb8e

Browse files
authored
[LoRA] Handle DoRA better (#9547)
* handle dora. * print test * debug * fix * fix-copies * update logits * add warning in the test. * make is_dora check consistent. * fix-copies
1 parent 66eef9a commit 02eeb8e

File tree

2 files changed

+46
-13
lines changed

2 files changed

+46
-13
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def load_lora_weights(
9999
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
100100
state_dict, network_alphas = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
101101

102-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
102+
is_correct_format = all("lora" in key for key in state_dict.keys())
103103
if not is_correct_format:
104104
raise ValueError("Invalid LoRA checkpoint.")
105105

@@ -211,6 +211,11 @@ def lora_state_dict(
211211
user_agent=user_agent,
212212
allow_pickle=allow_pickle,
213213
)
214+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
215+
if is_dora_scale_present:
216+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
217+
logger.warning(warn_msg)
218+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
214219

215220
network_alphas = None
216221
# TODO: replace it with a method from `state_dict_utils`
@@ -562,7 +567,8 @@ def load_lora_weights(
562567
unet_config=self.unet.config,
563568
**kwargs,
564569
)
565-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
570+
571+
is_correct_format = all("lora" in key for key in state_dict.keys())
566572
if not is_correct_format:
567573
raise ValueError("Invalid LoRA checkpoint.")
568574

@@ -684,6 +690,11 @@ def lora_state_dict(
684690
user_agent=user_agent,
685691
allow_pickle=allow_pickle,
686692
)
693+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
694+
if is_dora_scale_present:
695+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
696+
logger.warning(warn_msg)
697+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
687698

688699
network_alphas = None
689700
# TODO: replace it with a method from `state_dict_utils`
@@ -1089,6 +1100,12 @@ def lora_state_dict(
10891100
allow_pickle=allow_pickle,
10901101
)
10911102

1103+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1104+
if is_dora_scale_present:
1105+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1106+
logger.warning(warn_msg)
1107+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1108+
10921109
return state_dict
10931110

10941111
def load_lora_weights(
@@ -1125,7 +1142,7 @@ def load_lora_weights(
11251142
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
11261143
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
11271144

1128-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1145+
is_correct_format = all("lora" in key for key in state_dict.keys())
11291146
if not is_correct_format:
11301147
raise ValueError("Invalid LoRA checkpoint.")
11311148

@@ -1587,9 +1604,13 @@ def lora_state_dict(
15871604
user_agent=user_agent,
15881605
allow_pickle=allow_pickle,
15891606
)
1607+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1608+
if is_dora_scale_present:
1609+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1610+
logger.warning(warn_msg)
1611+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
15901612

15911613
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1592-
15931614
is_kohya = any(".lora_down.weight" in k for k in state_dict)
15941615
if is_kohya:
15951616
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
@@ -1659,7 +1680,7 @@ def load_lora_weights(
16591680
pretrained_model_name_or_path_or_dict, return_alphas=True, **kwargs
16601681
)
16611682

1662-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
1683+
is_correct_format = all("lora" in key for key in state_dict.keys())
16631684
if not is_correct_format:
16641685
raise ValueError("Invalid LoRA checkpoint.")
16651686

@@ -2374,6 +2395,12 @@ def lora_state_dict(
23742395
allow_pickle=allow_pickle,
23752396
)
23762397

2398+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2399+
if is_dora_scale_present:
2400+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2401+
logger.warning(warn_msg)
2402+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2403+
23772404
return state_dict
23782405

23792406
def load_lora_weights(
@@ -2405,7 +2432,7 @@ def load_lora_weights(
24052432
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
24062433
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
24072434

2408-
is_correct_format = all("lora" in key or "dora_scale" in key for key in state_dict.keys())
2435+
is_correct_format = all("lora" in key for key in state_dict.keys())
24092436
if not is_correct_format:
24102437
raise ValueError("Invalid LoRA checkpoint.")
24112438

tests/lora/test_lora_layers_sdxl.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,10 @@
3333
StableDiffusionXLPipeline,
3434
T2IAdapter,
3535
)
36+
from diffusers.utils import logging
3637
from diffusers.utils.import_utils import is_accelerate_available
3738
from diffusers.utils.testing_utils import (
39+
CaptureLogger,
3840
load_image,
3941
nightly,
4042
numpy_cosine_similarity_distance,
@@ -620,14 +622,18 @@ def test_integration_logits_for_dora_lora(self):
620622
pipeline.load_lora_weights("hf-internal-testing/dora-trained-on-kohya")
621623
pipeline.enable_model_cpu_offload()
622624

623-
images = pipeline(
624-
"photo of ohwx dog",
625-
num_inference_steps=10,
626-
generator=torch.manual_seed(0),
627-
output_type="np",
628-
).images
625+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
626+
logger.setLevel(30)
627+
with CaptureLogger(logger) as cap_logger:
628+
images = pipeline(
629+
"photo of ohwx dog",
630+
num_inference_steps=10,
631+
generator=torch.manual_seed(0),
632+
output_type="np",
633+
).images
634+
assert "It seems like you are using a DoRA checkpoint" in cap_logger.out
629635

630636
predicted_slice = images[0, -3:, -3:, -1].flatten()
631-
expected_slice_scale = np.array([0.3932, 0.3742, 0.4429, 0.3737, 0.3504, 0.433, 0.3948, 0.3769, 0.4516])
637+
expected_slice_scale = np.array([0.1817, 0.0697, 0.2346, 0.0900, 0.1261, 0.2279, 0.1767, 0.1991, 0.2886])
632638
max_diff = numpy_cosine_similarity_distance(expected_slice_scale, predicted_slice)
633639
assert max_diff < 1e-3

0 commit comments

Comments
 (0)