@@ -100,6 +100,49 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
100
100
return noise_cfg
101
101
102
102
103
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.MomentumBuffer
104
+ class MomentumBuffer :
105
+ def __init__ (self , momentum : float ):
106
+ self .momentum = momentum
107
+ self .running_average = 0
108
+
109
+ def update (self , update_value : torch .Tensor ):
110
+ new_average = self .momentum * self .running_average
111
+ self .running_average = update_value + new_average
112
+
113
+
114
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.normalized_guidance
115
+ def normalized_guidance (
116
+ pred_cond : torch .Tensor ,
117
+ pred_uncond : torch .Tensor ,
118
+ guidance_scale : float ,
119
+ momentum_buffer : MomentumBuffer = None ,
120
+ eta : float = 1.0 ,
121
+ norm_threshold : float = 0.0 ,
122
+ ):
123
+ """
124
+ Based on the findings of [Eliminating Oversaturation and Artifacts of High Guidance Scales
125
+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
126
+ """
127
+ diff = pred_cond - pred_uncond
128
+ if momentum_buffer is not None :
129
+ momentum_buffer .update (diff )
130
+ diff = momentum_buffer .running_average
131
+ if norm_threshold > 0 :
132
+ ones = torch .ones_like (diff )
133
+ diff_norm = diff .norm (p = 2 , dim = [- 1 , - 2 , - 3 ], keepdim = True )
134
+ scale_factor = torch .minimum (ones , norm_threshold / diff_norm )
135
+ diff = diff * scale_factor
136
+ v0 , v1 = diff .double (), pred_cond .double ()
137
+ v1 = torch .nn .functional .normalize (v1 , dim = [- 1 , - 2 , - 3 ])
138
+ v0_parallel = (v0 * v1 ).sum (dim = [- 1 , - 2 , - 3 ], keepdim = True ) * v1
139
+ v0_orthogonal = v0 - v0_parallel
140
+ diff_parallel , diff_orthogonal = v0_parallel .to (diff .dtype ), v0_orthogonal .to (diff .dtype )
141
+ normalized_update = diff_orthogonal + eta * diff_parallel
142
+ pred_guided = pred_cond + (guidance_scale - 1 ) * normalized_update
143
+ return pred_guided
144
+
145
+
103
146
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
104
147
def retrieve_timesteps (
105
148
scheduler ,
@@ -789,6 +832,18 @@ def guidance_scale(self):
789
832
def guidance_rescale (self ):
790
833
return self ._guidance_rescale
791
834
835
+ @property
836
+ def adaptive_projected_guidance (self ):
837
+ return self ._adaptive_projected_guidance
838
+
839
+ @property
840
+ def adaptive_projected_guidance_momentum (self ):
841
+ return self ._adaptive_projected_guidance_momentum
842
+
843
+ @property
844
+ def adaptive_projected_guidance_rescale_factor (self ):
845
+ return self ._adaptive_projected_guidance_rescale_factor
846
+
792
847
@property
793
848
def clip_skip (self ):
794
849
return self ._clip_skip
@@ -845,6 +900,9 @@ def __call__(
845
900
return_dict : bool = True ,
846
901
cross_attention_kwargs : Optional [Dict [str , Any ]] = None ,
847
902
guidance_rescale : float = 0.0 ,
903
+ adaptive_projected_guidance : Optional [bool ] = None ,
904
+ adaptive_projected_guidance_momentum : Optional [float ] = - 0.5 ,
905
+ adaptive_projected_guidance_rescale_factor : Optional [float ] = 15.0 ,
848
906
original_size : Optional [Tuple [int , int ]] = None ,
849
907
crops_coords_top_left : Tuple [int , int ] = (0 , 0 ),
850
908
target_size : Optional [Tuple [int , int ]] = None ,
@@ -956,6 +1014,13 @@ def __call__(
956
1014
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
957
1015
[Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
958
1016
Guidance rescale factor should fix overexposure when using zero terminal SNR.
1017
+ adaptive_projected_guidance (`bool`, *optional*):
1018
+ Use adaptive projected guidance from [Eliminating Oversaturation and Artifacts of High Guidance Scales
1019
+ in Diffusion Models](https://arxiv.org/pdf/2410.02416)
1020
+ adaptive_projected_guidance_momentum (`float`, *optional*, defaults to `-0.5`):
1021
+ Momentum to use with adaptive projected guidance. Use `None` to disable momentum.
1022
+ adaptive_projected_guidance_rescale_factor (`float`, *optional*, defaults to `15.0`):
1023
+ Rescale factor to use with adaptive projected guidance.
959
1024
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
960
1025
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
961
1026
`original_size` defaults to `(height, width)` if not specified. Part of SDXL's micro-conditioning as
@@ -1049,6 +1114,9 @@ def __call__(
1049
1114
1050
1115
self ._guidance_scale = guidance_scale
1051
1116
self ._guidance_rescale = guidance_rescale
1117
+ self ._adaptive_projected_guidance = adaptive_projected_guidance
1118
+ self ._adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
1119
+ self ._adaptive_projected_guidance_rescale_factor = adaptive_projected_guidance_rescale_factor
1052
1120
self ._clip_skip = clip_skip
1053
1121
self ._cross_attention_kwargs = cross_attention_kwargs
1054
1122
self ._denoising_end = denoising_end
@@ -1181,6 +1249,11 @@ def __call__(
1181
1249
guidance_scale_tensor , embedding_dim = self .unet .config .time_cond_proj_dim
1182
1250
).to (device = device , dtype = latents .dtype )
1183
1251
1252
+ if adaptive_projected_guidance and adaptive_projected_guidance_momentum is not None :
1253
+ momentum_buffer = MomentumBuffer (adaptive_projected_guidance_momentum )
1254
+ else :
1255
+ momentum_buffer = None
1256
+
1184
1257
self ._num_timesteps = len (timesteps )
1185
1258
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1186
1259
for i , t in enumerate (timesteps ):
@@ -1209,7 +1282,17 @@ def __call__(
1209
1282
# perform guidance
1210
1283
if self .do_classifier_free_guidance :
1211
1284
noise_pred_uncond , noise_pred_text = noise_pred .chunk (2 )
1212
- noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1285
+ if adaptive_projected_guidance :
1286
+ noise_pred = normalized_guidance (
1287
+ noise_pred_text ,
1288
+ noise_pred_uncond ,
1289
+ self .guidance_scale ,
1290
+ momentum_buffer ,
1291
+ eta ,
1292
+ adaptive_projected_guidance_rescale_factor ,
1293
+ )
1294
+ else :
1295
+ noise_pred = noise_pred_uncond + self .guidance_scale * (noise_pred_text - noise_pred_uncond )
1213
1296
1214
1297
if self .do_classifier_free_guidance and self .guidance_rescale > 0.0 :
1215
1298
# Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf
0 commit comments