Skip to content

Attention Dispatcher #11368

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open

Attention Dispatcher #11368

wants to merge 11 commits into from

Conversation

a-r-r-o-w
Copy link
Member

@a-r-r-o-w a-r-r-o-w commented Apr 19, 2025

Usage

# test.py
import torch
from diffusers import Lumina2Pipeline, attention_backend

pipe = Lumina2Pipeline.from_pretrained("Alpha-VLLM/Lumina-Image-2.0", torch_dtype=torch.bfloat16)
pipe.to("cuda")

prompt = "A cat holding a sign that says 'Hello, World!' in a colorful park with flowers and trees"

with attention_backend("sage_varlen"):
    image = pipe(prompt, generator=torch.Generator().manual_seed(42)).images[0]
image.save("output.png")
# fails because flex attention requires head dim to be a power of 2
DIFFUSERS_ATTN_PROVIDER="flex" CUDA_VISIBLE_DEVICES=3 python3 test.py
# dispatches to cudnn internally in pytorch, so it's the same as using "_native_cudnn" (see below)
DIFFUSERS_ATTN_PROVIDER="native" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="flash_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="sage_varlen" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_cudnn" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="_native_efficient" CUDA_VISIBLE_DEVICES=3 python3 test.py
DIFFUSERS_ATTN_PROVIDER="xformers" CUDA_VISIBLE_DEVICES=3 python3 test.py
attention-only benchmark
import torch
from diffusers.models.attention_dispatch import attention_backend, dispatch_attention_fn

torch.manual_seed(0)

# Wan 1.3B/CogVideoX
batch = 1
num_heads = 12
head_dim = 128
dtype = torch.bfloat16

