Skip to content

Commit b66e691

Browse files
committed
add tests
1 parent 07d44e7 commit b66e691

File tree

2 files changed

+161
-22
lines changed

2 files changed

+161
-22
lines changed

src/diffusers/loaders/lora_pipeline.py

Lines changed: 8 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1835,22 +1835,6 @@ def load_lora_weights(
18351835
if not (has_lora_keys or has_norm_keys):
18361836
raise ValueError("Invalid LoRA checkpoint.")
18371837

1838-
def prune_state_dict_(state_dict):
1839-
pruned_keys = []
1840-
for key in list(state_dict.keys()):
1841-
is_lora_key_present = "lora" in key
1842-
is_norm_key_present = any(norm_key in key for norm_key in supported_norm_keys)
1843-
if not is_lora_key_present and not is_norm_key_present:
1844-
state_dict.pop(key)
1845-
pruned_keys.append(key)
1846-
return pruned_keys
1847-
1848-
pruned_keys = prune_state_dict_(state_dict)
1849-
if len(pruned_keys) > 0:
1850-
logger.warning(
1851-
f"The provided LoRA state dict contains additional weights that are not compatible with Flux. The following are the incompatible weights:\n{pruned_keys}"
1852-
)
1853-
18541838
transformer_lora_state_dict = {
18551839
k: state_dict.pop(k) for k in list(state_dict.keys()) if "transformer." in k and "lora" in k
18561840
}
@@ -1883,7 +1867,7 @@ def prune_state_dict_(state_dict):
18831867
)
18841868

