@@ -395,6 +395,65 @@ def test_simple_inference_with_text_lora_save_load(self):
395
395
"Loading from saved checkpoints should give same results." ,
396
396
)
397
397
398
+ def test_simple_inference_with_partial_text_lora (self ):
399
+ """
400
+ Tests a simple inference with lora attached on the text encoder
401
+ with different ranks and some adapters removed
402
+ and makes sure it works as expected
403
+ """
404
+ for scheduler_cls in [DDIMScheduler , LCMScheduler ]:
405
+ components , _ , _ = self .get_dummy_components (scheduler_cls )
406
+ text_lora_config = LoraConfig (
407
+ r = 4 ,
408
+ rank_pattern = {"q_proj" : 1 , "k_proj" : 2 , "v_proj" : 3 },
409
+ lora_alpha = 4 ,
410
+ target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ],
411
+ init_lora_weights = False ,
412
+ use_dora = False ,
413
+ )
414
+ pipe = self .pipeline_class (** components )
415
+ pipe = pipe .to (torch_device )
416
+ pipe .set_progress_bar_config (disable = None )
417
+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
418
+
419
+ output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 )).images
420
+ self .assertTrue (output_no_lora .shape == (1 , 64 , 64 , 3 ))
421
+
422
+ pipe .text_encoder .add_adapter (text_lora_config )
423
+ self .assertTrue (check_if_lora_correctly_set (pipe .text_encoder ), "Lora not correctly set in text encoder" )
424
+ state_dict = {
425
+ f"text_encoder.{ module_name } " : param
426
+ for module_name , param in get_peft_model_state_dict (pipe .text_encoder ).items ()
427
+ }
428
+
429
+ if self .has_two_text_encoders :
430
+ pipe .text_encoder_2 .add_adapter (text_lora_config )
431
+ self .assertTrue (
432
+ check_if_lora_correctly_set (pipe .text_encoder_2 ), "Lora not correctly set in text encoder 2"
433
+ )
434
+ state_dict .update (
435
+ {
436
+ f"text_encoder.{ module_name } " : param
437
+ for module_name , param in get_peft_model_state_dict (pipe .text_encoder_2 ).items ()
438
+ }
439
+ )
440
+
441
+ # Discard half of the adapters.
442
+ rng = np .random .default_rng (0 )
443
+ key2adapters = {k : k .rsplit ("." , 2 )[0 ] for k in state_dict .keys ()}
444
+ adapters = list (set (key2adapters .values ()))
445
+ adapters = set (rng .choice (adapters , size = len (adapters ) // 2 , replace = False ))
446
+ state_dict = {k : state_dict [k ] for k , adapter in key2adapters .items () if adapter in adapters }
447
+
448
+ # Unload lora and load it back using the pipe.load_lora_weights machinery
449
+ pipe .unload_lora_weights ()
450
+ pipe .load_lora_weights (state_dict )
451
+
452
+ output_lora = pipe (** inputs , generator = torch .manual_seed (0 )).images
453
+ self .assertTrue (
454
+ not np .allclose (output_lora , output_no_lora , atol = 1e-3 , rtol = 1e-3 ), "Lora should change the output"
455
+ )
456
+
398
457
def test_simple_inference_save_pretrained (self ):
399
458
"""
400
459
Tests a simple usecase where users could use saving utilities for LoRA through save_pretrained
0 commit comments