-
Notifications
You must be signed in to change notification settings - Fork 6k
Attention Dispatcher #11368
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Attention Dispatcher #11368
Conversation
supported: flash, flash_varlen, flex, native, sage, sage_varlen, xformers
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
… flux attention processors
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting PR! I only left some higher-level comments. My major comment is around having an attention config class instead of environment vars. Or would that be too much for this PR?
For the attention config class (if decided to proceed that route), I was thinking of the following APIs:
attn_config = AttentionConfig(
attn_implementation="...",
enable_gqa=...
)
model.set_attn_config(attn_config)
class BlockMask: | ||
def __init__(self, *args, **kwargs): | ||
raise OptionalDependencyNotAvailable( | ||
"The `torch` library version is too old. Please update it to at least 2.5.0." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could further clarify that "To use BlockMask
you need an updated torch installation."
src/diffusers/utils/constants.py
Outdated
DIFFUSERS_ATTN_PROVIDER = os.getenv("DIFFUSERS_ATTN_PROVIDER", "native") | ||
DIFFUSERS_ATTN_CHECKS = os.getenv("DIFFUSERS_ATTN_CHECKS", "0") in ENV_VARS_TRUE_VALUES |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it instead make sense to have them parsed through some kind of AttentionConfig
class?
def get_active_provider(cls): | ||
return cls._active_provider, cls._providers[cls._active_provider] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should it only return cls._active_provider
?
The environment vars were initially only for my quick testing from CLI instead of changing the code everytime. We can get rid of it completely. The intended API in my mind, and what currently exists in the PR is with context managers: from diffusers import attention_provider
with attention_provider("sage_varlen"):
model(...) Can change once we finalize something |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's looking good 👍🏽 Nice work! Registry makes sense here. Just some minor comments on the initial pass.
Would also add torch NPU backend and XLA flash attention
hidden_states = torch_npu.npu_fusion_attention( |
from torch_xla.experimental.custom_kernel import flash_attention |
I do also think configuring attention without env variables and context manager might be needed. e.g. You want to run the transformer in the pipeline with sageattention but the other components can use regular SDPA. Config object that @sayakpaul suggested makes sense.
src/diffusers/__init__.py
Outdated
@@ -143,6 +143,7 @@ | |||
[ | |||
"AllegroTransformer3DModel", | |||
"AsymmetricAutoencoderKL", | |||
"AttentionProvider", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would prefer to preserve torch-like semantics and call this AttentionBackend
finally: | ||
_AttentionProviderRegistry._active_provider = old_provider | ||
|
||
|
||
def attention_dispatch( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit. Would prefer dispatch_attention_fn
scale: Optional[float] = None, | ||
enable_gqa: bool = False, | ||
) -> torch.Tensor: | ||
with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.FLASH_ATTENTION): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question. How does this compare to FA2 from source? I think they should be equivalent no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For most shapes, FA2 from source seems to be faster than torch native flash. I ran the following script to obtain the table:
code
import torch
from diffusers.models.attention_dispatch import attention_backend, dispatch_attention_fn
torch.manual_seed(0)
# Wan 1.3B/CogVideoX
batch = 1
num_heads = 12
head_dim = 128
dtype = torch.bfloat16
resolutions = [(1, 512, 512), (1, 1024, 1024), (49, 480, 720), (29, 1024, 1024), (81, 480, 832)]
seq_lens = [((res[0] - 1) // 4 + 1) * res[1] * res[2] // 8 // 8 // 4 for res in resolutions]
print("Sequence lengths:", seq_lens)
for seq_len in seq_lens:
flops = 4 * batch * num_heads * head_dim * seq_len * seq_len
torch.manual_seed(0)
query = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
key = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
value = torch.randn(batch, num_heads, seq_len, head_dim, dtype=dtype, device="cuda")
results = {}
for backend in ["flash", "_native_flash", "_native_cudnn", "_native_efficient", "xformers", "_sage_qk_int8_pv_fp16_cuda"]:
with attention_backend(backend):
for _ in range(5):
# Warmup
_ = dispatch_attention_fn(query, key, value)
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
result = dispatch_attention_fn(query, key, value)
end.record()
torch.cuda.synchronize()
elapsed_time = start.elapsed_time(end) / 1000
results[backend] = elapsed_time
tflops_s_flash = flops / results["flash"] / 1e12
tflops_s_native_flash = flops / results["_native_flash"] / 1e12
tflops_s_native_cudnn = flops / results["_native_cudnn"] / 1e12
tflops_s_native_efficient = flops / results["_native_efficient"] / 1e12
tflops_s_xformers = flops / results["xformers"] / 1e12
tflops_s_sage_qk_int8_pv_fp16_cuda = flops / results["_sage_qk_int8_pv_fp16_cuda"] / 1e12
print()
print(f"Shape: {query.shape}")
print(f"TFLOPs: {flops / 1e12:.2f}")
print("===== TFLOPS =====")
print(f" (flash): {tflops_s_flash:.2f}")
print(f" (native_flash): {tflops_s_native_flash:.2f}")
print(f" (native_cudnn): {tflops_s_native_cudnn:.2f}")
print(f" (native_efficient): {tflops_s_native_efficient:.2f}")
print(f" (xformers): {tflops_s_xformers:.2f}")
print(f"(_sage_qk_int8_pv_fp16_cuda): {tflops_s_sage_qk_int8_pv_fp16_cuda:.2f}")
print("==========")
hf-dgx-01
: A100
Shape | Attention | TFLOPS |
---|---|---|
torch.Size([1, 12, 1024, 128]) | flash | 32.77 |
native_flash | 60.49 | |
native_cudnn | 59.92 | |
native_efficient | 78.64 | |
xformers | 22.15 | |
_sage_qk_int8_pv_fp16_cuda | 20.43 | |
torch.Size([1, 12, 4096, 128]) | flash | 179.76 |
native_flash | 167.21 | |
native_cudnn | 158.52 | |
native_efficient | 91.68 | |
xformers | 179.44 | |
_sage_qk_int8_pv_fp16_cuda | 160.04 | |
torch.Size([1, 12, 17550, 128]) | flash | 183.86 |
native_flash | 164.21 | |
native_cudnn | 155.09 | |
native_efficient | 96.26 | |
xformers | 188.00 | |
_sage_qk_int8_pv_fp16_cuda | 200.41 | |
torch.Size([1, 12, 32768, 128]) | flash | 183.47 |
native_flash | 169.01 | |
native_cudnn | 160.52 | |
native_efficient | 97.89 | |
xformers | 183.15 | |
_sage_qk_int8_pv_fp16_cuda | 200.47 | |
torch.Size([1, 12, 32760, 128]) | flash | 178.40 |
native_flash | 166.07 | |
native_cudnn | 154.97 | |
native_efficient | 97.86 | |
xformers | 180.94 | |
_sage_qk_int8_pv_fp16_cuda | 201.17 |
audace
: RTX 4090
Shape | Attention Type | TFLOPS |
---|---|---|
torch.Size([1, 12, 1024, 128]) | flash | 81.71 |
native_flash | 82.78 | |
native_cudnn | 92.52 | |
native_efficient | 65.54 | |
xformers | 50.33 | |
_sage_qk_int8_pv_fp16_cuda | 40.59 | |
torch.Size([1, 12, 4096, 128]) | flash | 149.35 |
native_flash | 146.74 | |
native_cudnn | 150.69 | |
native_efficient | 97.17 | |
xformers | 149.13 | |
_sage_qk_int8_pv_fp16_cuda | 198.94 | |
torch.Size([1, 12, 17550, 128]) | flash | 153.68 |
native_flash | 151.06 | |
native_cudnn | 159.64 | |
native_efficient | 103.39 | |
xformers | 163.58 | |
_sage_qk_int8_pv_fp16_cuda | 243.06 | |
torch.Size([1, 12, 32768, 128]) | flash | 165.93 |
native_flash | 160.99 | |
native_cudnn | 165.72 | |
native_efficient | 105.52 | |
xformers | 165.89 | |
_sage_qk_int8_pv_fp16_cuda | 253.78 | |
torch.Size([1, 12, 32760, 128]) | flash | 165.33 |
native_flash | 161.65 | |
native_cudnn | 161.88 | |
native_efficient | 105.30 | |
xformers | 165.28 | |
_sage_qk_int8_pv_fp16_cuda | 253.74 |
@sayakpaul @DN6 How would you recommend we set per-model attention backend? The backend info needs to be propagated to the attention dispatcher when the forward method is called. The easiest way and how I've done it for training/CP is to attach a simple pre-forward hook that sets the backend, cp_mesh, and any other attributes, when the forward method is invoked. If you have recommendations, I'll modify the implementation accordingly. Currently, you need to first replace the calls to from diffusers import attention_backend
with attention_backend("flash_varlen"):
output = transformer(...) If context manager is not used, it defaults to the original behaviour of calling native torch attention. |
Usage
attention-only benchmark
Model benchmark
Results: 4090
Results with PyTorch 2.7 stable, CUDA 12.6
Wan
Results: A100
Results with PyTorch 2.7 stable, CUDA 12.2
Wan
cc @DN6 @sayakpaul @yiyixuxu