Skip to content

🎯 Guiders (support literally everything...?) #11205

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 21 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 52 additions & 29 deletions docs/source/en/api/cache.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,50 @@ specific language governing permissions and limitations under the License. -->

# Caching methods

## Faster Cache

[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.

FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output

```python
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
```

## First Block Cache

[First Block Cache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching) is a method that builds upon the ideas of [TeaCache](https://huggingface.co/papers/2411.19108) to speed up inference in diffusion transformers. The generation quality is superior with greatly reduced inference time. This method always computes the output of the first transformer block and computes the differences between past and current outputs of the first transformer block. If the difference is smaller than a predefined threshold, the computation of remaining transformer blocks is skipped, and otherwise the computation is performed as usual.

```python
import torch
from diffusers import CogVideoXPipeline, FirstBlockCacheConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")

# Increasing the threshold may lead to faster inference speeds, but may also lead to poorer quality of generated videos.
# Smaller values between 0.02-2.0 are recommended based on the model being used. The default value is 0.05.
config = FirstBlockCacheConfig(threshold=0.07)
pipe.transformer.enable_cache(config)
```

## Pyramid Attention Broadcast

[Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588) from Xuanlei Zhao, Xiaolong Jin, Kai Wang, Yang You.
Expand Down Expand Up @@ -38,45 +82,24 @@ config = PyramidAttentionBroadcastConfig(
pipe.transformer.enable_cache(config)
```

## Faster Cache
### CacheMixin

[FasterCache](https://huggingface.co/papers/2410.19355) from Zhengyao Lv, Chenyang Si, Junhao Song, Zhenyu Yang, Yu Qiao, Ziwei Liu, Kwan-Yee K. Wong.
[[autodoc]] CacheMixin

FasterCache is a method that speeds up inference in diffusion transformers by:
- Reusing attention states between successive inference steps, due to high similarity between them
- Skipping unconditional branch prediction used in classifier-free guidance by revealing redundancies between unconditional and conditional branch outputs for the same timestep, and therefore approximating the unconditional branch output using the conditional branch output
### FasterCacheConfig

```python
import torch
from diffusers import CogVideoXPipeline, FasterCacheConfig
[[autodoc]] FasterCacheConfig

pipe = CogVideoXPipeline.from_pretrained("THUDM/CogVideoX-5b", torch_dtype=torch.bfloat16)
pipe.to("cuda")
[[autodoc]] apply_faster_cache

config = FasterCacheConfig(
spatial_attention_block_skip_range=2,
spatial_attention_timestep_skip_range=(-1, 681),
current_timestep_callback=lambda: pipe.current_timestep,
attention_weight_callback=lambda _: 0.3,
unconditional_batch_skip_range=5,
unconditional_batch_timestep_skip_range=(-1, 781),
tensor_format="BFCHW",
)
pipe.transformer.enable_cache(config)
```
### FirstBlockCacheConfig

### CacheMixin
[[autodoc]] FirstBlockCacheConfig

[[autodoc]] CacheMixin
[[autodoc]] apply_first_block_cache

### PyramidAttentionBroadcastConfig

[[autodoc]] PyramidAttentionBroadcastConfig

[[autodoc]] apply_pyramid_attention_broadcast

### FasterCacheConfig

[[autodoc]] FasterCacheConfig

[[autodoc]] apply_faster_cache
25 changes: 25 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

_import_structure = {
"configuration_utils": ["ConfigMixin"],
"guiders": [],
"hooks": [],
"loaders": ["FromOriginalModelMixin"],
"models": [],
Expand Down Expand Up @@ -129,12 +130,25 @@
_import_structure["utils.dummy_pt_objects"] = [name for name in dir(dummy_pt_objects) if not name.startswith("_")]

else:
_import_structure["guiders"].extend(
[
"AdaptiveProjectedGuidance",
"ClassifierFreeGuidance",
"ClassifierFreeZeroStarGuidance",
"PerturbedAttentionGuidance",
"SkipLayerGuidance",
]
)
_import_structure["hooks"].extend(
[
"FasterCacheConfig",
"FirstBlockCacheConfig",
"HookRegistry",
"LayerSkipConfig",
"PyramidAttentionBroadcastConfig",
"apply_faster_cache",
"apply_first_block_cache",
"apply_layer_skip",
"apply_pyramid_attention_broadcast",
]
)
Expand Down Expand Up @@ -708,11 +722,22 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_pt_objects import * # noqa F403
else:
from .guiders import (
AdaptiveProjectedGuidance,
ClassifierFreeGuidance,
ClassifierFreeZeroStarGuidance,
PerturbedAttentionGuidance,
SkipLayerGuidance,
)
from .hooks import (
FasterCacheConfig,
FirstBlockCacheConfig,
HookRegistry,
LayerSkipConfig,
PyramidAttentionBroadcastConfig,
apply_faster_cache,
apply_first_block_cache,
apply_layer_skip,
apply_pyramid_attention_broadcast,
)
from .models import (
Expand Down
24 changes: 24 additions & 0 deletions src/diffusers/guiders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from ..utils import is_torch_available


if is_torch_available():
from .adaptive_projected_guidance import AdaptiveProjectedGuidance
from .classifier_free_guidance import ClassifierFreeGuidance
from .classifier_free_zero_star_guidance import ClassifierFreeZeroStarGuidance
from .guider_utils import GuidanceMixin, _raise_guidance_deprecation_warning
from .perturbed_attention_guidance import PerturbedAttentionGuidance
from .skip_layer_guidance import SkipLayerGuidance
145 changes: 145 additions & 0 deletions src/diffusers/guiders/adaptive_projected_guidance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import Optional

import torch

from .guider_utils import GuidanceMixin, rescale_noise_cfg


class AdaptiveProjectedGuidance(GuidanceMixin):
"""
Adaptive Projected Guidance (APG): https://huggingface.co/papers/2410.02416

Args:
guidance_scale (`float`, defaults to `7.5`):
The scale parameter for classifier-free guidance. Higher values result in stronger conditioning on the text
prompt, while lower values allow for more freedom in generation. Higher values may lead to saturation and
deterioration of image quality.
adaptive_projected_guidance_momentum (`float`, defaults to `None`):
The momentum parameter for the adaptive projected guidance. Disabled if set to `None`.
adaptive_projected_guidance_rescale (`float`, defaults to `15.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
guidance_rescale (`float`, defaults to `0.0`):
The rescale factor applied to the noise predictions. This is used to improve image quality and fix
overexposure. Based on Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
Flawed](https://huggingface.co/papers/2305.08891).
use_original_formulation (`bool`, defaults to `False`):
Whether to use the original formulation of classifier-free guidance as proposed in the paper. By default,
we use the diffusers-native implementation that has been in the codebase for a long time. See
[~guiders.classifier_free_guidance.ClassifierFreeGuidance] for more details.
"""

_input_predictions = ["pred_cond", "pred_uncond"]

def __init__(
self,
guidance_scale: float = 7.5,
adaptive_projected_guidance_momentum: Optional[float] = None,
adaptive_projected_guidance_rescale: float = 15.0,
eta: float = 1.0,
guidance_rescale: float = 0.0,
use_original_formulation: bool = False,
):
super().__init__()

self.guidance_scale = guidance_scale
self.adaptive_projected_guidance_momentum = adaptive_projected_guidance_momentum
self.adaptive_projected_guidance_rescale = adaptive_projected_guidance_rescale
self.eta = eta
self.guidance_rescale = guidance_rescale
self.use_original_formulation = use_original_formulation
self.momentum_buffer = None

def prepare_inputs(self, *args):
if self._step == 0:
if self.adaptive_projected_guidance_momentum is not None:
self.momentum_buffer = MomentumBuffer(self.adaptive_projected_guidance_momentum)
return super().prepare_inputs(*args)

def forward(self, pred_cond: torch.Tensor, pred_uncond: Optional[torch.Tensor] = None) -> torch.Tensor:
pred = None

if self._is_cfg_enabled():
pred = pred_cond
else:
pred = normalized_guidance(
pred_cond,
pred_uncond,
self.guidance_scale,
self.momentum_buffer,
self.eta,
self.adaptive_projected_guidance_rescale,
self.use_original_formulation,
)

if self.guidance_rescale > 0.0:
pred = rescale_noise_cfg(pred, pred_cond, self.guidance_rescale)

return pred

@property
def num_conditions(self) -> int:
num_conditions = 1
if self._is_cfg_enabled():
num_conditions += 1
return num_conditions

def _is_cfg_enabled(self) -> bool:
if self.use_original_formulation:
return not math.isclose(self.guidance_scale, 0.0)
else:
return not math.isclose(self.guidance_scale, 1.0)


class MomentumBuffer:
def __init__(self, momentum: float):
self.momentum = momentum
self.running_average = 0

def update(self, update_value: torch.Tensor):
new_average = self.momentum * self.running_average
self.running_average = update_value + new_average


def normalized_guidance(
pred_cond: torch.Tensor,
pred_uncond: torch.Tensor,
guidance_scale: float,
momentum_buffer: Optional[MomentumBuffer] = None,
eta: float = 1.0,
norm_threshold: float = 0.0,
use_original_formulation: bool = False,
):
diff = pred_cond - pred_uncond
dim = [-i for i in range(1, len(diff.shape))]
if momentum_buffer is not None:
momentum_buffer.update(diff)
diff = momentum_buffer.running_average
if norm_threshold > 0:
ones = torch.ones_like(diff)
diff_norm = diff.norm(p=2, dim=dim, keepdim=True)
scale_factor = torch.minimum(ones, norm_threshold / diff_norm)
diff = diff * scale_factor
v0, v1 = diff.double(), pred_cond.double()
v1 = torch.nn.functional.normalize(v1, dim=dim)
v0_parallel = (v0 * v1).sum(dim=dim, keepdim=True) * v1
v0_orthogonal = v0 - v0_parallel
diff_parallel, diff_orthogonal = v0_parallel.type_as(diff), v0_orthogonal.type_as(diff)
normalized_update = diff_orthogonal + eta * diff_parallel
pred = pred_cond if use_original_formulation else pred_uncond
pred = pred + (guidance_scale - 1) * normalized_update
return pred
Loading
Loading