Skip to content

Commit e08cf74

Browse files
committed
make is_dora check consistent.
1 parent e3c250e commit e08cf74

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1100,6 +1100,12 @@ def lora_state_dict(
11001100
allow_pickle=allow_pickle,
11011101
)
11021102

1103+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1104+
if is_dora_scale_present:
1105+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1106+
logger.warning(warn_msg)
1107+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1108+
11031109
return state_dict
11041110

11051111
def load_lora_weights(
@@ -1136,12 +1142,6 @@ def load_lora_weights(
11361142
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
11371143
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
11381144

1139-
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
1140-
if is_dora_scale_present:
1141-
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
1142-
logger.warning(warn_msg)
1143-
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
1144-
11451145
is_correct_format = all("lora" in key for key in state_dict.keys())
11461146
if not is_correct_format:
11471147
raise ValueError("Invalid LoRA checkpoint.")
@@ -1611,7 +1611,6 @@ def lora_state_dict(
16111611
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
16121612

16131613
# TODO (sayakpaul): to a follow-up to clean and try to unify the conditions.
1614-
16151614
is_kohya = any(".lora_down.weight" in k for k in state_dict)
16161615
if is_kohya:
16171616
state_dict = _convert_kohya_flux_lora_to_diffusers(state_dict)
@@ -2395,6 +2394,11 @@ def lora_state_dict(
23952394
user_agent=user_agent,
23962395
allow_pickle=allow_pickle,
23972396
)
2397+
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2398+
if is_dora_scale_present:
2399+
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2400+
logger.warning(warn_msg)
2401+
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
23982402

23992403
return state_dict
24002404

@@ -2427,12 +2431,6 @@ def load_lora_weights(
24272431
# First, ensure that the checkpoint is a compatible one and can be successfully loaded.
24282432
state_dict = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
24292433

2430-
is_dora_scale_present = any("dora_scale" in k for k in state_dict)
2431-
if is_dora_scale_present:
2432-
warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
2433-
logger.warning(warn_msg)
2434-
state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
2435-
24362434
is_correct_format = all("lora" in key for key in state_dict.keys())
24372435
if not is_correct_format:
24382436
raise ValueError("Invalid LoRA checkpoint.")

0 commit comments

Comments
 (0)