Skip to content

Commit 5799954

Browse files
Gothossayakpaul
authored andcommitted
Add Flux inpainting and Flux Img2Img (#9135)
--------- Co-authored-by: yiyixuxu <[email protected]> Update `UNet2DConditionModel`'s error messages (#9230) * refactor [CI] Update Single file Nightly Tests (#9357) * update * update feedback. improve README for flux dreambooth lora (#9290) * improve readme * improve readme * improve readme * improve readme fix one uncaught deprecation warning for accessing vae_latent_channels in VaeImagePreprocessor (#9372) deprecation warning vae_latent_channels add mixed int8 tests and more tests to nf4. [core] Freenoise memory improvements (#9262) * update * implement prompt interpolation * make style * resnet memory optimizations * more memory optimizations; todo: refactor * update * update animatediff controlnet with latest changes * refactor chunked inference changes * remove print statements * update * chunk -> split * remove changes from incorrect conflict resolution * remove changes from incorrect conflict resolution * add explanation of SplitInferenceModule * update docs * Revert "update docs" This reverts commit c55a50a. * update docstring for freenoise split inference * apply suggestions from review * add tests * apply suggestions from review quantization docs. docs.
1 parent 3b2d6e1 commit 5799954

39 files changed

+3487
-307
lines changed

docs/source/en/_toctree.yml

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,12 @@
178178
title: Habana Gaudi
179179
title: Optimized hardware
180180
title: Accelerate inference and reduce memory
181+
- sections:
182+
- local: quantization/overview
183+
title: Getting Started
184+
- local: quantization/bitsandbytes
185+
title: bitsandbytes
186+
title: Quantization
181187
- sections:
182188
- local: conceptual/philosophy
183189
title: Philosophy
@@ -203,6 +209,8 @@
203209
title: Logging
204210
- local: api/outputs
205211
title: Outputs
212+
- local: api/quantization
213+
title: Quantization
206214
title: Main Classes
207215
- isExpanded: false
208216
sections:

docs/source/en/api/pipelines/flux.md

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,3 +163,15 @@ image.save("flux-fp8-dev.png")
163163
[[autodoc]] FluxPipeline
164164
- all
165165
- __call__
166+
167+
## FluxImg2ImgPipeline
168+
169+
[[autodoc]] FluxImg2ImgPipeline
170+
- all
171+
- __call__
172+
173+
## FluxInpaintPipeline
174+
175+
[[autodoc]] FluxInpaintPipeline
176+
- all
177+
- __call__

docs/source/en/api/quantization.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
-->
13+
14+
# Quantization
15+
16+
Quantization techniques reduce memory and computational costs by representing weights and activations with lower-precision data types like 8-bit integers (int8). This enables loading larger models you normally wouldn't be able to fit into memory, and speeding up inference. Diffusers supports 8-bit and 4-bit quantization with [`bitsandbytes`](https://github.com/bitsandbytes-foundation/bitsandbytes).
17+
18+
Quantization techniques that aren't supported in Transformers can be added with the [`DiffusersQuantizer`] class.
19+
20+
<Tip>
21+
22+
Learn how to quantize models in the [Quantization] (TODO) guide.
23+
24+
</Tip>
25+
26+
27+
## BitsAndBytesConfig
28+
29+
[[autodoc]] BitsAndBytesConfig
30+
31+
## DiffusersQuantizer
32+
33+
[[autodoc]] quantizers.base.DiffusersQuantizer
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
-->
13+
14+
# bitsandbytes
15+
16+
[bitsandbytes](https://github.com/TimDettmers/bitsandbytes) is the easiest option for quantizing a model to 8 and 4-bit. 8-bit quantization multiplies outliers in fp16 with non-outliers in int8, converts the non-outlier values back to fp16, and then adds them together to return the weights in fp16. This reduces the degradative effect outlier values have on a model's performance. 4-bit quantization compresses a model even further, and it is commonly used with [QLoRA](https://hf.co/papers/2305.14314) to finetune quantized LLMs.
17+
18+
19+
To use bitsandbytes, make sure you have the following libraries installed:
20+
21+
```bash
22+
pip install diffusers transformers accelerate bitsandbytes -U
23+
```
24+
25+
Now you can quantize a model by passing a `BitsAndBytesConfig` to [`~ModelMixin.from_pretrained`] method. This works for any model in any modality, as long as it supports loading with Accelerate and contains `torch.nn.Linear` layers.
26+
27+
<hfoptions id="bnb">
28+
<hfoption id="8-bit">
29+
30+
Quantizing a model in 8-bit halves the memory-usage:
31+
32+
```py
33+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
34+
35+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
36+
37+
model_8bit = FluxTransformer2DModel.from_pretrained(
38+
"black-forest-labs/FLUX.1-dev",
39+
subfolder="transformer",
40+
quantization_config=quantization_config
41+
)
42+
```
43+
44+
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
45+
46+
```py
47+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
48+
49+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
50+
51+
model_8bit = FluxTransformer2DModel.from_pretrained(
52+
"black-forest-labs/FLUX.1-dev",
53+
subfolder="transformer",
54+
quantization_config=quantization_config,
55+
torch_dtype=torch.float32
56+
)
57+
model_8bit.transformer_blocks.layers[-1].norm2.weight.dtype
58+
```
59+
60+
Once a model is quantized, you can push the model to the Hub with the [`~ModelMixin.push_to_hub`] method. The quantization config.json file is pushed first, followed by the quantized model weights.
61+
62+
```py
63+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
64+
65+
quantization_config = BitsAndBytesConfig(load_in_8bit=True)
66+
67+
model_8bit = FluxTransformer2DModel.from_pretrained(
68+
"black-forest-labs/FLUX.1-dev",
69+
subfolder="transformer",
70+
quantization_config=quantization_config
71+
)
72+
```
73+
74+
</hfoption>
75+
<hfoption id="4-bit">
76+
77+
Quantizing a model in 4-bit reduces your memory-usage by 4x:
78+
79+
```py
80+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
81+
82+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
83+
84+
model_4bit = FluxTransformer2DModel.from_pretrained(
85+
"black-forest-labs/FLUX.1-dev",
86+
subfolder="transformer",
87+
quantization_config=quantization_config
88+
)
89+
```
90+
91+
By default, all the other modules such as `torch.nn.LayerNorm` are converted to `torch.float16`. You can change the data type of these modules with the `torch_dtype` parameter if you want:
92+
93+
```py
94+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
95+
96+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
97+
98+
model_4bit = FluxTransformer2DModel.from_pretrained(
99+
"black-forest-labs/FLUX.1-dev",
100+
subfolder="transformer",
101+
quantization_config=quantization_config,
102+
torch_dtype=torch.float32
103+
)
104+
model_4bit.transformer_blocks.layers[-1].norm2.weight.dtype
105+
```
106+
107+
You can simply call `model.push_to_hub()` after loading it in 4-bit precision. You can also save the serialized 4-bit models locally with `model.save_pretrained()` command.
108+
109+
</hfoption>
110+
</hfoptions>
111+
112+
<Tip warning={true}>
113+
114+
Training with 8-bit and 4-bit weights are only supported for training *extra* parameters.
115+
116+
</Tip>
117+
118+
You can check your memory footprint with the `get_memory_footprint` method:
119+
120+
```py
121+
print(model.get_memory_footprint())
122+
```
123+
124+
Quantized models can be loaded from the [`~ModelMixin.from_pretrained`] method without needing to specify the `quantization_config` parameters:
125+
126+
```py
127+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
128+
129+
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
130+
131+
model_4bit = FluxTransformer2DModel.from_pretrained(
132+
"sayakpaul/flux.1-dev-nf4-pkg", subfolder="transformer"
133+
)
134+
```
135+
136+
## 8-bit (LLM.int8() algorithm)
137+
138+
<Tip>
139+
140+
Learn more about the details of 8-bit quantization in this [blog post](https://huggingface.co/blog/hf-bitsandbytes-integration)!
141+
142+
</Tip>
143+
144+
This section explores some of the specific features of 8-bit models, such as utlier thresholds and skipping module conversion.
145+
146+
### Outlier threshold
147+
148+
An "outlier" is a hidden state value greater than a certain threshold, and these values are computed in fp16. While the values are usually normally distributed ([-3.5, 3.5]), this distribution can be very different for large models ([-60, 6] or [6, 60]). 8-bit quantization works well for values ~5, but beyond that, there is a significant performance penalty. A good default threshold value is 6, but a lower threshold may be needed for more unstable models (small models or finetuning).
149+
150+
To find the best threshold for your model, we recommend experimenting with the `llm_int8_threshold` parameter in [`BitsAndBytesConfig`]:
151+
152+
```py
153+
from diffusers import FluxTransformer2DModel, BitsAndBytesConfig
154+
155+
quantization_config = BitsAndBytesConfig(
156+
load_in_8bit=True, llm_int8_threshold=10,
157+
)
158+
159+
model_8bit = FluxTransformer2DModel.from_pretrained(
160+
"black-forest-labs/FLUX.1-dev",
161+
subfolder="transformer",
162+
quantization_config=quantization_config,
163+
)
164+
```
165+
166+
### Skip module conversion
167+
168+
For some models, you don't need to quantize every module to 8-bit which can actually cause instability. For example, for diffusion models like [Stable Diffusion 3](../api/pipelines/stable_diffusion/stable_diffusion_3), the `proj_out` module that could be skipped using the `llm_int8_skip_modules` parameter in [`BitsAndBytesConfig`]:
169+
170+
```py
171+
from diffusers import SD3Transformer2DModel, BitsAndBytesConfig
172+
173+
quantization_config = BitsAndBytesConfig(
174+
load_in_8bit=True, llm_int8_skip_modules=["proj_out"],
175+
)
176+
177+
model_8bit = SD3Transformer2DModel.from_pretrained(
178+
"stabilityai/stable-diffusion-3-medium-diffusers",
179+
subfolder="transformer",
180+
quantization_config=quantization_config,
181+
)
182+
```
183+
184+
185+
## 4-bit (QLoRA algorithm)
186+
187+
<Tip>
188+
189+
Learn more about its details in this [blog post](https://huggingface.co/blog/4bit-transformers-bitsandbytes).
190+
191+
</Tip>
192+
193+
This section explores some of the specific features of 4-bit models, such as changing the compute data type, using the Normal Float 4 (NF4) data type, and using nested quantization.
194+
195+
196+
### Compute data type
197+
198+
To speedup computation, you can change the data type from float32 (the default value) to bf16 using the `bnb_4bit_compute_dtype` parameter in [`BitsAndBytesConfig`]:
199+
200+
```py
201+
import torch
202+
from diffusers import BitsAndBytesConfig
203+
204+
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.bfloat16)
205+
```
206+
207+
### Normal Float 4 (NF4)
208+
209+
NF4 is a 4-bit data type from the [QLoRA](https://hf.co/papers/2305.14314) paper, adapted for weights initialized from a normal distribution. You should use NF4 for training 4-bit base models. This can be configured with the `bnb_4bit_quant_type` parameter in the [`BitsAndBytesConfig`]:
210+
211+
```py
212+
from diffusers import BitsAndBytesConfig
213+
214+
nf4_config = BitsAndBytesConfig(
215+
load_in_4bit=True,
216+
bnb_4bit_quant_type="nf4",
217+
)
218+
219+
model_nf4 = SD3Transformer2DModel.from_pretrained(
220+
"stabilityai/stable-diffusion-3-medium-diffusers",
221+
subfolder="transformer",
222+
quantization_config=nf4_config,
223+
)
224+
```
225+
226+
For inference, the `bnb_4bit_quant_type` does not have a huge impact on performance. However, to remain consistent with the model weights, you should use the `bnb_4bit_compute_dtype` and `torch_dtype` values.
227+
228+
### Nested quantization
229+
230+
Nested quantization is a technique that can save additional memory at no additional performance cost. This feature performs a second quantization of the already quantized weights to save an addition 0.4 bits/parameter.
231+
232+
```py
233+
from transformers import BitsAndBytesConfig
234+
235+
double_quant_config = BitsAndBytesConfig(
236+
load_in_4bit=True,
237+
bnb_4bit_use_double_quant=True,
238+
)
239+
240+
double_quant_model = SD3Transformer2DModel.from_pretrained(
241+
"stabilityai/stable-diffusion-3-medium-diffusers",
242+
subfolder="transformer",
243+
quantization_config=double_quant_config,
244+
)
245+
```
246+
247+
## Dequantizing `bitsandbytes` models
248+
249+
Once quantized, you can dequantize the model to the original precision but this might result in a small quality loss of the model. Make sure you have enough GPU RAM to fit the dequantized model.
250+
251+
```python
252+
from transformers import BitsAndBytesConfig
253+
254+
double_quant_config = BitsAndBytesConfig(
255+
load_in_4bit=True,
256+
bnb_4bit_use_double_quant=True,
257+
)
258+
259+
double_quant_model = SD3Transformer2DModel.from_pretrained(
260+
"stabilityai/stable-diffusion-3-medium-diffusers",
261+
subfolder="transformer",
262+
quantization_config=double_quant_config,
263+
)
264+
model.dequantize()
265+
```
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
4+
the License. You may obtain a copy of the License at
5+
6+
http://www.apache.org/licenses/LICENSE-2.0
7+
8+
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
9+
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
10+
specific language governing permissions and limitations under the License.
11+
12+
-->
13+
14+
# Quantization
15+
16+
Quantization techniques focus on representing data with less information while also trying to not lose too much accuracy. This often means converting a data type to represent the same information with fewer bits. For example, if your model weights are stored as 32-bit floating points and they're quantized to 16-bit floating points, this halves the model size which makes it easier to store and reduces memory-usage. Lower precision can also speedup inference because it takes less time to perform calculations with fewer bits.
17+
18+
<Tip>
19+
20+
Interested in adding a new quantization method to Transformers? Read the [`DiffusersQuantizer`](../conceptual/contribution) guide to learn how!
21+
22+
</Tip>
23+
24+
<Tip>
25+
26+
If you are new to the quantization field, we recommend you to check out these beginner-friendly courses about quantization in collaboration with DeepLearning.AI:
27+
28+
* [Quantization Fundamentals with Hugging Face](https://www.deeplearning.ai/short-courses/quantization-fundamentals-with-hugging-face/)
29+
* [Quantization in Depth](https://www.deeplearning.ai/short-courses/quantization-in-depth/)
30+
31+
</Tip>
32+
33+
## When to use what?
34+
35+
This section will be expanded once Diffusers has multiple quantization backends. Currently, we only support `bitsandbytes`. [This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.

0 commit comments

Comments
 (0)