Skip to content

Commit fa3a910

Browse files
authored
[LoRA] depcrecate save_attn_procs(). (#10126)
depcrecate save_attn_procs().
1 parent 188bca3 commit fa3a910

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

src/diffusers/loaders/unet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,6 +492,9 @@ def save_attn_procs(
492492
)
493493
state_dict = {k: v for k, v in state_dict.items() if isinstance(v, torch.Tensor)}
494494
else:
495+
deprecation_message = "Using the `save_attn_procs()` method has been deprecated and will be removed in a future version. Please use `save_lora_adapter()`."
496+
deprecate("save_attn_procs", "0.40.0", deprecation_message)
497+
495498
if not USE_PEFT_BACKEND:
496499
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
497500

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1119,6 +1119,24 @@ def test_load_attn_procs_raise_warning(self):
11191119
lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4
11201120
), "Loading from a saved checkpoint should produce identical results."
11211121

1122+
@require_peft_backend
1123+
def test_save_attn_procs_raise_warning(self):
1124+
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
1125+
model = self.model_class(**init_dict)
1126+
model.to(torch_device)
1127+
1128+
unet_lora_config = get_unet_lora_config()
1129+
model.add_adapter(unet_lora_config)
1130+
1131+
assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet."
1132+
1133+
with tempfile.TemporaryDirectory() as tmpdirname:
1134+
with self.assertWarns(FutureWarning) as warning:
1135+
model.save_attn_procs(tmpdirname)
1136+
1137+
warning_message = str(warning.warnings[0].message)
1138+
assert "Using the `save_attn_procs()` method has been deprecated" in warning_message
1139+
11221140

11231141
@slow
11241142
class UNet2DConditionModelIntegrationTests(unittest.TestCase):

0 commit comments

Comments
 (0)