Skip to content

Commit 1ddf3f3

Browse files
a-r-r-o-wDN6
andauthored
Improve information about group offloading and layerwise casting (#11101)
* update * Update docs/source/en/optimization/memory.md * Apply suggestions from code review Co-authored-by: Dhruv Nair <[email protected]> * apply review suggestions * update --------- Co-authored-by: Dhruv Nair <[email protected]>
1 parent 7aac77a commit 1ddf3f3

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

docs/source/en/optimization/memory.md

+20
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,18 @@ export_to_video(video, "output.mp4", fps=8)
198198

199199
Group offloading (for CUDA devices with support for asynchronous data transfer streams) overlaps data transfer and computation to reduce the overall execution time compared to sequential offloading. This is enabled using layer prefetching with CUDA streams. The next layer to be executed is loaded onto the accelerator device while the current layer is being executed - this increases the memory requirements slightly. Group offloading also supports leaf-level offloading (equivalent to sequential CPU offloading) but can be made much faster when using streams.
200200

201+
<Tip>
202+
203+
- Group offloading may not work with all models out-of-the-box. If the forward implementations of the model contain weight-dependent device-casting of inputs, it may clash with the offloading mechanism's handling of device-casting.
204+
- The `offload_type` parameter can be set to either `block_level` or `leaf_level`. `block_level` offloads groups of `torch::nn::ModuleList` or `torch::nn:Sequential` modules based on a configurable attribute `num_blocks_per_group`. For example, if you set `num_blocks_per_group=2` on a standard transformer model containing 40 layers, it will onload/offload 2 layers at a time for a total of 20 onload/offloads. This drastically reduces the VRAM requirements. `leaf_level` offloads individual layers at the lowest level, which is equivalent to sequential offloading. However, unlike sequential offloading, group offloading can be made much faster when using streams, with minimal compromise to end-to-end generation time.
205+
- The `use_stream` parameter can be used with CUDA devices to enable prefetching layers for onload. It defaults to `False`. Layer prefetching allows overlapping computation and data transfer of model weights, which drastically reduces the overall execution time compared to other offloading methods. However, it can increase the CPU RAM usage significantly. Ensure that available CPU RAM that is at least twice the size of the model when setting `use_stream=True`. You can find more information about CUDA streams [here](https://pytorch.org/docs/stable/generated/torch.cuda.Stream.html)
206+
- If specifying `use_stream=True` on VAEs with tiling enabled, make sure to do a dummy forward pass (possibly with dummy inputs) before the actual inference to avoid device-mismatch errors. This may not work on all implementations. Please open an issue if you encounter any problems.
207+
- The parameter `low_cpu_mem_usage` can be set to `True` to reduce CPU memory usage when using streams for group offloading. This is useful when the CPU memory is the bottleneck, but it may counteract the benefits of using streams and increase the overall execution time. The CPU memory savings come from creating pinned-tensors on-the-fly instead of pre-pinning them. This parameter is better suited for using `leaf_level` offloading.
208+
209+
For more information about available parameters and an explanation of how group offloading works, refer to [`~hooks.group_offloading.apply_group_offloading`].
210+
211+
</Tip>
212+
201213
## FP8 layerwise weight-casting
202214

203215
PyTorch supports `torch.float8_e4m3fn` and `torch.float8_e5m2` as weight storage dtypes, but they can't be used for computation in many different tensor operations due to unimplemented kernel support. However, you can use these dtypes to store model weights in fp8 precision and upcast them on-the-fly when the layers are used in the forward pass. This is known as layerwise weight-casting.
@@ -235,6 +247,14 @@ In the above example, layerwise casting is enabled on the transformer component
235247

236248
However, you gain more control and flexibility by directly utilizing the [`~hooks.layerwise_casting.apply_layerwise_casting`] function instead of [`~ModelMixin.enable_layerwise_casting`].
237249

250+
<Tip>
251+
252+
- Layerwise casting may not work with all models out-of-the-box. Sometimes, the forward implementations of the model might contain internal typecasting of weight values. Such implementations are not supported due to the currently simplistic implementation of layerwise casting, which assumes that the forward pass is independent of the weight precision and that the input dtypes are always in `compute_dtype`. An example of an incompatible implementation can be found [here](https://github.com/huggingface/transformers/blob/7f5077e53682ca855afc826162b204ebf809f1f9/src/transformers/models/t5/modeling_t5.py#L294-L299).
253+
- Layerwise casting may fail on custom modeling implementations that make use of [PEFT](https://github.com/huggingface/peft) layers. Some minimal checks to handle this case is implemented but is not extensively tested or guaranteed to work in all cases.
254+
- It can be also be applied partially to specific layers of a model. Partially applying layerwise casting can either be done manually by calling the `apply_layerwise_casting` function on specific internal modules, or by specifying the `skip_modules_pattern` and `skip_modules_classes` parameters for a root module. These parameters are particularly useful for layers such as normalization and modulation.
255+
256+
</Tip>
257+
238258
## Channels-last memory format
239259

240260
The channels-last memory format is an alternative way of ordering NCHW tensors in memory to preserve dimension ordering. Channels-last tensors are ordered in such a way that the channels become the densest dimension (storing images pixel-per-pixel). Since not all operators currently support the channels-last format, it may result in worst performance but you should still try and see if it works for your model.

src/diffusers/hooks/group_offloading.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def apply_group_offloading(
331331
num_blocks_per_group: Optional[int] = None,
332332
non_blocking: bool = False,
333333
use_stream: bool = False,
334-
low_cpu_mem_usage=False,
334+
low_cpu_mem_usage: bool = False,
335335
) -> None:
336336
r"""
337337
Applies group offloading to the internal layers of a torch.nn.Module. To understand what group offloading is, and
@@ -378,6 +378,10 @@ def apply_group_offloading(
378378
use_stream (`bool`, defaults to `False`):
379379
If True, offloading and onloading is done asynchronously using a CUDA stream. This can be useful for
380380
overlapping computation and data transfer.
381+
low_cpu_mem_usage (`bool`, defaults to `False`):
382+
If True, the CPU memory usage is minimized by pinning tensors on-the-fly instead of pre-pinning them. This
383+
option only matters when using streamed CPU offloading (i.e. `use_stream=True`). This can be useful when
384+
the CPU memory is a bottleneck but may counteract the benefits of using streams.
381385
382386
Example:
383387
```python

0 commit comments

Comments
 (0)