Description
**Is your feature request related to a problem? No
Describe the solution you'd like.
A clear and concise description of what you want to happen.
Incorporate a way to add sage attention to the diffusers library: Flux pipeline, Wan pipeline, etc.
Describe alternatives you've considered.
None
Additional context.
When I incorporated sage attention in the flux pipeline (text to image) I achieved a 16% speed advantage vs no sage attention.
My environment was the same save for including / excluding sage attention in my 4 image benchmark creation.
How to incorporate sage attention? We must consider that this only applies to the Transformer. With this in mind I did the following to the FluxPipeline. Obviously there must be a way to do this via a variable of sorts so that we may/may not run it:
Need some kind of indicator to decide whether to include or not! This must be done before the denoising step in the model pipeline.
` import torch.nn.functional as F
sage_function = False
try:
from sageattention import sageattn
self.transformer.scaled_dot_product_attention = F.scaled_dot_product_attention = sageattn
sage_function = True
except (ImportError):
pass
# 6. Denoising loop
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
if self.interrupt:
continue
`
After the denoising step we must remove sage attention else we get a VAE error due to Sage Attn wanting only torch.float16 or torch.bfloat16 dtypes which the VAE doesn't want:
if output_type == "latent": image = latents else: if sage_function: self.transformer.scaled_dot_product_attention = F.scaled_dot_product_attention = torch._C._nn.scaled_dot_product_attention
Hopefully this helps.