Skip to content

Adaptive Projected Guidance #9626

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,47 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average


def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
):
"""
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
"""
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided


def retrieve_timesteps(
scheduler,
num_inference_steps: Optional[int] = None,
Expand Down Expand Up @@ -742,6 +783,18 @@ def guidance_scale(self):
def guidance_rescale(self):
return self._guidance_rescale

@property
def adaptive_projected_guidance(self):
return self._adaptive_projected_guidance

@property
def adaptive_projected_guidance_momentum(self):
return self._adaptive_projected_guidance_momentum

@property
def adaptive_projected_guidance_rescale_factor(self):
return self._adaptive_projected_guidance_rescale_factor

@property
def clip_skip(self):
return self._clip_skip
Expand Down Expand Up @@ -789,6 +842,9 @@ def __call__(
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
adaptive_projected_guidance: Optional[bool] = None,
adaptive_projected_guidance_momentum: Optional[float] = -0.5,
adaptive_projected_guidance_rescale_factor: Optional[float] = 15.0,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
Expand Down Expand Up @@ -859,6 +915,13 @@ def __call__(
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
using zero terminal SNR.
adaptive_projected_guidance (`bool`, *optional*):
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
Momentum to use with adaptive projected guidance. Use `None` to disable momentum.
adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
Rescale factor to use with adaptive projected guidance.
clip_skip (`int`, *optional*):
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
the output of the pre-final layer will be used for computing the prompt embeddings.
Expand Down Expand Up @@ -922,6 +985,9 @@ def __call__(

self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._adaptive_projected_guidance = adaptive_projected_guidance
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._interrupt = False
Expand Down Expand Up @@ -1004,6 +1070,11 @@ def __call__(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)

if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
else:
momentum_buffer = None

# 7. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
self._num_timesteps = len(timesteps)
Expand All @@ -1029,8 +1100,22 @@ def __call__(

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if adaptive_projected_guidance:
sigma = self.scheduler.sigmas[self.scheduler.step_index]
noise_pred = latents - sigma * noise_pred
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = normalized_guidance(
noise_pred_text,
noise_pred_uncond,
self.guidance_scale,
momentum_buffer,
eta,
adaptive_projected_guidance_rescale_factor,
)
noise_pred = (latents - noise_pred) / sigma
else:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,49 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
return noise_cfg


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: MomentumBuffer = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
):
"""
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
"""
diff = pred_cond - pred_uncond
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
normalized_update = diff_orthogonal + eta * diff_parallel
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
return pred_guided


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
def retrieve_timesteps(
scheduler,
Expand Down Expand Up @@ -801,6 +844,18 @@ def guidance_scale(self):
def guidance_rescale(self):
return self._guidance_rescale

@property
def adaptive_projected_guidance(self):
return self._adaptive_projected_guidance

@property
def adaptive_projected_guidance_momentum(self):
return self._adaptive_projected_guidance_momentum

@property
def adaptive_projected_guidance_rescale_factor(self):
return self._adaptive_projected_guidance_rescale_factor

@property
def clip_skip(self):
return self._clip_skip
Expand Down Expand Up @@ -857,6 +912,9 @@ def __call__(
return_dict: bool = True,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
guidance_rescale: float = 0.0,
adaptive_projected_guidance: Optional[bool] = None,
adaptive_projected_guidance_momentum: Optional[float] = -0.5,
adaptive_projected_guidance_rescale_factor: Optional[float] = 15.0,
original_size: Optional[Tuple[int, int]] = None,
crops_coords_top_left: Tuple[int, int] = (0, 0),
target_size: Optional[Tuple[int, int]] = None,
Expand Down Expand Up @@ -968,6 +1026,13 @@ def __call__(
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
Guidance rescale factor should fix overexposure when using zero terminal SNR.
adaptive_projected_guidance (`bool`, *optional*):
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
Momentum to use with adaptive projected guidance. Use `None` to disable momentum.
adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
Rescale factor to use with adaptive projected guidance.
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
Expand Down Expand Up @@ -1061,6 +1126,9 @@ def __call__(

self._guidance_scale = guidance_scale
self._guidance_rescale = guidance_rescale
self._adaptive_projected_guidance = adaptive_projected_guidance
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
self._clip_skip = clip_skip
self._cross_attention_kwargs = cross_attention_kwargs
self._denoising_end = denoising_end
Expand Down Expand Up @@ -1193,6 +1261,11 @@ def __call__(
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
).to(device=device, dtype=latents.dtype)

if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
else:
momentum_buffer = None

self._num_timesteps = len(timesteps)
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
Expand Down Expand Up @@ -1220,8 +1293,22 @@ def __call__(

# perform guidance
if self.do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
if adaptive_projected_guidance:
sigma = self.scheduler.sigmas[self.scheduler.step_index]
noise_pred = latents - sigma * noise_pred
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = normalized_guidance(
noise_pred_text,
noise_pred_uncond,
self.guidance_scale,
momentum_buffer,
eta,
adaptive_projected_guidance_rescale_factor,
)
noise_pred = (latents - noise_pred) / sigma
else:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)

if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
Expand Down
Loading