@@ -188,6 +188,7 @@ def apply_guidance(
188
188
self ,
189
189
model_output : torch .Tensor ,
190
190
timestep : int = None ,
191
+ latents : Optional [torch .Tensor ] = None ,
191
192
) -> torch .Tensor :
192
193
if not self .do_classifier_free_guidance :
193
194
return model_output
@@ -476,6 +477,7 @@ def apply_guidance(
476
477
self ,
477
478
model_output : torch .Tensor ,
478
479
timestep : int ,
480
+ latents : Optional [torch .Tensor ] = None ,
479
481
) -> torch .Tensor :
480
482
if not self .do_perturbed_attention_guidance :
481
483
return model_output
@@ -501,3 +503,231 @@ def apply_guidance(
501
503
noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
502
504
503
505
return noise_pred
506
+
507
+
508
+ class MomentumBuffer :
509
+ def __init__ (self , momentum : float ):
510
+ self .momentum = momentum
511
+ self .running_average = 0
512
+
513
+ def update (self , update_value : torch .Tensor ):
514
+ new_average = self .momentum * self .running_average
515
+ self .running_average = update_value + new_average
516
+
517
+
518
+ class APGGuider :
519
+ """
520
+ This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
521
+ """
522
+
523
+ def normalized_guidance (
524
+ self ,
525
+ pred_cond : torch .Tensor ,
526
+ pred_uncond : torch .Tensor ,
527
+ guidance_scale : float ,
528
+ momentum_buffer : MomentumBuffer = None ,
529
+ norm_threshold : float = 0.0 ,
530
+ eta : float = 1.0 ,
531
+ ):
532
+ """
533
+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
534
+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
535
+ """
536
+ diff = pred_cond - pred_uncond
537
+ if momentum_buffer is not None :
538
+ momentum_buffer .update (diff )
539
+ diff = momentum_buffer .running_average
540
+ if norm_threshold > 0 :
541
+ ones = torch .ones_like (diff )
542
+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
543
+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
544
+ diff = diff * scale_factor
545
+ v0 , v1 = diff .double (), pred_cond .double ()
546
+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
547
+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
548
+ v0_orthogonal = v0 - v0_parallel
549
+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
550
+ normalized_update = diff_orthogonal + eta * diff_parallel
551
+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
552
+ return pred_guided
553
+
554
+ @property
555
+ def adaptive_projected_guidance_momentum (self ):
556
+ return self ._adaptive_projected_guidance_momentum
557
+
558
+ @property
559
+ def adaptive_projected_guidance_rescale_factor (self ):
560
+ return self ._adaptive_projected_guidance_rescale_factor
561
+
562
+ @property
563
+ def do_classifier_free_guidance (self ):
564
+ return self ._guidance_scale > 1.0 and not self ._disable_guidance
565
+
566
+ @property
567
+ def guidance_rescale (self ):
568
+ return self ._guidance_rescale
569
+
570
+ @property
571
+ def guidance_scale (self ):
572
+ return self ._guidance_scale
573
+
574
+ @property
575
+ def batch_size (self ):
576
+ return self ._batch_size
577
+
578
+ def set_guider (self , pipeline , guider_kwargs : Dict [str , Any ]):
579
+ disable_guidance = guider_kwargs .get ("disable_guidance" , False )
580
+ guidance_scale = guider_kwargs .get ("guidance_scale" , None )
581
+ if guidance_scale is None :
582
+ raise ValueError ("guidance_scale is not provided in guider_kwargs" )
583
+ adaptive_projected_guidance_momentum = guider_kwargs .get ("adaptive_projected_guidance_momentum" , None )
584
+ adaptive_projected_guidance_rescale_factor = guider_kwargs .get (
585
+ "adaptive_projected_guidance_rescale_factor" , 15.0
586
+ )
587
+ guidance_rescale = guider_kwargs .get ("guidance_rescale" , 0.0 )
588
+ batch_size = guider_kwargs .get ("batch_size" , None )
589
+ if batch_size is None :
590
+ raise ValueError ("batch_size is not provided in guider_kwargs" )
591
+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
592
+ self ._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
593
+ self ._guidance_scale = guidance_scale
594
+ self ._guidance_rescale = guidance_rescale
595
+ self ._batch_size = batch_size
596
+ self ._disable_guidance = disable_guidance
597
+ if adaptive_projected_guidance_momentum is not None :
598
+ self .momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
599
+ else :
600
+ self .momentum_buffer = None
601
+ self .scheduler = pipeline .scheduler
602
+
603
+ def reset_guider (self , pipeline ):
604
+ pass
605
+
606
+ def maybe_update_guider (self , pipeline , timestep ):
607
+ pass
608
+
609
+ def maybe_update_input (self , pipeline , cond_input ):
610
+ pass
611
+
612
+ def _maybe_split_prepared_input (self , cond ):
613
+ """
614
+ Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
615
+
616
+ This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
617
+ It determines whether to split the input based on its batch size relative to the expected batch size.
618
+
619
+ Args:
620
+ cond (torch.Tensor): The conditional input tensor to process.
621
+
622
+ Returns:
623
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
624
+ - The negative conditional input (uncond_input)
625
+ - The positive conditional input (cond_input)
626
+ """
627
+ if cond .shape [0 ] == self .batch_size * 2 :
628
+ neg_cond = cond [0 : self .batch_size ]
629
+ cond = cond [self .batch_size :]
630
+ return neg_cond , cond
631
+ elif cond .shape [0 ] == self .batch_size :
632
+ return cond , cond
633
+ else :
634
+ raise ValueError (f"Unsupported input shape: { cond .shape } " )
635
+
636
+ def _is_prepared_input (self , cond ):
637
+ """
638
+ Check if the input is already prepared for Classifier-Free Guidance (CFG).
639
+
640
+ Args:
641
+ cond (torch.Tensor): The conditional input tensor to check.
642
+
643
+ Returns:
644
+ bool: True if the input is already prepared, False otherwise.
645
+ """
646
+ cond_tensor = cond [0 ] if isinstance (cond , (list , tuple )) else cond
647
+
648
+ return cond_tensor .shape [0 ] == self .batch_size * 2
649
+
650
+ def prepare_input (
651
+ self ,
652
+ cond_input : Union [torch .Tensor , List [torch .Tensor ]],
653
+ negative_cond_input : Optional [Union [torch .Tensor , List [torch .Tensor ]]] = None ,
654
+ ) -> Union [torch .Tensor , List [torch .Tensor ]]:
655
+ """
656
+ Prepare the input for CFG.
657
+
658
+ Args:
659
+ cond_input (Union[torch.Tensor, List[torch.Tensor]]):
660
+ The conditional input. It can be a single tensor or a
661
+ list of tensors. It must have the same length as `negative_cond_input`.
662
+ negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
663
+ single tensor or a list of tensors. It must have the same length as `cond_input`.
664
+
665
+ Returns:
666
+ Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
667
+ """
668
+
669
+ # we check if cond_input already has CFG applied, and split if it is the case.
670
+ if self ._is_prepared_input (cond_input ) and self .do_classifier_free_guidance :
671
+ return cond_input
672
+
673
+ if self ._is_prepared_input (cond_input ) and not self .do_classifier_free_guidance :
674
+ if isinstance (cond_input , list ):
675
+ negative_cond_input , cond_input = zip (* [self ._maybe_split_prepared_input (cond ) for cond in cond_input ])
676
+ else :
677
+ negative_cond_input , cond_input = self ._maybe_split_prepared_input (cond_input )
678
+
679
+ if not self ._is_prepared_input (cond_input ) and negative_cond_input is None :
680
+ raise ValueError (
681
+ "`negative_cond_input` is required when cond_input does not already contains negative conditional input"
682
+ )
683
+
684
+ if isinstance (cond_input , (list , tuple )):
685
+ if not self .do_classifier_free_guidance :
686
+ return cond_input
687
+
688
+ if len (negative_cond_input ) != len (cond_input ):
689
+ raise ValueError ("The length of negative_cond_input and cond_input must be the same." )
690
+ prepared_input = []
691
+ for neg_cond , cond in zip (negative_cond_input , cond_input ):
692
+ if neg_cond .shape [0 ] != cond .shape [0 ]:
693
+ raise ValueError ("The batch size of negative_cond_input and cond_input must be the same." )
694
+ prepared_input .append (torch .cat ([neg_cond , cond ], dim = 0 ))
695
+ return prepared_input
696
+
697
+ elif isinstance (cond_input , torch .Tensor ):
698
+ if not self .do_classifier_free_guidance :
699
+ return cond_input
700
+ else :
701
+ return torch .cat ([negative_cond_input , cond_input ], dim = 0 )
702
+
703
+ else :
704
+ raise ValueError (f"Unsupported input type: { type (cond_input )} " )
705
+
706
+ def apply_guidance (
707
+ self ,
708
+ model_output : torch .Tensor ,
709
+ timestep : int = None ,
710
+ latents : Optional [torch .Tensor ] = None ,
711
+ ) -> torch .Tensor :
712
+ if not self .do_classifier_free_guidance :
713
+ return model_output
714
+
715
+ if latents is None :
716
+ raise ValueError ("APG requires `latents` to convert model output to denoised prediction (x0)." )
717
+
718
+ sigma = self .scheduler .sigmas [self .scheduler .step_index ]
719
+ noise_pred = latents - sigma * model_output
720
+ noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
721
+ noise_pred = self .normalized_guidance (
722
+ noise_pred_text ,
723
+ noise_pred_uncond ,
724
+ self .guidance_scale ,
725
+ self .momentum_buffer ,
726
+ self .adaptive_projected_guidance_rescale_factor ,
727
+ )
728
+ noise_pred = (latents - noise_pred ) / sigma
729
+
730
+ if self .guidance_rescale > 0.0 :
731
+ # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
732
+ noise_pred = rescale_noise_cfg (noise_pred , noise_pred_text , guidance_rescale = self .guidance_rescale )
733
+ return noise_pred
0 commit comments