resolutions = [(1, 512, 512), (1, 1024, 1024), (49, 480, 720), (29, 1024, 1024), (81, 480, 832)]
seq_lens = [((res[0] - 1) // 4 + 1) * res[1] * res[2] // 8 // 8 // 4 for res in resolutions]
print("Sequence lengths:", seq_lens)

for seq_len in seq_lens:
    flops = 4 * batch * num_heads * head_dim * seq_len * seq_len

    torch.manual_seed(0)
    query = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
    key = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
    value = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")

    results = {}
    
    for backend in ["flash", "flash_varlen", "_native_flash", "_native_cudnn", "_native_efficient", "xformers", "_sage_qk_int8_pv_fp16_cuda"]:
        with attention_backend(backend):
            for _ in range(5):
                # Warmup
                _ = dispatch_attention_fn(query, key, value)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            result = dispatch_attention_fn(query, key, value)
            end.record()
            torch.cuda.synchronize()

            elapsed_time = start.elapsed_time(end) / 1000
            results[backend] = elapsed_time
    
    tflops_s_flash = flops / results["flash"] / 1e12
    tflops_s_flash_varlen = flops / results["flash_varlen"] / 1e12
    tflops_s_native_flash = flops / results["_native_flash"] / 1e12
    tflops_s_native_cudnn = flops / results["_native_cudnn"] / 1e12
    tflops_s_native_efficient = flops / results["_native_efficient"] / 1e12
    tflops_s_xformers = flops / results["xformers"] / 1e12
    tflops_s_sage_qk_int8_pv_fp16_cuda = flops / results["_sage_qk_int8_pv_fp16_cuda"] / 1e12

    print()
    print(f"Shape: {query.shape}")
    print(f"TFLOPs: {flops / 1e12:.2f}")
    print("===== TFLOPS =====")
    print(f"                     (flash): {tflops_s_flash:.2f}")
    print(f"              (flash_varlen): {tflops_s_flash_varlen:.2f}")
    print(f"              (native_flash): {tflops_s_native_flash:.2f}")
    print(f"              (native_cudnn): {tflops_s_native_cudnn:.2f}")
    print(f"          (native_efficient): {tflops_s_native_efficient:.2f}")
    print(f"                  (xformers): {tflops_s_xformers:.2f}")
    print(f"(_sage_qk_int8_pv_fp16_cuda): {tflops_s_sage_qk_int8_pv_fp16_cuda:.2f}")
    print("==========")
Model benchmark
import argparse
import gc
import pathlib
import traceback

import git
import pandas as pd
import torch
import torch.nn.attention.flex_attention
from diffusers import (
    AllegroPipeline,
    CogVideoXPipeline,
    FluxPipeline,
    HunyuanVideoPipeline,
    LattePipeline,
    LTXPipeline,
    MochiPipeline,
    WanPipeline,
    AttentionBackendName,
    attention_backend,
)
from diffusers.hooks import apply_group_offloading
from diffusers.models import HunyuanVideoTransformer3DModel
from diffusers.utils import export_to_video
from diffusers.utils.logging import set_verbosity_info, set_verbosity_debug
from tabulate import tabulate


repo = git.Repo(path="/home/aryan/work/diffusers")
branch = repo.active_branch

torch.nn.attention.flex_attention.flex_attention = torch.compile(torch.nn.attention.flex_attention.flex_attention, mode="max-autotune", dynamic=False, fullgraph=True)
torch.nn.attention.flex_attention.create_block_mask = torch.compile(torch.nn.attention.flex_attention.create_block_mask, mode="max-autotune", dynamic=False, fullgraph=True)

torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_fp16_accumulation = True
torch.backends.cuda.matmul.allow_fp16_reduced_precision_reduction = True
torch.backends.cuda.matmul.allow_bf16_reduced_precision_reduction = True


def pretty_print_results(results, precision: int = 3):
    def format_value(value):
        if isinstance(value, float):
            return f"{value:.{precision}f}"
        return value

    filtered_table = {k: format_value(v) for k, v in results.items()}
    print(tabulate([filtered_table], headers="keys", tablefmt="pipe", stralign="center"))


def benchmark_fn(f, *args, **kwargs):
    torch.cuda.synchronize()
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    start.record()
    output = f(*args, **kwargs)
    end.record()
    torch.cuda.synchronize()
    elapsed_time = round(start.elapsed_time(end) / 1000, 3)

    return elapsed_time, output


def prepare_allegro(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "rhymes-ai/Allegro"
    cache_dir = None

    pipe = AllegroPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    if compile:
        pipe.transformer = torch.compile(
            pipe.transformer, mode="max-autotune-no-cudagraphs", fullgraph=True, dynamic=False
        )

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "A seaside harbor with bright sunlight and sparkling seawater, with many boats in the water. From an aerial view, the boats vary in size and color, some moving and some stationary. Fishing boats in the water suggest that this location might be a popular spot for docking fishing boats.",
        "height": 720,
        "width": 1280,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_cogvideox_1_0(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "THUDM/CogVideoX-5b"
    cache_dir = None

    pipe = CogVideoXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=(
            "A panda, dressed in a small, red jacket and a tiny hat, sits on a wooden stool in a serene bamboo forest. "
            "The panda's fluffy paws strum a miniature acoustic guitar, producing soft, melodic tunes. Nearby, a few other "
            "pandas gather, watching curiously and some clapping in rhythm. Sunlight filters through the tall bamboo, "
            "casting a gentle glow on the scene. The panda's face is expressive, showing concentration and joy as it plays. "
            "The background includes a small, flowing stream and vibrant green foliage, enhancing the peaceful and magical "
            "atmosphere of this unique musical performance."
        ),
        device="cuda",
        dtype=dtype,
    )

    pipe.text_encoder.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 480,
        "width": 720,
        "num_frames": 49,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_flux(dtype: torch.dtype, compile: bool = False, **kwargs) -> None:
    model_id = "black-forest-labs/FLUX.1-dev"
    cache_dir = "/raid/.cache/huggingface"

    pipe = FluxPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.vae.enable_tiling()

    pipe.text_encoder.to("cuda")
    pipe.text_encoder_2.to("cuda")
    prompt_embeds, pooled_prompt_embeds, _ = pipe.encode_prompt(
        prompt="A cat holding a sign that says hello world", prompt_2=None, device="cuda"
    )
    pipe.text_encoder.to("cpu")
    pipe.text_encoder_2.to("cpu")
    del pipe.text_encoder
    del pipe.text_encoder_2
    pipe.text_encoder = None
    pipe.text_encoder_2 = None
    pipe.to("cuda")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "height": 768,
        "width": 768,
        "num_inference_steps": 50,
        "guidance_scale": 5.0,
        **kwargs,
    }

    return pipe, generation_kwargs


def prepare_hunyuan_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "hunyuanvideo-community/HunyuanVideo"
    cache_dir = None

    transformer = HunyuanVideoTransformer3DModel.from_pretrained(
        model_id, subfolder="transformer", torch_dtype=torch.bfloat16
    )
    pipe = HunyuanVideoPipeline.from_pretrained(
        model_id, transformer=transformer, torch_dtype=torch.float16, cache_dir=cache_dir
    )
    pipe.to("cuda")

    prompt_embeds, pooled_prompt_embeds, prompt_attention_mask = pipe.encode_prompt(
        prompt="A cat wearing sunglasses and working as a lifeguard at pool.", device="cuda", dtype=torch.float16
    )
    pipe.text_encoder.to("cpu")
    pipe.text_encoder_2.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "pooled_prompt_embeds": pooled_prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "height": 320,
        "width": 512,
        "num_frames": 61,
        "num_inference_steps": 30,
    }

    return pipe, generation_kwargs


def prepare_latte(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "maxin-cn/Latte-1"
    cache_dir = None

    pipe = LattePipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt="A cat wearing sunglasses and working as a lifeguard at pool.",
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 512,
        "width": 512,
        "video_length": 16,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_ltx_video(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "a-r-r-o-w/LTX-Video-diffusers"
    cache_dir = None

    pipe = LTXPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")

    (
        prompt_embeds,
        prompt_attention_mask,
        negative_prompt_embeds,
        negative_prompt_attention_mask,
    ) = pipe.encode_prompt(
        prompt="A woman with long brown hair and light skin smiles at another woman with long blonde hair. The woman with brown hair wears a black jacket and has a small, barely noticeable mole on her right cheek. The camera angle is a close-up, focused on the woman with brown hair's face. The lighting is warm and natural, likely from the setting sun, casting a soft glow on the scene. The scene appears to be real-life footage",
        negative_prompt="worst quality, inconsistent motion, blurry, jittery, distorted",
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "prompt_attention_mask": prompt_attention_mask,
        "negative_prompt_embeds": negative_prompt_embeds,
        "negative_prompt_attention_mask": negative_prompt_attention_mask,
        "width": 768,
        "height": 512,
        "num_frames": 161,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_mochi(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "genmo/mochi-1-preview"
    cache_dir = None

    pipe = MochiPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    pipe.to("cuda")
    pipe.vae.enable_tiling()

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)

    generation_kwargs = {
        "prompt": "Close-up of a chameleon's eye, with its scaly skin changing color. Ultra high resolution 4k.",
        "height": 480,
        "width": 848,
        "num_frames": 85,
        "num_inference_steps": 50,
    }

    return pipe, generation_kwargs


def prepare_wan(dtype: torch.dtype, compile: bool = False, **kwargs):
    model_id = "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"
    cache_dir = None

    pipe = WanPipeline.from_pretrained(model_id, torch_dtype=dtype, cache_dir=cache_dir)
    
    prompt = "A cat and a dog baking a cake together in a kitchen. The cat is carefully measuring flour, while the dog is stirring the batter with a wooden spoon. The kitchen is cozy, with sunlight streaming through the window."
    negative_prompt = "worst quality, low quality, blurry, distorted, out of focus, bad composition"
    
    pipe.text_encoder.to("cuda")
    prompt_embeds, negative_prompt_embeds = pipe.encode_prompt(
        prompt=prompt,
        negative_prompt=negative_prompt,
        do_classifier_free_guidance=True,
        num_videos_per_prompt=1,
        device="cuda",
    )
    pipe.text_encoder.to("cpu")
    del pipe.text_encoder
    pipe.text_encoder = None

    pipe.to("cuda")

    for key, value in list(kwargs.items()):
        if torch.is_tensor(value):
            kwargs[key] = value.to(device="cuda", dtype=dtype)
    
    generation_kwargs = {
        "prompt_embeds": prompt_embeds,
        "negative_prompt_embeds": negative_prompt_embeds,
        "height": 480,
        "width": 832,
        "num_frames": 81,
        "guidance_scale": 5.0,
        "num_inference_steps": 30,
        **kwargs,
    }

    return pipe, generation_kwargs


def decode_allegro(pipe: AllegroPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_cogvideox_1_0(pipe: CogVideoXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents)
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_flux(pipe: FluxPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    height = kwargs["height"]
    width = kwargs["width"]
    filename = f"{filename.as_posix()}.png"
    latents = pipe._unpack_latents(latents, height, width, pipe.vae_scale_factor)
    latents = (latents / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
    image = pipe.vae.decode(latents, return_dict=False)[0]
    image = pipe.image_processor.postprocess(image, output_type="pil")[0]
    image.save(filename)
    return filename


def decode_hunyuan_video(pipe: HunyuanVideoPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype) / pipe.vae.config.scaling_factor
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_latte(pipe: LattePipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    video = pipe.decode_latents(latents, video_length=kwargs["video_length"])
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_ltx_video(pipe: LTXPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latent_num_frames = (kwargs["num_frames"] - 1) // pipe.vae_temporal_compression_ratio + 1
    latent_height = kwargs["height"] // pipe.vae_spatial_compression_ratio
    latent_width = kwargs["width"] // pipe.vae_spatial_compression_ratio

    latents = pipe._unpack_latents(
        latents,
        latent_num_frames,
        latent_height,
        latent_width,
        pipe.transformer_spatial_patch_size,
        pipe.transformer_temporal_patch_size,
    )
    latents = pipe._denormalize_latents(
        latents, pipe.vae.latents_mean, pipe.vae.latents_std, pipe.vae.config.scaling_factor
    )
    latents = latents.to(pipe.vae.dtype)

    timestep = None
    video = pipe.vae.decode(latents, timestep, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=24)
    return filename


def decode_mochi(pipe: MochiPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents_mean = torch.tensor(pipe.vae.config.latents_mean).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents_std = torch.tensor(pipe.vae.config.latents_std).view(1, 12, 1, 1, 1).to(latents.device, latents.dtype)
    latents = latents * latents_std / pipe.vae.config.scaling_factor + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video=video, output_type="pil")[0]
    export_to_video(video, filename, fps=8)
    return filename


def decode_wan(pipe: WanPipeline, latents: torch.Tensor, filename: pathlib.Path, **kwargs):
    filename = f"{filename.as_posix()}.mp4"
    latents = latents.to(pipe.vae.dtype)
    latents_mean = (
        torch.tensor(pipe.vae.config.latents_mean)
        .view(1, pipe.vae.config.z_dim, 1, 1, 1)
        .to(latents.device, latents.dtype)
    )
    latents_std = 1.0 / torch.tensor(pipe.vae.config.latents_std).view(1, pipe.vae.config.z_dim, 1, 1, 1).to(
        latents.device, latents.dtype
    )
    latents = latents / latents_std + latents_mean
    video = pipe.vae.decode(latents, return_dict=False)[0]
    video = pipe.video_processor.postprocess_video(video, output_type="pil")[0]
    export_to_video(video, filename, fps=16)
    return filename


def reset_memory():
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.ipc_collect()
    torch.cuda.synchronize()
    torch.cuda.reset_peak_memory_stats()
    torch.cuda.reset_accumulated_memory_stats()


MODEL_MAPPING = {
    "allegro": {
        "prepare": prepare_allegro,
        "decode": decode_allegro,
    },
    "cogvideox-1.0": {
        "prepare": prepare_cogvideox_1_0,
        "decode": decode_cogvideox_1_0,
    },
    "flux": {
        "prepare": prepare_flux,
        "decode": decode_flux,
    },
    "hunyuan_video": {
        "prepare": prepare_hunyuan_video,
        "decode": decode_hunyuan_video,
    },
    "latte": {
        "prepare": prepare_latte,
        "decode": decode_latte,
    },
    "ltx_video": {
        "prepare": prepare_ltx_video,
        "decode": decode_ltx_video,
    },
    "mochi": {
        "prepare": prepare_mochi,
        "decode": decode_mochi,
    },
    "wan": {
        "prepare": prepare_wan,
        "decode": decode_wan,
    }
}

STR_TO_COMPUTE_DTYPE = {
    "bf16": torch.bfloat16,
    "fp16": torch.float16,
    "fp32": torch.float32,
}


def run_inference(pipe, generation_kwargs):
    generator = torch.Generator().manual_seed(181201)
    output = pipe(generator=generator, output_type="latent", **generation_kwargs)[0]
    torch.cuda.synchronize()
    return output


from diffusers.hooks import ModelHook, HookRegistry
from accelerate.utils import send_to_device

class MoveToCUDAHook(ModelHook):
    def pre_forward(self, module, *args, **kwargs):
        args = send_to_device(args, "cuda")
        kwargs = send_to_device(kwargs, "cuda")
        return args, kwargs

    def post_forward(self, module, output):
        output = send_to_device(output, "cpu")
        return output


@torch.no_grad()
def main(model_id: str, output_dir: str, dtype: str, offloading_type: str, num_blocks_per_group: int, use_stream: bool, compile: bool, attn_provider: str, num_images_per_prompt: int):
    if attn_provider == "flex":
        import torch.nn.attention.flex_attention as flex_attention

        flex_attention.flex_attention = torch.compile(flex_attention.flex_attention, mode="max-autotune-no-cudagraphs", fullgraph=True)
        flex_attention.create_block_mask = torch.compile(flex_attention.create_block_mask, mode="max-autotune-no-cudagraphs", fullgraph=True)

    if model_id not in MODEL_MAPPING.keys():
        raise ValueError("Unsupported `model_id` specified.")

    output_dir = pathlib.Path(output_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    csv_filename = output_dir / f"{model_id}.csv"

    compute_dtype = STR_TO_COMPUTE_DTYPE[dtype]
    model = MODEL_MAPPING[model_id]
    reset_memory()

    try:
        # 1. Prepare inputs and generation kwargs
        pipe, generation_kwargs = model["prepare"](dtype=compute_dtype)

        extra_keys = {}
        if model_id == "wan":
            extra_keys = {"num_videos_per_prompt": num_images_per_prompt}
        else:
            extra_keys = {"num_images_per_prompt": num_images_per_prompt}
        generation_kwargs.update(extra_keys)

        # 2. Apply group offloading
        if offloading_type == "model":
            pipe.enable_model_cpu_offload()
        elif offloading_type == "sequential":
            pipe.enable_sequential_cpu_offload()
        elif offloading_type in ["block_level", "leaf_level"]:
            apply_group_offloading(
                pipe.transformer,
                offload_type=offloading_type,
                num_blocks_per_group=num_blocks_per_group,
                offload_device=torch.device("cpu"),
                onload_device=torch.device("cuda"),
                non_blocking=True,
                use_stream=use_stream,
            )
        else:
            pipe.transformer.to("cuda")
            # registry = HookRegistry.check_if_exists_or_initialize(pipe.transformer)
            # registry.register_hook(MoveToCUDAHook(), "MoveToCUDAHook")
        
        pipe.vae.to("cuda")
        torch.cuda.synchronize()

        reset_memory()
        model_max_memory_reserved = round(torch.cuda.max_memory_allocated() / 1024**3, 3)

        if compile:
            pipe.transformer = torch.compile(
                pipe.transformer, mode="max-autotune", fullgraph=True, dynamic=False
            )

        registry_vae = HookRegistry.check_if_exists_or_initialize(pipe.vae.decoder)
        registry_vae.register_hook(MoveToCUDAHook(), "MoveToCUDAHook")

        # 3. Warmup
        num_warmups = 1
        original_num_inference_steps = generation_kwargs["num_inference_steps"]
        generation_kwargs["num_inference_steps"] = 2
        with attention_backend(attn_provider):
            for _ in range(num_warmups):
                run_inference(pipe, generation_kwargs)
        generation_kwargs["num_inference_steps"] = original_num_inference_steps

        # 4. Benchmark
        with attention_backend(attn_provider):
            time, latents = benchmark_fn(run_inference, pipe, generation_kwargs)
        inference_max_memory_reserved = round(torch.cuda.max_memory_allocated() / 1024**3, 3)

        # 5. Decode latents
        filename = output_dir / f"{model_id}---attn_provider-{attn_provider}---dtype-{dtype}---offloading_type-{offloading_type}---num_blocks_per_group-{num_blocks_per_group}---use_stream-{use_stream}---compile-{compile}"
        filename = model["decode"](
            pipe,
            latents,
            filename,
            height=generation_kwargs["height"],
            width=generation_kwargs["width"],
            num_frames=generation_kwargs.get("num_frames", None),
            video_length=generation_kwargs.get("video_length", None),
        )

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "attn_provider": attn_provider,
            "time": time,
            "offloading_type": offloading_type,
            "use_stream": use_stream,
            "num_blocks": num_blocks_per_group,
            "model_memory": model_max_memory_reserved,
            "inference_memory": inference_max_memory_reserved,
            "compile": compile,
            "compute_dtype": dtype,
            "branch": branch,
            "filename": filename,
            "exception": None,
        }

    except Exception as e:
        print(f"An error occurred: {e}")
        traceback.print_exc()

        # 6. Save artifacts
        info = {
            "model_id": model_id,
            "attn_provider": attn_provider,
            "time": None,
            "offloading_type": offloading_type,
            "use_stream": use_stream,
            "num_blocks": num_blocks_per_group,
            "model_memory": None,
            "inference_memory": None,
            "compile": compile,
            "compute_dtype": dtype,
            "branch": branch,
            "filename": None,
            "exception": str(e),
        }

    pretty_print_results(info, precision=3)

    df = pd.DataFrame([info])
    df.to_csv(csv_filename.as_posix(), mode="a", index=False, header=not csv_filename.is_file())


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model_id",
        type=str,
        default="flux",
        choices=["flux", "cogvideox-1.0", "latte", "allegro", "hunyuan_video", "mochi", "ltx_video", "wan"],
        help="Model to run benchmark for.",
    )
    parser.add_argument("--attn_provider", type=str, default="native", choices=[x.value for x in AttentionBackendName.__members__.values()])
    parser.add_argument("--num_images_per_prompt", type=int, default=1, help="Number of images to generate per prompt.")
    parser.add_argument(
        "--output_dir", required=True, type=str, help="Path where the benchmark artifacts and outputs are the be saved."
    )
    parser.add_argument("--dtype", type=str, help="torch.dtype to use for inference")
    parser.add_argument("--offloading_type", type=str, default="none", choices=["none", "model", "block_level", "leaf_level"], help="Type of offloading to use.")
    parser.add_argument("--num_blocks_per_group", type=int, default=None, help="Number of layers per group for group offloading.")
    parser.add_argument("--use_stream", action="store_true", default=False, help="Whether to use CUDA streams for offloading.")
    parser.add_argument(
        "--compile",
        action="store_true",
        default=False,
        help="Whether to torch.compile the denoiser.",
    )
    parser.add_argument("-v", "--verbose", action="store_true", help="Enable verbose logging.")
    args = parser.parse_args()

    if args.verbose:
        set_verbosity_debug()
    else:
        set_verbosity_info()

    main(
        args.model_id,
        args.output_dir,
        args.dtype,
        args.offloading_type,
        args.num_blocks_per_group,
        args.use_stream,
        args.compile,
        args.attn_provider,
        args.num_images_per_prompt,
    )
Results: 4090

Results with PyTorch 2.7 stable, CUDA 12.6

Wan

model_id attn_provider time offloading_type use_stream num_blocks model_memory inference_memory compile
wan flash 142.816 none False 2.912 4.455 False
wan flash_varlen 144.221 none False 2.912 4.455 False
wan flex 146.176 none False 2.912 4.455 False
wan native 144.692 none False 2.912 4.455 False
wan _native_cudnn 144.901 none False 2.912 4.455 False
wan _native_efficient 184.593 none False 2.912 4.455 False
wan _native_flash 144.611 none False 2.912 4.455 False
wan sage 102.281 none False 2.912 4.455 False
wan sage_varlen 112.254 none False 2.912 4.455 False
wan xformers 142.909 none False 2.912 4.455 False
wan flash 147.230 leaf_level True 0.249 1.819 False
wan flash_varlen 148.197 leaf_level True 0.249 1.819 False
wan flex 150.197 leaf_level True 0.249 1.819 False
wan native 148.783 leaf_level True 0.249 1.819 False
wan _native_cudnn 149.177 leaf_level True 0.249 1.819 False
wan _native_efficient 188.643 leaf_level True 0.249 1.819 False
wan _native_flash 148.753 leaf_level True 0.249 1.819 False
wan sage 106.032 leaf_level True 0.249 1.819 False
wan sage_varlen 116.081 leaf_level True 0.249 1.819 False
wan xformers 147.119 leaf_level True 0.249 1.819 False
Results: A100

Results with PyTorch 2.7 stable, CUDA 12.2

Wan

model_id attn_provider time offloading_type use_stream num_blocks model_memory inference_memory compile
wan flash 123.107 none False 2.912 4.455 False
wan flash_varlen 125.355 none False 2.912 4.455 False
wan flex 143.088 none False 2.912 4.455 False
wan native 130.183 none False 2.912 4.455 False
wan _native_cudnn 137.591 none False 2.912 4.455 False
wan _native_efficient 183.795 none False 2.912 4.455 False
wan _native_flash 131.384 none False 2.912 4.455 False
wan sage 119.741 none False 2.912 4.455 False
wan sage_varlen 131.515 none False 2.912 4.455 False
wan xformers 125.414 none False 2.912 4.455 False
wan flash 127.484 leaf_level True 0.249 1.819 False
wan flash_varlen 129.351 leaf_level True 0.249 1.819 False
wan flex 146.739 leaf_level True 0.249 1.819 False
wan native 133.718 leaf_level True 0.249 1.819 False
wan _native_cudnn 141.970 leaf_level True 0.249 1.819 False
wan _native_efficient 188.268 leaf_level True 0.249 1.819 False
wan _native_flash 133.996 leaf_level True 0.249 1.819 False
wan sage 123.269 leaf_level True 0.249 1.819 False
wan sage_varlen 133.422 leaf_level True 0.249 1.819 False
wan xformers 127.743 leaf_level True 0.249 1.819 False

cc @DN6 @sayakpaul @yiyixuxu

supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Member

@sayakpaul sayakpaul left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting PR! I only left some higher-level comments. My major comment is around having an attention config class instead of environment vars. Or would that be too much for this PR?


For the attention config class (if decided to proceed that route), I was thinking of the following APIs:

attn_config = AttentionConfig(
    attn_implementation="...",
    enable_gqa=...
)
model.set_attn_config(attn_config)

class BlockMask:
def __init__(self, *args, **kwargs):
raise OptionalDependencyNotAvailable(
"The `torch` library version is too old. Please update it to at least 2.5.0."
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could further clarify that "To use BlockMask you need an updated torch installation."

Comment on lines 44 to 45
DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native")
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES
Copy link
Member

@sayakpaul sayakpaul Apr 22, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it instead make sense to have them parsed through some kind of AttentionConfig class?

Comment on lines 153 to 154
def get_active_provider(cls):
return cls._active_provider, cls._providers[cls._active_provider]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should it only return cls._active_provider?

@a-r-r-o-w
Copy link
Member Author

The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely.

The intended API in my mind, and what currently exists in the PR is with context managers:

from diffusers import attention_provider

with attention_provider("sage_varlen"):
    model(...)

Can change once we finalize something

Copy link
Collaborator

@DN6 DN6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's looking good 👍🏽 Nice work! Registry makes sense here. Just some minor comments on the initial pass.

Would also add torch NPU backend and XLA flash attention

hidden_states = torch_npu.npu_fusion_attention(

from torch_xla.experimental.custom_kernel import flash_attention

I do also think configuring attention without env variables and context manager might be needed. e.g. You want to run the transformer in the pipeline with sageattention but the other components can use regular SDPA. Config object that @sayakpaul suggested makes sense.

@@ -143,6 +143,7 @@
[
"AllegroTransformer3DModel",
"AsymmetricAutoencoderKL",
"AttentionProvider",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would prefer to preserve torch-like semantics and call this AttentionBackend

finally:
_AttentionProviderRegistry._active_provider = old_provider


def attention_dispatch(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit. Would prefer dispatch_attention_fn

scale: Optional[float] = None,
enable_gqa: bool = False,
) -> torch.Tensor:
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question. How does this compare to FA2 from source? I think they should be equivalent no?

Copy link
Member Author

@a-r-r-o-w a-r-r-o-w May 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For most shapes, FA2 from source seems to be faster than torch native flash. I ran the following script to obtain the table:

code
import torch
from diffusers.models.attention_dispatch import attention_backend, dispatch_attention_fn

torch.manual_seed(0)

# Wan 1.3B/CogVideoX
batch = 1
num_heads = 12
head_dim = 128
dtype = torch.bfloat16

resolutions = [(1, 512, 512), (1, 1024, 1024), (49, 480, 720), (29, 1024, 1024), (81, 480, 832)]
seq_lens = [((res[0] - 1) // 4 + 1) * res[1] * res[2] // 8 // 8 // 4 for res in resolutions]
print("Sequence lengths:", seq_lens)

for seq_len in seq_lens:
    flops = 4 * batch * num_heads * head_dim * seq_len * seq_len

    torch.manual_seed(0)
    query = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
    key = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
    value = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")

    results = {}
    
    for backend in ["flash", "_native_flash", "_native_cudnn", "_native_efficient", "xformers", "_sage_qk_int8_pv_fp16_cuda"]:
        with attention_backend(backend):
            for _ in range(5):
                # Warmup
                _ = dispatch_attention_fn(query, key, value)

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            result = dispatch_attention_fn(query, key, value)
            end.record()
            torch.cuda.synchronize()

            elapsed_time = start.elapsed_time(end) / 1000
            results[backend] = elapsed_time
    
    tflops_s_flash = flops / results["flash"] / 1e12
    tflops_s_native_flash = flops / results["_native_flash"] / 1e12
    tflops_s_native_cudnn = flops / results["_native_cudnn"] / 1e12
    tflops_s_native_efficient = flops / results["_native_efficient"] / 1e12
    tflops_s_xformers = flops / results["xformers"] / 1e12
    tflops_s_sage_qk_int8_pv_fp16_cuda = flops / results["_sage_qk_int8_pv_fp16_cuda"] / 1e12

    print()
    print(f"Shape: {query.shape}")
    print(f"TFLOPs: {flops / 1e12:.2f}")
    print("===== TFLOPS =====")
    print(f"                     (flash): {tflops_s_flash:.2f}")
    print(f"              (native_flash): {tflops_s_native_flash:.2f}")
    print(f"              (native_cudnn): {tflops_s_native_cudnn:.2f}")
    print(f"          (native_efficient): {tflops_s_native_efficient:.2f}")
    print(f"                  (xformers): {tflops_s_xformers:.2f}")
    print(f"(_sage_qk_int8_pv_fp16_cuda): {tflops_s_sage_qk_int8_pv_fp16_cuda:.2f}")
    print("==========")

hf-dgx-01: A100

Shape Attention TFLOPS
torch.Size([1, 12, 1024, 128]) flash 32.77
native_flash 60.49
native_cudnn 59.92
native_efficient 78.64
xformers 22.15
_sage_qk_int8_pv_fp16_cuda 20.43
torch.Size([1, 12, 4096, 128]) flash 179.76
native_flash 167.21
native_cudnn 158.52
native_efficient 91.68
xformers 179.44
_sage_qk_int8_pv_fp16_cuda 160.04
torch.Size([1, 12, 17550, 128]) flash 183.86
native_flash 164.21
native_cudnn 155.09
native_efficient 96.26
xformers 188.00
_sage_qk_int8_pv_fp16_cuda 200.41
torch.Size([1, 12, 32768, 128]) flash 183.47
native_flash 169.01
native_cudnn 160.52
native_efficient 97.89
xformers 183.15
_sage_qk_int8_pv_fp16_cuda 200.47
torch.Size([1, 12, 32760, 128]) flash 178.40
native_flash 166.07
native_cudnn 154.97
native_efficient 97.86
xformers 180.94
_sage_qk_int8_pv_fp16_cuda 201.17

audace: RTX 4090

Shape Attention Type TFLOPS
torch.Size([1, 12, 1024, 128]) flash 81.71
native_flash 82.78
native_cudnn 92.52
native_efficient 65.54
xformers 50.33
_sage_qk_int8_pv_fp16_cuda 40.59
torch.Size([1, 12, 4096, 128]) flash 149.35
native_flash 146.74
native_cudnn 150.69
native_efficient 97.17
xformers 149.13
_sage_qk_int8_pv_fp16_cuda 198.94
torch.Size([1, 12, 17550, 128]) flash 153.68
native_flash 151.06
native_cudnn 159.64
native_efficient 103.39
xformers 163.58
_sage_qk_int8_pv_fp16_cuda 243.06
torch.Size([1, 12, 32768, 128]) flash 165.93
native_flash 160.99
native_cudnn 165.72
native_efficient 105.52
xformers 165.89
_sage_qk_int8_pv_fp16_cuda 253.78
torch.Size([1, 12, 32760, 128]) flash 165.33
native_flash 161.65
native_cudnn 161.88
native_efficient 105.30
xformers 165.28
_sage_qk_int8_pv_fp16_cuda 253.74

@a-r-r-o-w a-r-r-o-w marked this pull request as ready for review May 16, 2025 12:22
@a-r-r-o-w
Copy link
Member Author

For the attention config class (if decided to proceed that route), I was thinking of the following APIs:

attn_config = AttentionConfig(
    attn_implementation="...",
    enable_gqa=...
)
model.set_attn_config(attn_config)

@sayakpaul @DN6 How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked. If you have recommendations, I'll modify the implementation accordingly.

Currently, you need to first replace the calls to F.scaled_dot_product_attention with diffusers.models.attention_dispatch.dispatch_attention_fn in the modeling code and then invoke one or more models under the attention_backend context manager:

from diffusers import attention_backend

with attention_backend("flash_varlen"):
    output = transformer(...)

If context manager is not used, it defaults to the original behaviour of calling native torch attention.

@a-r-r-o-w a-r-r-o-w requested review from DN6, sayakpaul and yiyixuxu May 16, 2025 18:54
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants