Skip to content

Commit a528e6c

Browse files
committed
Adaptive Projected Guidance
1 parent 31058cd commit a528e6c

File tree

2 files changed

+166
-2
lines changed

2 files changed

+166
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py

+82-1
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,47 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7878
return noise_cfg
7979

8080

81+
class MomentumBuffer:
82+
def __init__(self, momentum: float):
83+
self.momentum = momentum
84+
self.running_average = 0
85+
86+
def update(self, update_value: torch.Tensor):
87+
new_average = self.momentum * self.running_average
88+
self.running_average = update_value + new_average
89+
90+
91+
def normalized_guidance(
92+
pred_cond: torch.Tensor,
93+
pred_uncond: torch.Tensor,
94+
guidance_scale: float,
95+
momentum_buffer: MomentumBuffer = None,
96+
eta: float = 1.0,
97+
norm_threshold: float = 0.0,
98+
):
99+
"""
100+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
101+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
102+
"""
103+
diff = pred_cond - pred_uncond
104+
if momentum_buffer is not None:
105+
momentum_buffer.update(diff)
106+
diff = momentum_buffer.running_average
107+
if norm_threshold > 0:
108+
ones = torch.ones_like(diff)
109+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
110+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
111+
diff = diff * scale_factor
112+
v0, v1 = diff.double(), pred_cond.double()
113+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
114+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
115+
v0_orthogonal = v0 - v0_parallel
116+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
117+
normalized_update = diff_orthogonal + eta * diff_parallel
118+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
119+
return pred_guided
120+
121+
81122
def retrieve_timesteps(
82123
scheduler,
83124
num_inference_steps: Optional[int] = None,
@@ -730,6 +771,18 @@ def guidance_scale(self):
730771
def guidance_rescale(self):
731772
return self._guidance_rescale
732773

774+
@property
775+
def adaptive_projected_guidance(self):
776+
return self._adaptive_projected_guidance
777+
778+
@property
779+
def adaptive_projected_guidance_momentum(self):
780+
return self._adaptive_projected_guidance_momentum
781+
782+
@property
783+
def adaptive_projected_guidance_rescale_factor(self):
784+
return self._adaptive_projected_guidance_rescale_factor
785+
733786
@property
734787
def clip_skip(self):
735788
return self._clip_skip
@@ -777,6 +830,9 @@ def __call__(
777830
return_dict: bool = True,
778831
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
779832
guidance_rescale: float = 0.0,
833+
adaptive_projected_guidance: Optional[bool] = None,
834+
adaptive_projected_guidance_momentum: Optional[float] = -0.5,
835+
adaptive_projected_guidance_rescale_factor: Optional[float] = 15.0,
780836
clip_skip: Optional[int] = None,
781837
callback_on_step_end: Optional[
782838
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
@@ -847,6 +903,13 @@ def __call__(
847903
Guidance rescale factor from [Common Diffusion Noise Schedules and Sample Steps are
848904
Flawed](https://arxiv.org/pdf/2305.08891.pdf). Guidance rescale factor should fix overexposure when
849905
using zero terminal SNR.
906+
adaptive_projected_guidance (`bool`, *optional*):
907+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
908+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
909+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
910+
Momentum to use with adaptive projected guidance. Use `None` to disable momentum.
911+
adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
912+
Rescale factor to use with adaptive projected guidance.
850913
clip_skip (`int`, *optional*):
851914
Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that
852915
the output of the pre-final layer will be used for computing the prompt embeddings.
@@ -910,6 +973,9 @@ def __call__(
910973

911974
self._guidance_scale = guidance_scale
912975
self._guidance_rescale = guidance_rescale
976+
self._adaptive_projected_guidance = adaptive_projected_guidance
977+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
978+
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
913979
self._clip_skip = clip_skip
914980
self._cross_attention_kwargs = cross_attention_kwargs
915981
self._interrupt = False
@@ -992,6 +1058,11 @@ def __call__(
9921058
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
9931059
).to(device=device, dtype=latents.dtype)
9941060

1061+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1062+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1063+
else:
1064+
momentum_buffer = None
1065+
9951066
# 7. Denoising loop
9961067
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
9971068
self._num_timesteps = len(timesteps)
@@ -1018,7 +1089,17 @@ def __call__(
10181089
# perform guidance
10191090
if self.do_classifier_free_guidance:
10201091
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1021-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1092+
if adaptive_projected_guidance:
1093+
noise_pred = normalized_guidance(
1094+
noise_pred_text,
1095+
noise_pred_uncond,
1096+
self.guidance_scale,
1097+
momentum_buffer,
1098+
eta,
1099+
adaptive_projected_guidance_rescale_factor,
1100+
)
1101+
else:
1102+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
10221103

10231104
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
10241105
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

+84-1
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,49 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100100
return noise_cfg
101101

102102

103+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104+
class MomentumBuffer:
105+
def __init__(self, momentum: float):
106+
self.momentum = momentum
107+
self.running_average = 0
108+
109+
def update(self, update_value: torch.Tensor):
110+
new_average = self.momentum * self.running_average
111+
self.running_average = update_value + new_average
112+
113+
114+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
115+
def normalized_guidance(
116+
pred_cond: torch.Tensor,
117+
pred_uncond: torch.Tensor,
118+
guidance_scale: float,
119+
momentum_buffer: MomentumBuffer = None,
120+
eta: float = 1.0,
121+
norm_threshold: float = 0.0,
122+
):
123+
"""
124+
Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
125+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
126+
"""
127+
diff = pred_cond - pred_uncond
128+
if momentum_buffer is not None:
129+
momentum_buffer.update(diff)
130+
diff = momentum_buffer.running_average
131+
if norm_threshold > 0:
132+
ones = torch.ones_like(diff)
133+
diff_norm = diff.norm(p=2, dim=[-1, -2, -3], keepdim=True)
134+
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
135+
diff = diff * scale_factor
136+
v0, v1 = diff.double(), pred_cond.double()
137+
v1 = torch.nn.functional.normalize(v1, dim=[-1, -2, -3])
138+
v0_parallel = (v0 * v1).sum(dim=[-1, -2, -3], keepdim=True) * v1
139+
v0_orthogonal = v0 - v0_parallel
140+
diff_parallel, diff_orthogonal = v0_parallel.to(diff.dtype), v0_orthogonal.to(diff.dtype)
141+
normalized_update = diff_orthogonal + eta * diff_parallel
142+
pred_guided = pred_cond + (guidance_scale - 1) * normalized_update
143+
return pred_guided
144+
145+
103146
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104147
def retrieve_timesteps(
105148
scheduler,
@@ -789,6 +832,18 @@ def guidance_scale(self):
789832
def guidance_rescale(self):
790833
return self._guidance_rescale
791834

835+
@property
836+
def adaptive_projected_guidance(self):
837+
return self._adaptive_projected_guidance
838+
839+
@property
840+
def adaptive_projected_guidance_momentum(self):
841+
return self._adaptive_projected_guidance_momentum
842+
843+
@property
844+
def adaptive_projected_guidance_rescale_factor(self):
845+
return self._adaptive_projected_guidance_rescale_factor
846+
792847
@property
793848
def clip_skip(self):
794849
return self._clip_skip
@@ -845,6 +900,9 @@ def __call__(
845900
return_dict: bool = True,
846901
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
847902
guidance_rescale: float = 0.0,
903+
adaptive_projected_guidance: Optional[bool] = None,
904+
adaptive_projected_guidance_momentum: Optional[float] = -0.5,
905+
adaptive_projected_guidance_rescale_factor: Optional[float] = 15.0,
848906
original_size: Optional[Tuple[int, int]] = None,
849907
crops_coords_top_left: Tuple[int, int] = (0, 0),
850908
target_size: Optional[Tuple[int, int]] = None,
@@ -956,6 +1014,13 @@ def __call__(
9561014
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
9571015
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
9581016
Guidance rescale factor should fix overexposure when using zero terminal SNR.
1017+
adaptive_projected_guidance (`bool`, *optional*):
1018+
Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1019+
in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1020+
adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
1021+
Momentum to use with adaptive projected guidance. Use `None` to disable momentum.
1022+
adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
1023+
Rescale factor to use with adaptive projected guidance.
9591024
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
9601025
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
9611026
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1114,9 @@ def __call__(
10491114

10501115
self._guidance_scale = guidance_scale
10511116
self._guidance_rescale = guidance_rescale
1117+
self._adaptive_projected_guidance = adaptive_projected_guidance
1118+
self._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
1119+
self._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
10521120
self._clip_skip = clip_skip
10531121
self._cross_attention_kwargs = cross_attention_kwargs
10541122
self._denoising_end = denoising_end
@@ -1181,6 +1249,11 @@ def __call__(
11811249
guidance_scale_tensor, embedding_dim=self.unet.config.time_cond_proj_dim
11821250
).to(device=device, dtype=latents.dtype)
11831251

1252+
if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None:
1253+
momentum_buffer = MomentumBuffer(adaptive_projected_guidance_momentum)
1254+
else:
1255+
momentum_buffer = None
1256+
11841257
self._num_timesteps = len(timesteps)
11851258
with self.progress_bar(total=num_inference_steps) as progress_bar:
11861259
for i, t in enumerate(timesteps):
@@ -1209,7 +1282,17 @@ def __call__(
12091282
# perform guidance
12101283
if self.do_classifier_free_guidance:
12111284
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1212-
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
1285+
if adaptive_projected_guidance:
1286+
noise_pred = normalized_guidance(
1287+
noise_pred_text,
1288+
noise_pred_uncond,
1289+
self.guidance_scale,
1290+
momentum_buffer,
1291+
eta,
1292+
adaptive_projected_guidance_rescale_factor,
1293+
)
1294+
else:
1295+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_text - noise_pred_uncond)
12131296

12141297
if self.do_classifier_free_guidance and self.guidance_rescale > 0.0:
12151298
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf

0 commit comments

Comments
 (0)