Skip to content

Modular APG #10173

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions src/diffusers/guider.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_classifier_free_guidance:
return model_output
Expand Down Expand Up @@ -476,6 +477,7 @@ def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_perturbed_attention_guidance:
return model_output
Expand All @@ -501,3 +503,231 @@ def apply_guidance(
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)

return noise_pred


class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average


class APGGuider:
"""
This class is used to guide the pipeline with APG (Adaptive Projected Guidance).
"""

def normalized_guidance(
self,
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
norm_threshold: float = 0.0,
eta: float = 1.0,
):
"""
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
"""
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided

@property
def adaptive_projected_guidance_momentum(self):
return self._adaptive_projected_guidance_momentum

@property
def adaptive_projected_guidance_rescale_factor(self):
return self._adaptive_projected_guidance_rescale_factor

@property
def do_classifier_free_guidance(self):
return self._guidance_scale > 1.0 and not self._disable_guidance

@property
def guidance_rescale(self):
return self._guidance_rescale

@property
def guidance_scale(self):
return self._guidance_scale

@property
def batch_size(self):
return self._batch_size

def set_guider(self, pipeline, guider_kwargs: Dict[str, Any]):
disable_guidance = guider_kwargs.get("disable_guidance", False)
guidance_scale = guider_kwargs.get("guidance_scale", None)
if guidance_scale is None:
raise ValueError("guidance_scale is not provided in guider_kwargs")
adaptive_projected_guidance_momentum = guider_kwargs.get("adaptive_projected_guidance_momentum", None)
adaptive_projected_guidance_rescale_factor = guider_kwargs.get(
"adaptive_projected_guidance_rescale_factor", 15.0
)
guidance_rescale = guider_kwargs.get("guidance_rescale", 0.0)
batch_size = guider_kwargs.get("batch_size", None)
if batch_size is None:
raise ValueError("batch_size is not provided in guider_kwargs")
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._batch_size = batch_size
self._disable_guidance = disable_guidance
if adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
else:
self.momentum_buffer = None
self.scheduler = pipeline.scheduler

def reset_guider(self, pipeline):
pass

def maybe_update_guider(self, pipeline, timestep):
pass

def maybe_update_input(self, pipeline, cond_input):
pass

def _maybe_split_prepared_input(self, cond):
"""
Process and potentially split the conditional input for Classifier-Free Guidance (CFG).

This method handles inputs that may already have CFG applied (i.e. when `cond` is output of `prepare_input`).
It determines whether to split the input based on its batch size relative to the expected batch size.

Args:
cond (torch.Tensor): The conditional input tensor to process.

Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The negative conditional input (uncond_input)
- The positive conditional input (cond_input)
"""
if cond.shape[0] == self.batch_size * 2:
neg_cond = cond[0 : self.batch_size]
cond = cond[self.batch_size :]
return neg_cond, cond
elif cond.shape[0] == self.batch_size:
return cond, cond
else:
raise ValueError(f"Unsupported input shape: {cond.shape}")

def _is_prepared_input(self, cond):
"""
Check if the input is already prepared for Classifier-Free Guidance (CFG).

Args:
cond (torch.Tensor): The conditional input tensor to check.

Returns:
bool: True if the input is already prepared, False otherwise.
"""
cond_tensor = cond[0] if isinstance(cond, (list, tuple)) else cond

return cond_tensor.shape[0] == self.batch_size * 2

def prepare_input(
self,
cond_input: Union[torch.Tensor, List[torch.Tensor]],
negative_cond_input: Optional[Union[torch.Tensor, List[torch.Tensor]]] = None,
) -> Union[torch.Tensor, List[torch.Tensor]]:
"""
Prepare the input for CFG.

Args:
cond_input (Union[torch.Tensor, List[torch.Tensor]]):
The conditional input. It can be a single tensor or a
list of tensors. It must have the same length as `negative_cond_input`.
negative_cond_input (Union[torch.Tensor, List[torch.Tensor]]): The negative conditional input. It can be a
single tensor or a list of tensors. It must have the same length as `cond_input`.

Returns:
Union[torch.Tensor, List[torch.Tensor]]: The prepared input.
"""

# we check if cond_input already has CFG applied, and split if it is the case.
if self._is_prepared_input(cond_input) and self.do_classifier_free_guidance:
return cond_input

if self._is_prepared_input(cond_input) and not self.do_classifier_free_guidance:
if isinstance(cond_input, list):
negative_cond_input, cond_input = zip(*[self._maybe_split_prepared_input(cond) for cond in cond_input])
else:
negative_cond_input, cond_input = self._maybe_split_prepared_input(cond_input)

if not self._is_prepared_input(cond_input) and negative_cond_input is None:
raise ValueError(
"`negative_cond_input` is required when cond_input does not already contains negative conditional input"
)

if isinstance(cond_input, (list, tuple)):
if not self.do_classifier_free_guidance:
return cond_input

if len(negative_cond_input) != len(cond_input):
raise ValueError("The length of negative_cond_input and cond_input must be the same.")
prepared_input = []
for neg_cond, cond in zip(negative_cond_input, cond_input):
if neg_cond.shape[0] != cond.shape[0]:
raise ValueError("The batch size of negative_cond_input and cond_input must be the same.")
prepared_input.append(torch.cat([neg_cond, cond], dim=0))
return prepared_input

elif isinstance(cond_input, torch.Tensor):
if not self.do_classifier_free_guidance:
return cond_input
else:
return torch.cat([negative_cond_input, cond_input], dim=0)

else:
raise ValueError(f"Unsupported input type: {type(cond_input)}")

def apply_guidance(
self,
model_output: torch.Tensor,
timestep: int = None,
latents: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if not self.do_classifier_free_guidance:
return model_output

if latents is None:
raise ValueError("APG requires `latents` to convert model output to denoised prediction (x0).")

sigma = self.scheduler.sigmas[self.scheduler.step_index]
noise_pred = latents - sigma * model_output
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = self.normalized_guidance(
noise_pred_text,
noise_pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.adaptive_projected_guidance_rescale_factor,
)
noise_pred = (latents - noise_pred) / sigma

if self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
noise_pred = rescale_noise_cfg(noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale)
return noise_pred
Original file line number Diff line number Diff line change
Expand Up @@ -926,6 +926,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
noise_pred = pipeline.guider.apply_guidance(
noise_pred,
timestep=t,
latents=latents,
)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
Expand Down Expand Up @@ -1213,7 +1214,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
return_dict=False,
)[0]
# perform guidance
noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t)
noise_pred = pipeline.guider.apply_guidance(noise_pred, timestep=t, latents=latents)
# compute the previous noisy sample x_t -> x_t-1
latents_dtype = latents.dtype
latents = pipeline.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
Expand Down