|
24 | 24 | from transformers import AutoTokenizer, CLIPTextModel, CLIPTokenizer, T5EncoderModel
|
25 | 25 |
|
26 | 26 | from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline, FluxTransformer2DModel
|
| 27 | +from diffusers.utils import logging |
27 | 28 | from diffusers.utils.testing_utils import (
|
| 29 | + CaptureLogger, |
28 | 30 | floats_tensor,
|
29 | 31 | is_peft_available,
|
30 | 32 | nightly,
|
31 | 33 | numpy_cosine_similarity_distance,
|
32 | 34 | require_peft_backend,
|
| 35 | + require_peft_version_greater, |
33 | 36 | require_torch_gpu,
|
34 | 37 | slow,
|
35 | 38 | torch_device,
|
@@ -108,6 +111,30 @@ def get_dummy_inputs(self, with_generator=True):
|
108 | 111 |
|
109 | 112 | return noise, input_ids, pipeline_inputs
|
110 | 113 |
|
| 114 | + def get_dummy_tensor_inputs(self, device=None): |
| 115 | + batch_size = 1 |
| 116 | + num_latent_channels = 4 |
| 117 | + num_image_channels = 3 |
| 118 | + height = width = 4 |
| 119 | + sequence_length = 48 |
| 120 | + embedding_dim = 32 |
| 121 | + |
| 122 | + hidden_states = torch.randn((batch_size, height * width, num_latent_channels)).to(torch_device) |
| 123 | + encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device) |
| 124 | + pooled_prompt_embeds = torch.randn((batch_size, embedding_dim)).to(torch_device) |
| 125 | + text_ids = torch.randn((sequence_length, num_image_channels)).to(torch_device) |
| 126 | + image_ids = torch.randn((height * width, num_image_channels)).to(torch_device) |
| 127 | + timestep = torch.tensor([1.0]).to(torch_device).expand(batch_size) |
| 128 | + |
| 129 | + return { |
| 130 | + "hidden_states": hidden_states, |
| 131 | + "encoder_hidden_states": encoder_hidden_states, |
| 132 | + "pooled_projections": pooled_prompt_embeds, |
| 133 | + "txt_ids": text_ids, |
| 134 | + "img_ids": image_ids, |
| 135 | + "timestep": timestep, |
| 136 | + } |
| 137 | + |
111 | 138 | def test_with_alpha_in_state_dict(self):
|
112 | 139 | components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler)
|
113 | 140 | pipe = self.pipeline_class(**components)
|
@@ -156,6 +183,132 @@ def test_with_alpha_in_state_dict(self):
|
156 | 183 | )
|
157 | 184 | self.assertFalse(np.allclose(images_lora_with_alpha, images_lora, atol=1e-3, rtol=1e-3))
|
158 | 185 |
|
| 186 | + def test_with_norm_in_state_dict(self): |
| 187 | + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) |
| 188 | + pipe = self.pipeline_class(**components) |
| 189 | + pipe = pipe.to(torch_device) |
| 190 | + pipe.set_progress_bar_config(disable=None) |
| 191 | + |
| 192 | + inputs = self.get_dummy_tensor_inputs(torch_device) |
| 193 | + |
| 194 | + logger = logging.get_logger("diffusers.loaders.lora_pipeline") |
| 195 | + logger.setLevel(logging.INFO) |
| 196 | + |
| 197 | + with torch.no_grad(): |
| 198 | + original_output = pipe.transformer(**inputs)[0] |
| 199 | + |
| 200 | + for norm_layer in ["norm_q", "norm_k", "norm_added_q", "norm_added_k"]: |
| 201 | + norm_state_dict = {} |
| 202 | + for name, module in pipe.transformer.named_modules(): |
| 203 | + if norm_layer not in name or not hasattr(module, "weight") or module.weight is None: |
| 204 | + continue |
| 205 | + norm_state_dict[f"transformer.{name}.weight"] = torch.randn( |
| 206 | + module.weight.shape, device=module.weight.device, dtype=module.weight.dtype |
| 207 | + ) |
| 208 | + |
| 209 | + with torch.no_grad(): |
| 210 | + with CaptureLogger(logger) as cap_logger: |
| 211 | + pipe.load_lora_weights(norm_state_dict) |
| 212 | + lora_load_output = pipe.transformer(**inputs)[0] |
| 213 | + self.assertTrue( |
| 214 | + cap_logger.out.startswith( |
| 215 | + "The provided state dict contains normalization layers in addition to LoRA layers" |
| 216 | + ) |
| 217 | + ) |
| 218 | + |
| 219 | + pipe.unload_lora_weights() |
| 220 | + lora_unload_output = pipe.transformer(**inputs)[0] |
| 221 | + |
| 222 | + self.assertTrue(pipe.transformer._transformer_norm_layers is None) |
| 223 | + self.assertFalse(np.allclose(original_output, lora_load_output, atol=1e-5, rtol=1e-5)) |
| 224 | + self.assertTrue(np.allclose(original_output, lora_unload_output, atol=1e-5, rtol=1e-5)) |
| 225 | + |
| 226 | + with CaptureLogger(logger) as cap_logger: |
| 227 | + for key in list(norm_state_dict.keys()): |
| 228 | + norm_state_dict[key.replace("norm", "norm_k_something_random")] = norm_state_dict.pop(key) |
| 229 | + pipe.load_lora_weights(norm_state_dict) |
| 230 | + |
| 231 | + self.assertTrue( |
| 232 | + cap_logger.out.startswith("Unsupported keys found in state dict when trying to load normalization layers") |
| 233 | + ) |
| 234 | + |
| 235 | + def test_lora_parameter_expanded_shapes(self): |
| 236 | + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) |
| 237 | + pipe = self.pipeline_class(**components) |
| 238 | + pipe = pipe.to(torch_device) |
| 239 | + pipe.set_progress_bar_config(disable=None) |
| 240 | + |
| 241 | + inputs = self.get_dummy_tensor_inputs(torch_device) |
| 242 | + |
| 243 | + logger = logging.get_logger("diffusers.loaders.lora_pipeline") |
| 244 | + logger.setLevel(logging.DEBUG) |
| 245 | + |
| 246 | + with torch.no_grad(): |
| 247 | + original_output = pipe.transformer(**inputs)[0] |
| 248 | + |
| 249 | + out_features, in_features = pipe.transformer.x_embedder.weight.shape |
| 250 | + rank = 4 |
| 251 | + |
| 252 | + dummy_lora_A = torch.nn.Linear(2 * in_features, rank, bias=False) |
| 253 | + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) |
| 254 | + lora_state_dict = { |
| 255 | + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, |
| 256 | + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, |
| 257 | + } |
| 258 | + with CaptureLogger(logger) as cap_logger: |
| 259 | + pipe.load_lora_weights(lora_state_dict, "adapter-1") |
| 260 | + inputs["hidden_states"] = torch.cat([inputs["hidden_states"]] * 2, dim=2) |
| 261 | + with torch.no_grad(): |
| 262 | + expanded_output = pipe.transformer(**inputs)[0] |
| 263 | + pipe.delete_adapters("adapter-1") |
| 264 | + self.assertTrue(cap_logger.out.startswith("Expanding the nn.Linear input/output features for module")) |
| 265 | + self.assertFalse(np.allclose(original_output, expanded_output, atol=1e-3, rtol=1e-3)) |
| 266 | + |
| 267 | + components, _, _ = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) |
| 268 | + pipe = self.pipeline_class(**components) |
| 269 | + pipe = pipe.to(torch_device) |
| 270 | + pipe.set_progress_bar_config(disable=None) |
| 271 | + dummy_lora_A = torch.nn.Linear(1, rank, bias=False) |
| 272 | + dummy_lora_B = torch.nn.Linear(rank, out_features, bias=False) |
| 273 | + lora_state_dict = { |
| 274 | + "transformer.x_embedder.lora_A.weight": dummy_lora_A.weight, |
| 275 | + "transformer.x_embedder.lora_B.weight": dummy_lora_B.weight, |
| 276 | + } |
| 277 | + # We should error out because lora input features is less than original. We only |
| 278 | + # support expanding the module, not shrinking it |
| 279 | + with self.assertRaises(NotImplementedError): |
| 280 | + pipe.load_lora_weights(lora_state_dict, "adapter-1") |
| 281 | + |
| 282 | + @require_peft_version_greater("0.13.2") |
| 283 | + def test_lora_B_bias(self): |
| 284 | + components, _, denoiser_lora_config = self.get_dummy_components(FlowMatchEulerDiscreteScheduler) |
| 285 | + pipe = self.pipeline_class(**components) |
| 286 | + pipe = pipe.to(torch_device) |
| 287 | + pipe.set_progress_bar_config(disable=None) |
| 288 | + |
| 289 | + inputs = self.get_dummy_tensor_inputs(torch_device) |
| 290 | + |
| 291 | + logger = logging.get_logger("diffusers.loaders.lora_pipeline") |
| 292 | + logger.setLevel(logging.INFO) |
| 293 | + |
| 294 | + with torch.no_grad(): |
| 295 | + original_output = pipe.transformer(**inputs)[0] |
| 296 | + |
| 297 | + denoiser_lora_config.lora_bias = False |
| 298 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 299 | + with torch.no_grad(): |
| 300 | + lora_bias_false_output = pipe.transformer(**inputs)[0] |
| 301 | + pipe.delete_adapters("adapter-1") |
| 302 | + |
| 303 | + denoiser_lora_config.lora_bias = True |
| 304 | + pipe.transformer.add_adapter(denoiser_lora_config, "adapter-1") |
| 305 | + with torch.no_grad(): |
| 306 | + lora_bias_true_output = pipe.transformer(**inputs)[0] |
| 307 | + |
| 308 | + self.assertFalse(np.allclose(original_output, lora_bias_false_output, atol=1e-3, rtol=1e-3)) |
| 309 | + self.assertFalse(np.allclose(original_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) |
| 310 | + self.assertFalse(np.allclose(lora_bias_false_output, lora_bias_true_output, atol=1e-3, rtol=1e-3)) |
| 311 | + |
159 | 312 | @unittest.skip("Not supported in Flux.")
|
160 | 313 | def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
|
161 | 314 | pass
|
|
0 commit comments