18851869
if len(transformer_norm_state_dict) > 0:
1886-
self._transformer_norm_layers = self._load_norm_into_transformer(
1870+
transformer._transformer_norm_layers = self._load_norm_into_transformer(
18871871
transformer_norm_state_dict,
18881872
transformer=transformer,
18891873
discard_original_layers=False,
@@ -1977,7 +1961,7 @@ def _load_norm_into_transformer(
19771961
overwritten_layers_state_dict = {}
19781962
if not discard_original_layers:
19791963
for key in state_dict.keys():
1980-
overwritten_layers_state_dict[key] = transformer_state_dict[key]
1964+
overwritten_layers_state_dict[key] = transformer_state_dict[key].clone()
19811965

19821966
logger.info(
19831967
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will directly update the state_dict of the transformer "
@@ -2237,10 +2221,12 @@ def fuse_lora(
22372221
pipeline.fuse_lora(lora_scale=0.7)
22382222
```
22392223
"""
2224+
2225+
transformer = getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer
22402226
if (
2241-
hasattr(self, "_transformer_norm_layers")
2242-
and isinstance(self._transformer_norm_layers, dict)
2243-
and len(self._transformer_norm_layers.keys()) > 0
2227+
hasattr(transformer, "_transformer_norm_layers")
2228+
and isinstance(transformer._transformer_norm_layers, dict)
2229+
and len(transformer._transformer_norm_layers.keys()) > 0
22442230
):
22452231
logger.info(
22462232
"The provided state dict contains normalization layers in addition to LoRA layers. The normalization layers will be directly updated the state_dict of the transformer "
@@ -2303,7 +2289,7 @@ def _maybe_expand_transformer_param_shape_or_error_(
23032289
prefix = prefix or cls.transformer_name
23042290
for key in list(state_dict.keys()):
23052291
if key.split(".")[0] == prefix:
2306-
state_dict[key.replace(f"{prefix}.", "")] = state_dict.pop(key)
2292+
state_dict[key[len(f"{prefix}.") :]] = state_dict.pop(key)
23072293

23082294
# Expand transformer parameter shapes if they don't match lora
23092295
has_param_with_shape_update = False

tests/lora/test_lora_layers_flux.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,15 @@
2424
from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
2525

2626
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
27+
from diffusers.utils import logging
2728
from diffusers.utils.testing_utils import (
29+
CaptureLogger,
2830
floats_tensor,
2931
is_peft_available,
3032
nightly,
3133
numpy_cosine_similarity_distance,
3234
require_peft_backend,
35+
require_peft_version_greater,
3336
require_torch_gpu,
3437
slow,
3538
torch_device,
@@ -108,6 +111,30 @@ def get_dummy_inputs(self, with_generator=True):
108111

109112
return noise, input_ids, pipeline_inputs
110113

114+
def get_dummy_tensor_inputs(self, device=None):
115+
batch_size = 1
116+
num_latent_channels = 4
117+
num_image_channels = 3
118+
height = width = 4
119+
sequence_length = 48
120+
embedding_dim = 32
121+
122+
hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device)
123+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
124+
pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device)
125+
text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device)
126+
image_ids = torch.randn((height * width, num_image_channels)).to(torch_device)
127+
timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size)
128+
129+
return {
130+
"hidden_states": hidden_states,
131+
"encoder_hidden_states": encoder_hidden_states,
132+
"pooled_projections": pooled_prompt_embeds,
133+
"txt_ids": text_ids,
134+
"img_ids": image_ids,
135+
"timestep": timestep,
136+
}
137+
111138
def test_with_alpha_in_state_dict(self):
112139
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
113140
pipe = self.pipeline_class(**components)
@@ -156,6 +183,132 @@ def test_with_alpha_in_state_dict(self):
156183
)
157184
self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
158185

186+
def test_with_norm_in_state_dict(self):
187+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
188+
pipe = self.pipeline_class(**components)
189+
pipe = pipe.to(torch_device)
190+
pipe.set_progress_bar_config(disable=None)
191+
192+
inputs = self.get_dummy_tensor_inputs(torch_device)
193+
194+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
195+
logger.setLevel(logging.INFO)
196+
197+
with torch.no_grad():
198+
original_output = pipe.transformer(**inputs)[0]
199+
200+
for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]:
201+
norm_state_dict = {}
202+
for name, module in pipe.transformer.named_modules():
203+
if norm_layer not in name or not hasattr(module, "weight") or module.weight is None:
204+
continue
205+
norm_state_dict[f"transformer.{name}.weight"] = torch.randn(
206+
module.weight.shape, device=module.weight.device, dtype=module.weight.dtype
207+
)
208+
209+
with torch.no_grad():
210+
with CaptureLogger(logger) as cap_logger:
211+
pipe.load_lora_weights(norm_state_dict)
212+
lora_load_output = pipe.transformer(**inputs)[0]
213+
self.assertTrue(
214+
cap_logger.out.startswith(
215+
"The provided state dict contains normalization layers in addition to LoRA layers"
216+
)
217+
)
218+
219+
pipe.unload_lora_weights()
220+
lora_unload_output = pipe.transformer(**inputs)[0]
221+
222+
self.assertTrue(pipe.transformer._transformer_norm_layers is None)
223+
self.assertFalse(np.allclose(original_output, lora_load_output, atol=1e-5, rtol=1e-5))
224+
self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5))
225+
226+
with CaptureLogger(logger) as cap_logger:
227+
for key in list(norm_state_dict.keys()):
228+
norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key)
229+
pipe.load_lora_weights(norm_state_dict)
230+
231+
self.assertTrue(
232+
cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers")
233+
)
234+
235+
def test_lora_parameter_expanded_shapes(self):
236+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
237+
pipe = self.pipeline_class(**components)
238+
pipe = pipe.to(torch_device)
239+
pipe.set_progress_bar_config(disable=None)
240+
241+
inputs = self.get_dummy_tensor_inputs(torch_device)
242+
243+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
244+
logger.setLevel(logging.DEBUG)
245+
246+
with torch.no_grad():
247+
original_output = pipe.transformer(**inputs)[0]
248+
249+
out_features, in_features = pipe.transformer.x_embedder.weight.shape
250+
rank = 4
251+
252+
dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False)
253+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
254+
lora_state_dict = {
255+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
256+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
257+
}
258+
with CaptureLogger(logger) as cap_logger:
259+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
260+
inputs["hidden_states"] = torch.cat([inputs["hidden_states"]] * 2, dim=2)
261+
with torch.no_grad():
262+
expanded_output = pipe.transformer(**inputs)[0]
263+
pipe.delete_adapters("adapter-1")
264+
self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module"))
265+
self.assertFalse(np.allclose(original_output, expanded_output, atol=1e-3, rtol=1e-3))
266+
267+
components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
268+
pipe = self.pipeline_class(**components)
269+
pipe = pipe.to(torch_device)
270+
pipe.set_progress_bar_config(disable=None)
271+
dummy_lora_A = torch.nn.Linear(1, rank, bias=False)
272+
dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False)
273+
lora_state_dict = {
274+
"transformer.x_embedder.lora_A.weight": dummy_lora_A.weight,
275+
"transformer.x_embedder.lora_B.weight": dummy_lora_B.weight,
276+
}
277+
# We should error out because lora input features is less than original. We only
278+
# support expanding the module, not shrinking it
279+
with self.assertRaises(NotImplementedError):
280+
pipe.load_lora_weights(lora_state_dict, "adapter-1")
281+
282+
@require_peft_version_greater("0.13.2")
283+
def test_lora_B_bias(self):
284+
components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
285+
pipe = self.pipeline_class(**components)
286+
pipe = pipe.to(torch_device)
287+
pipe.set_progress_bar_config(disable=None)
288+
289+
inputs = self.get_dummy_tensor_inputs(torch_device)
290+
291+
logger = logging.get_logger("diffusers.loaders.lora_pipeline")
292+
logger.setLevel(logging.INFO)
293+
294+
with torch.no_grad():
295+
original_output = pipe.transformer(**inputs)[0]
296+
297+
denoiser_lora_config.lora_bias = False
298+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
299+
with torch.no_grad():
300+
lora_bias_false_output = pipe.transformer(**inputs)[0]
301+
pipe.delete_adapters("adapter-1")
302+
303+
denoiser_lora_config.lora_bias = True
304+
pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1")
305+
with torch.no_grad():
306+
lora_bias_true_output = pipe.transformer(**inputs)[0]
307+
308+
self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3))
309+
self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
310+
self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3))
311+
159312
@unittest.skip("Not supported in Flux.")
160313
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
161314
pass

0 commit comments

Comments
 (0)