-
Notifications
You must be signed in to change notification settings - Fork 6k
[Quantization] Add Quanto backend #10756
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 13 commits
ff50418
ba5bba7
aa8cdaf
39e20e2
f52050a
f4c14c2
f67d97c
5cff237
f734c09
7472f18
e96686e
4ae8691
7b841dc
e090177
559f124
9e5a3d0
b136d23
2c7f303
c80d4d4
d355e6a
9a72fef
79901e4
c4b6e24
c29684f
6cf9a78
0736f87
4eabed7
f512c28
dbaef7c
963559f
156db08
4516f22
830b734
8afff1b
8163687
bb7fb66
6cad1d5
d5ab9ca
deebc22
cf4694e
1b46a32
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,107 @@ | ||
<!--Copyright 2025 The HuggingFace Team. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with | ||
the License. You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on | ||
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the | ||
specific language governing permissions and limitations under the License. | ||
|
||
--> | ||
|
||
# Quanto | ||
|
||
[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind: | ||
|
||
- All features are available in eager mode (works with non-traceable models) | ||
- Supports quantization aware training | ||
- Quantized models are compatible with `torch.compile` | ||
- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU) | ||
|
||
In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate` | ||
|
||
```shell | ||
pip install optimum-quanto accelerate | ||
``` | ||
|
||
Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. The following snippet demonstrates how to apply `float8` quantization with Quanto. | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights="float8") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps a comment to note that only |
||
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype) | ||
pipe.to("cuda") | ||
|
||
prompt = "A cat holding a sign that says hello world" | ||
image = pipe( | ||
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512 | ||
).images[0] | ||
image.save("output.png") | ||
``` | ||
|
||
## Using `from_single_file` with the Quanto Backend | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors" | ||
quantization_config = QuantoConfig(weights="float8") | ||
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh lovely. Not to digress from this PR but would it make sense to also do something similar for bitsandbytes and torchao for There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. TorchAO should just work out of the box. We can add a section in the docs page. For BnB the conversion step in single file is still a bottleneck. We need to figure out how to handle that gracefully. |
||
``` | ||
|
||
## Saving Quantized models | ||
|
||
Diffusers supports serializing and saving Quanto models using the `save_pretrained` method. | ||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights="float8") | ||
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16) | ||
|
||
# save quantized model to reuse | ||
transformer.save_pretrained("<your quantized model save path>") | ||
|
||
# you can reload your quantized model with | ||
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>") | ||
``` | ||
|
||
## Using `torch.compile` with Quanto | ||
|
||
Currently the Quanto backend only supports `torch.compile` for `int8` weights and activations. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not sure where this restriction comes from ... did you have issues with float8 or int4 because of the custom kernels ? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So I tested with Float8 weight quantization and ran into this error
And with INT4 I was running into what looks like a dtype issue, which I don't seem to run into when I'm not using
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you try out PyTorch nightlies? How's the performance improvement with |
||
|
||
```python | ||
import torch | ||
from diffusers import FluxTransformer2DModel, QuantoConfig | ||
|
||
model_id = "black-forest-labs/FLUX.1-dev" | ||
quantization_config = QuantoConfig(weights="int8") | ||
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16) | ||
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Great that this works. |
||
|
||
pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype) | ||
pipe.to("cuda") | ||
``` | ||
|
||
## Supported Quantization Types | ||
|
||
### Weights | ||
|
||
- float8 | ||
- int8 | ||
- int4 | ||
- int2 | ||
|
||
### Activations | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's show an example from the docs as well? Additionally, we could refer the users to this blog post so that they have a sense of the savings around memory and latency? |
||
- float8 | ||
- int8 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -866,7 +866,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
|
||
# Check if `_keep_in_fp32_modules` is not None | ||
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and ( | ||
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") | ||
(torch_dtype in [torch.float16, torch.bfloat16]) or hasattr(hf_quantizer, "use_keep_in_fp32_modules") | ||
) | ||
if use_keep_in_fp32_modules: | ||
keep_in_fp32_modules = cls._keep_in_fp32_modules | ||
|
@@ -1041,7 +1041,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P | |
model, | ||
state_dict, | ||
device=param_device, | ||
dtype=torch_dtype, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this going away? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh this is a mistake. Thanks for catching. |
||
model_name_or_path=pretrained_model_name_or_path, | ||
hf_quantizer=hf_quantizer, | ||
keep_in_fp32_modules=keep_in_fp32_modules, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,8 +26,10 @@ | |
GGUFQuantizationConfig, | ||
QuantizationConfigMixin, | ||
QuantizationMethod, | ||
QuantoConfig, | ||
TorchAoConfig, | ||
) | ||
from .quanto import QuantoQuantizer | ||
from .torchao import TorchAoHfQuantizer | ||
|
||
|
||
|
@@ -36,13 +38,15 @@ | |
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer, | ||
"gguf": GGUFQuantizer, | ||
"torchao": TorchAoHfQuantizer, | ||
"quanto": QuantoQuantizer, | ||
} | ||
|
||
AUTO_QUANTIZATION_CONFIG_MAPPING = { | ||
"bitsandbytes_4bit": BitsAndBytesConfig, | ||
"bitsandbytes_8bit": BitsAndBytesConfig, | ||
"gguf": GGUFQuantizationConfig, | ||
"torchao": TorchAoConfig, | ||
"quanto": QuantoConfig, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit): maybe this should go above torchao to keep quantizers in alphabetical order (does not really have to be addressed and we can do it order of quantization backend addition as well) |
||
} | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .quanto_quantizer import QuantoQuantizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have we verified this? Last time I checked only weight-quantized models were compatible with
torch.compile
. Cc: @dacorvo.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
True, but this should be fixed in pytorch 2.6 (I did not check though).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dacorvo I tried to run torch compile with float8 weights in the following way and hit an error during inference
Traceback:
The
torch.compile
step seems to work. The error is raised during the forward pass.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same with nightly?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah same errors with nightly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's be specific that only int8 supports torch.compile for now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Mentioned in the compile section of the docs
diffusers/docs/source/en/quantization/quanto.md
Line 105 in bb7fb66