Skip to content

Commit a8df0f1

Browse files
authored
Modular APG (#10173)
1 parent ace53e2 commit a8df0f1

File tree

2 files changed

+232
-1
lines changed

2 files changed

+232
-1
lines changed

src/diffusers/guider.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,7 @@ def apply_guidance(
188188
self,
189189
model_output: torch.Tensor,
190190
timestep: int = None,
191+
latents: Optional[torch.Tensor] = None,
191192
) -> torch.Tensor:
192193
if not self.do_classifier_free_guidance:
193194
return model_output
@@ -476,6 +477,7 @@ def apply_guidance(
476477
self,
477478
model_output: torch.Tensor,
478479
timestep: int,
480+
latents: Optional[torch.Tensor] = None,
479481
) -> torch.Tensor:
480482
if not self.do_perturbed_attention_guidance:
481483
return model_output
@@ -501,3 +503,231 @@ def apply_guidance(
501503
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
502504

503505
return noise_pred
506+
507+
508+
class MomentumBuffer:
509+
def __init__(self, momentum: float):
510+
self.momentum = momentum
511+
self.running_average = 0
512+
513+
def update(self, update_value: torch.Tensor):
514+
new_average = self.momentum * self.running_average
515+
self.running_average = update_value + new_average
516+
517+
518+
class APGGuider:
519+
"""
520+
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
521+
"""
522+
523+
def normalized_guidance(
524+
self,
525+
pred_cond: torch.Tensor,
526+
pred_uncond: torch.Tensor,
527+
guidance_scale: float,
528+
momentum_buffer: MomentumBuffer = None,
529+
norm_threshold: float = 0.0,
530+
eta: float = 1.0,
531+
):
532+
"""
533+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
534+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
535+
"""
536+
diff = pred_cond - pred_uncond
537+
if momentum_buffer is not None:
538+
momentum_buffer.update(diff)
539+
diff = momentum_buffer.running_average
540+
if norm_threshold > 0:
541+
ones = torch.ones_like(diff)
542+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
543+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
544+
diff = diff * scale_factor
545+
v0, v1 = diff.double(), pred_cond.double()
546+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
547+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
548+
v0_orthogonal = v0 - v0_parallel
549+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
550+
normalized_update = diff_orthogonal + eta * diff_parallel
551+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
552+
return pred_guided
553+
554+
@property
555+
def adaptive_projected_guidance_momentum(self):
556+
return self._adaptive_projected_guidance_momentum
557+
558+
@property
559+
def adaptive_projected_guidance_rescale_factor(self):
560+
return self._adaptive_projected_guidance_rescale_factor
561+
562+
@property
563+
def do_classifier_free_guidance(self):
564+
return self._guidance_scale > 1.0 and not self._disable_guidance
565+
566+
@property
567+
def guidance_rescale(self):
568+
return self._guidance_rescale
569+
570+
@property
571+
def guidance_scale(self):
572+
return self._guidance_scale
573+
574+
@property
575+
def batch_size(self):
576+
return self._batch_size
577+
578+
def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
579+
disable_guidance = guider_kwargs.get("disable_guidance", False)
580+
guidance_scale = guider_kwargs.get("guidance_scale", None)
581+
if guidance_scale is None:
582+
raise ValueError("guidance_scale is not provided in guider_kwargs")
583+
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
584+
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
585+
"adaptive_projected_guidance_rescale_factor", 15.0
586+
)
587+
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
588+
batch_size = guider_kwargs.get("batch_size", None)
589+
if batch_size is None:
590+
raise ValueError("batch_size is not provided in guider_kwargs")
591+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
592+
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
593+
self._guidance_scale = guidance_scale
594+
self._guidance_rescale = guidance_rescale
595+
self._batch_size = batch_size
596+
self._disable_guidance = disable_guidance
597+
if adaptive_projected_guidance_momentum is not None:
598+
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
599+
else:
600+
self.momentum_buffer = None
601+
self.scheduler = pipeline.scheduler
602+
603+
def reset_guider(self, pipeline):
604+
pass
605+
606+
def maybe_update_guider(self, pipeline, timestep):
607+
pass
608+
609+
def maybe_update_input(self, pipeline, cond_input):
610+
pass
611+
612+
def _maybe_split_prepared_input(self, cond):
613+
"""
614+
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).
615+
616+
This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
617+
It determines whether to split the input based on its batch size relative to the expected batch size.
618+
619+
Args:
620+
cond (torch.Tensor): The conditional input tensor to process.
621+
622+
Returns:
623+
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
624+
- The negative conditional input (uncond_input)
625+
- The positive conditional input (cond_input)
626+
"""
627+
if cond.shape[0] == self.batch_size * 2:
628+
neg_cond = cond[0 : self.batch_size]
629+
cond = cond[self.batch_size :]
630+
return neg_cond, cond
631+
elif cond.shape[0] == self.batch_size:
632+
return cond, cond
633+
else:
634+
raise ValueError(f"Unsupported input shape: {cond.shape}")
635+
636+
def _is_prepared_input(self, cond):
637+
"""
638+
Check if the input is already prepared for Classifier-Free Guidance (CFG).
639+
640+
Args:
641+
cond (torch.Tensor): The conditional input tensor to check.
642+
643+
Returns:
644+
bool: True if the input is already prepared, False otherwise.
645+
"""
646+
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond
647+
648+
return cond_tensor.shape[0] == self.batch_size * 2
649+
650+
def prepare_input(
651+
self,
652+
cond_input: Union[torch.Tensor, List[torch.Tensor]],
653+
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
654+
) -> Union[torch.Tensor, List[torch.Tensor]]:
655+
"""
656+
Prepare the input for CFG.
657+
658+
Args:
659+
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
660+
The conditional input. It can be a single tensor or a
661+
list of tensors. It must have the same length as `negative_cond_input`.
662+
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
663+
single tensor or a list of tensors. It must have the same length as `cond_input`.
664+
665+
Returns:
666+
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
667+
"""
668+
669+
# we check if cond_input already has CFG applied, and split if it is the case.
670+
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
671+
return cond_input
672+
673+
if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
674+
if isinstance(cond_input, list):
675+
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
676+
else:
677+
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)
678+
679+
if not self._is_prepared_input(cond_input) and negative_cond_input is None:
680+
raise ValueError(
681+
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
682+
)
683+
684+
if isinstance(cond_input, (list, tuple)):
685+
if not self.do_classifier_free_guidance:
686+
return cond_input
687+
688+
if len(negative_cond_input) != len(cond_input):
689+
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
690+
prepared_input = []
691+
for neg_cond, cond in zip(negative_cond_input, cond_input):
692+
if neg_cond.shape[0] != cond.shape[0]:
693+
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
694+
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
695+
return prepared_input
696+
697+
elif isinstance(cond_input, torch.Tensor):
698+
if not self.do_classifier_free_guidance:
699+
return cond_input
700+
else:
701+
return torch.cat([negative_cond_input, cond_input], dim=0)
702+
703+
else:
704+
raise ValueError(f"Unsupported input type: {type(cond_input)}")
705+
706+
def apply_guidance(
707+
self,
708+
model_output: torch.Tensor,
709+
timestep: int = None,
710+
latents: Optional[torch.Tensor] = None,
711+
) -> torch.Tensor:
712+
if not self.do_classifier_free_guidance:
713+
return model_output
714+
715+
if latents is None:
716+
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")
717+
718+
sigma = self.scheduler.sigmas[self.scheduler.step_index]
719+
noise_pred = latents - sigma * model_output
720+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
721+
noise_pred = self.normalized_guidance(
722+
noise_pred_text,
723+
noise_pred_uncond,
724+
self.guidance_scale,
725+
self.momentum_buffer,
726+
self.adaptive_projected_guidance_rescale_factor,
727+
)
728+
noise_pred = (latents - noise_pred) / sigma
729+
730+
if self.guidance_rescale > 0.0:
731+
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
732+
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
733+
return noise_pred

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -926,6 +926,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
926926
noise_pred = pipeline.guider.apply_guidance(
927927
noise_pred,
928928
timestep=t,
929+
latents=latents,
929930
)
930931
# compute the previous noisy sample x_t -> x_t-1
931932
latents_dtype = latents.dtype
@@ -1213,7 +1214,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
12131214
return_dict=False,
12141215
)[0]
12151216
# perform guidance
1216-
noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t)
1217+
noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents)
12171218
# compute the previous noisy sample x_t -> x_t-1
12181219
latents_dtype = latents.dtype
12191220
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]

0 commit comments

Comments
 (0)