Skip to content

Sage Attention for diffuser library #11168

Open
@ukaprch

Description

@ukaprch

**Is your feature request related to a problem? No

Describe the solution you'd like.
A clear and concise description of what you want to happen.
Incorporate a way to add sage attention to the diffusers library: Flux pipeline, Wan pipeline, etc.

Describe alternatives you've considered.
None

Additional context.
When I incorporated sage attention in the flux pipeline (text to image) I achieved a 16% speed advantage vs no sage attention.
My environment was the same save for including / excluding sage attention in my 4 image benchmark creation.

How to incorporate sage attention? We must consider that this only applies to the Transformer. With this in mind I did the following to the FluxPipeline. Obviously there must be a way to do this via a variable of sorts so that we may/may not run it:

Need some kind of indicator to decide whether to include or not! This must be done before the denoising step in the model pipeline.
` import torch.nn.functional as F
sage_function = False
try:
from sageattention import sageattn
self.transformer.scaled_dot_product_attention = F.scaled_dot_product_attention = sageattn
sage_function = True
except (ImportError):
pass

    # 6. Denoising loop
    with self.progress_bar(total=num_inference_steps) as progress_bar:
        for i, t in enumerate(timesteps):
            if self.interrupt:
                continue

`
After the denoising step we must remove sage attention else we get a VAE error due to Sage Attn wanting only torch.float16 or torch.bfloat16 dtypes which the VAE doesn't want:

if output_type == "latent": image = latents else: if sage_function: self.transformer.scaled_dot_product_attention = F.scaled_dot_product_attention = torch._C._nn.scaled_dot_product_attention
Hopefully this helps.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions