Skip to content

[Quantization] Add Quanto backend #10756

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 41 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
ff50418
update
DN6 Feb 5, 2025
ba5bba7
updaet
DN6 Feb 5, 2025
aa8cdaf
update
DN6 Feb 5, 2025
39e20e2
update
DN6 Feb 8, 2025
f52050a
update
DN6 Feb 8, 2025
f4c14c2
update
DN6 Feb 10, 2025
f67d97c
update
DN6 Feb 10, 2025
5cff237
update
DN6 Feb 10, 2025
f734c09
update
DN6 Feb 10, 2025
7472f18
update
DN6 Feb 10, 2025
e96686e
update
DN6 Feb 10, 2025
4ae8691
update
DN6 Feb 10, 2025
7b841dc
Update docs/source/en/quantization/quanto.md
DN6 Feb 11, 2025
e090177
update
DN6 Feb 11, 2025
559f124
Merge https://github.com/huggingface/diffusers into add-quanto
DN6 Feb 11, 2025
9e5a3d0
Merge branch 'add-quanto' of https://github.com/huggingface/diffusers…
DN6 Feb 11, 2025
b136d23
update
DN6 Feb 11, 2025
2c7f303
update
DN6 Feb 11, 2025
c80d4d4
update
DN6 Feb 12, 2025
d355e6a
update
DN6 Feb 13, 2025
9a72fef
Merge branch 'main' into add-quanto
DN6 Feb 13, 2025
79901e4
update
DN6 Feb 18, 2025
c4b6e24
update
DN6 Feb 20, 2025
c29684f
Merge branch 'main' into add-quanto
DN6 Feb 20, 2025
6cf9a78
update
DN6 Feb 20, 2025
0736f87
update
DN6 Feb 20, 2025
4eabed7
update
DN6 Feb 25, 2025
f512c28
update
DN6 Feb 25, 2025
dbaef7c
update
DN6 Feb 25, 2025
963559f
update
DN6 Feb 25, 2025
156db08
Merge branch 'main' into add-quanto
DN6 Mar 3, 2025
4516f22
update
DN6 Mar 7, 2025
830b734
update
DN6 Mar 7, 2025
8afff1b
Merge branch 'main' into add-quanto
DN6 Mar 7, 2025
8163687
update
DN6 Mar 7, 2025
bb7fb66
update
DN6 Mar 7, 2025
6cad1d5
update
DN6 Mar 7, 2025
d5ab9ca
Update src/diffusers/quantizers/quanto/utils.py
DN6 Mar 7, 2025
deebc22
update
DN6 Mar 7, 2025
cf4694e
Merge branch 'add-quanto' of https://github.com/huggingface/diffusers…
DN6 Mar 7, 2025
1b46a32
update
DN6 Mar 10, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@
title: gguf
- local: quantization/torchao
title: torchao
- local: quantization/quanto
title: quanto
title: Quantization Methods
- sections:
- local: optimization/fp16
Expand Down
1 change: 1 addition & 0 deletions docs/source/en/quantization/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,6 @@ Diffusers currently supports the following quantization methods.
- [BitsandBytes](./bitsandbytes)
- [TorchAO](./torchao)
- [GGUF](./gguf)
- [Quanto](./quanto.md)

[This resource](https://huggingface.co/docs/transformers/main/en/quantization/overview#when-to-use-what) provides a good overview of the pros and cons of different quantization techniques.
107 changes: 107 additions & 0 deletions docs/source/en/quantization/quanto.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
<!--Copyright 2025 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.

-->

# Quanto

[Quanto](https://github.com/huggingface/optimum-quanto) is a PyTorch quantization backend for [Optimum](https://huggingface.co/docs/optimum/en/index). It has been designed with versatility and simplicity in mind:

- All features are available in eager mode (works with non-traceable models)
- Supports quantization aware training
- Quantized models are compatible with `torch.compile`
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Have we verified this? Last time I checked only weight-quantized models were compatible with torch.compile. Cc: @dacorvo.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

True, but this should be fixed in pytorch 2.6 (I did not check though).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dacorvo I tried to run torch compile with float8 weights in the following way and hit an error during inference

import torch
from diffusers import FluxPipeline, FluxTransformer2DModel, QuantoConfig
from optimum.quanto import quantize, freeze, qint8, qint4, qfloat8

model_id = "black-forest-labs/FLUX.1-dev"
transformer = FluxTransformer2DModel.from_pretrained(
    model_id,
    subfolder="transformer",
    torch_dtype=torch.bfloat16,
)
quantize(transformer, weights=qfloat8)
freeze(transformer)

transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
pipe = FluxPipeline.from_pretrained(
    model_id, transformer=transformer, torch_dtype=torch.bfloat16
)
pipe.to("cuda")
images = pipe("A cat holding a sign that says hello").images[0]
images.save("flux-quanto-compile.png")

Traceback:

  File "/home/dhruv/miniconda3/envs/mochi/lib/python3.11/site-packages/torch/_subclasses/fake_tensor.py", line 2082, in validate
    raise AssertionError(
torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(1, 4096, 64), dtype=torch.bfloat16
), MarlinF8QBytesTensor(MarlinF8PackedTensor(FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32)), scale=FakeTensor(..., device='cuda:0', size=(1, 3072
), dtype=torch.bfloat16), dtype=torch.bfloat16)), **{'bias': Parameter(FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.bfloat16,
           requires_grad=True))}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in quanto.gemm_f16f8_marlin.default(FakeTensor(..., de
vice='cuda:0', size=(4096, 64), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(1, 3
072), dtype=torch.bfloat16), tensor([...], device='cuda:0', size=(768,), dtype=torch.int32), 8, 4096, 3072, 64)

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 482, in forward
    hidden_states = self.x_embedder(hidden_states)
  File "/home/dhruv/miniconda3/envs/mochi/lib/python3.11/site-packages/optimum/quanto/nn/qlinear.py", line 50, in forward
    return torch.nn.functional.linear(input, self.qweight, bias=self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

The torch.compile step seems to work. The error is raised during the forward pass.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same with nightly?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah same errors with nightly.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's be specific that only int8 supports torch.compile for now?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mentioned in the compile section of the docs

- Quantized models are Device agnostic (e.g CUDA,XPU,MPS,CPU)

In order to use the Quanto backend, you will first need to install `optimum-quanto>=0.2.6` and `accelerate`

```shell
pip install optimum-quanto accelerate
```

Now you can quantize a model by passing the `QuantoConfig` object to the `from_pretrained()` method. The following snippet demonstrates how to apply `float8` quantization with Quanto.

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights="float8")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a comment to note that only weights will be quantized.

transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

subfolder missing.


pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
pipe.to("cuda")

prompt = "A cat holding a sign that says hello world"
image = pipe(
prompt, num_inference_steps=50, guidance_scale=4.5, max_sequence_length=512
).images[0]
image.save("output.png")
```

## Using `from_single_file` with the Quanto Backend

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

ckpt_path = "https://huggingface.co/black-forest-labs/FLUX.1-dev/blob/main/flux1-dev.safetensors"
quantization_config = QuantoConfig(weights="float8")
transformer = FluxTransformer2DModel.from_single_file(ckpt_path, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh lovely. Not to digress from this PR but would it make sense to also do something similar for bitsandbytes and torchao for from_single_file() or not yet?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TorchAO should just work out of the box. We can add a section in the docs page.

For BnB the conversion step in single file is still a bottleneck. We need to figure out how to handle that gracefully.

```

## Saving Quantized models

Diffusers supports serializing and saving Quanto models using the `save_pretrained` method.

```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights="float8")
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)

# save quantized model to reuse
transformer.save_pretrained("<your quantized model save path>")

# you can reload your quantized model with
model = FluxTransformer2DModel.from_pretrained("<your quantized model save path>")
```

## Using `torch.compile` with Quanto

Currently the Quanto backend only supports `torch.compile` for `int8` weights and activations.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure where this restriction comes from ... did you have issues with float8 or int4 because of the custom kernels ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I tested with Float8 weight quantization and ran into this error

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function linear>(*(FakeTensor(..., device='cuda:0', size=(1, 4096, 64), dtype=torch.bfloat16), MarlinF8QBytesTensor(MarlinF8PackedTensor(FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32)), scale=FakeTensor(..., device='cuda:0', size=(1, 3072), dtype=torch.bfloat16), dtype=torch.bfloat16)), **{'bias': Parameter(FakeTensor(..., device='cuda:0', size=(3072,), dtype=torch.bfloat16))}):
Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in quanto.gemm_f16f8_marlin.default(FakeTensor(..., device='cuda:0', size=(4096, 64), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(4, 12288), dtype=torch.int32), FakeTensor(..., device='cuda:0', size=(1, 3072), dtype=torch.bfloat16), tensor([...], device='cuda:0', size=(768,), dtype=torch.int32), 8, 4096, 3072, 64)

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 482, in forward
    hidden_states = self.x_embedder(hidden_states)
  File "/home/dhruv/optimum-quanto/optimum/quanto/nn/qlinear.py", line 50, in forward
    return torch.nn.functional.linear(input, self.qweight, bias=self.bias)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

And with INT4 I was running into what looks like a dtype issue, which I don't seem to run into when I'm not using compile

torch._dynamo.exc.TorchRuntimeError: Failed running call_function <built-in function scaled_dot_product_attention>(*(FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128), dtype=torch.bfloat16), FakeTensor(..., device='cuda:0', size=(1, 24, 4608, 128))), **{'dropout_p': 0.0, 'is_causal': False}):
Expected query, key, and value to have the same dtype, but got query.dtype: c10::BFloat16 key.dtype: c10::BFloat16 and value.dtype: float instead.

from user code:
   File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 529, in forward
    encoder_hidden_states, hidden_states = block(
  File "/home/dhruv/diffusers/src/diffusers/models/transformers/transformer_flux.py", line 188, in forward
    attention_outputs = self.attn(
  File "/home/dhruv/diffusers/src/diffusers/models/attention_processor.py", line 595, in forward
    return self.processor(
  File "/home/dhruv/diffusers/src/diffusers/models/attention_processor.py", line 2328, in __call__
    hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you try out PyTorch nightlies? How's the performance improvement with torch.compile() and int8?


```python
import torch
from diffusers import FluxTransformer2DModel, QuantoConfig

model_id = "black-forest-labs/FLUX.1-dev"
quantization_config = QuantoConfig(weights="int8")
transformer = FluxTransformer2DModel.from_pretrained(model_id, quantization_config=quantization_config, torch_dtype=torch.bfloat16)
transformer = torch.compile(transformer, mode="max-autotune", fullgraph=True)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great that this works.


pipe = FluxPipeline.from_pretrained(model_id, transformer=transformer, torch_dtype=torch_dtype)
pipe.to("cuda")
```

## Supported Quantization Types

### Weights

- float8
- int8
- int4
- int2

### Activations
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's show an example from the docs as well?

Additionally, we could refer the users to this blog post so that they have a sense of the savings around memory and latency?

- float8
- int8
4 changes: 4 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,10 @@
"GitPython<3.1.19",
"scipy",
"onnx",
"optimum_quanto>=0.2.6",
"gguf>=0.10.0",
"torchao>=0.7.0",
"bitsandbytes>=0.43.3",
"regex!=2019.12.17",
"requests",
"tensorboard",
Expand Down
94 changes: 92 additions & 2 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

from typing import TYPE_CHECKING

from diffusers.quantizers import quantization_config
from diffusers.utils import dummy_gguf_objects
from diffusers.utils.import_utils import (
is_bitsandbytes_available,
is_gguf_available,
is_optimum_quanto_version,
is_torchao_available,
)

from .utils import (
DIFFUSERS_SLOW_IMPORT,
OptionalDependencyNotAvailable,
Expand All @@ -11,6 +20,7 @@
is_librosa_available,
is_note_seq_available,
is_onnx_available,
is_optimum_quanto_available,
is_scipy_available,
is_sentencepiece_available,
is_torch_available,
Expand All @@ -32,7 +42,7 @@
"loaders": ["FromOriginalModelMixin"],
"models": [],
"pipelines": [],
"quantizers.quantization_config": ["BitsAndBytesConfig", "GGUFQuantizationConfig", "TorchAoConfig"],
"quantizers.quantization_config": [],
"schedulers": [],
"utils": [
"OptionalDependencyNotAvailable",
Expand All @@ -54,6 +64,55 @@
],
}

try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_bitsandbytes_objects

_import_structure["utils.dummy_bitsandbytes_objects"] = [
name for name in dir(dummy_bitsandbytes_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("BitsAndBytesConfig")

try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_gguf_objects

_import_structure["utils.dummy_gguf_objects"] = [
name for name in dir(dummy_gguf_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("GGUFQuantizationConfig")

try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_torchao_objects

_import_structure["utils.dummy_torchao_objects"] = [
name for name in dir(dummy_torchao_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("TorchAoConfig")

try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils import dummy_optimum_quanto_objects

_import_structure["utils.dummy_optimum_quanto_objects"] = [
name for name in dir(dummy_optimum_quanto_objects) if not name.startswith("_")
]
else:
_import_structure["quantizers.quantization_config"].append("QuantoConfig")


try:
if not is_onnx_available():
raise OptionalDependencyNotAvailable()
Expand Down Expand Up @@ -581,7 +640,38 @@

if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
from .configuration_utils import ConfigMixin
from .quantizers.quantization_config import BitsAndBytesConfig, GGUFQuantizationConfig, TorchAoConfig

try:
if not is_bitsandbytes_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_bitsandbytes_objects import *
else:
from .quantizers.quantization_config import BitsAndBytesConfig

try:
if not is_gguf_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_gguf_objects import *
else:
from .quantizers.quantization_config import GGUFQuantizationConfig

try:
if not is_torchao_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_torchao_objects import *
else:
from .quantizers.quantization_config import TorchAoConfig

try:
if not is_optimum_quanto_available():
raise OptionalDependencyNotAvailable()
except OptionalDependencyNotAvailable:
from .utils.dummy_optimum_quanto_objects import *
else:
from .quantizers.quantization_config import QuantoConfig

try:
if not is_onnx_available():
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/dependency_versions_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
"GitPython": "GitPython<3.1.19",
"scipy": "scipy",
"onnx": "onnx",
"optimum_quanto": "optimum_quanto>=0.2.6",
"gguf": "gguf>=0.10.0",
"torchao": "torchao>=0.7.0",
"bitsandbytes": "bitsandbytes>=0.43.3",
"regex": "regex!=2019.12.17",
"requests": "requests",
"tensorboard": "tensorboard",
Expand Down
6 changes: 4 additions & 2 deletions src/diffusers/models/model_loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def load_model_dict_into_meta(
and any(
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
)
and dtype == torch.float16
and dtype in [torch.float16, torch.bfloat16]
):
param = param.to(torch.float32)
if accepts_dtype:
Expand Down Expand Up @@ -248,7 +248,9 @@ def load_model_dict_into_meta(
if is_quantized and (
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
):
hf_quantizer.create_quantized_param(model, param, param_name, device, state_dict, unexpected_keys)
hf_quantizer.create_quantized_param(
model, param, param_name, device, state_dict, unexpected_keys, dtype=dtype
)
else:
if accepts_dtype:
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
Expand Down
3 changes: 1 addition & 2 deletions src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P

# Check if `_keep_in_fp32_modules` is not None
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
(torch_dtype in [torch.float16, torch.bfloat16]) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
)
if use_keep_in_fp32_modules:
keep_in_fp32_modules = cls._keep_in_fp32_modules
Expand Down Expand Up @@ -1041,7 +1041,6 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
model,
state_dict,
device=param_device,
dtype=torch_dtype,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this going away?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this is a mistake. Thanks for catching.

model_name_or_path=pretrained_model_name_or_path,
hf_quantizer=hf_quantizer,
keep_in_fp32_modules=keep_in_fp32_modules,
Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/quantizers/auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@
GGUFQuantizationConfig,
QuantizationConfigMixin,
QuantizationMethod,
QuantoConfig,
TorchAoConfig,
)
from .quanto import QuantoQuantizer
from .torchao import TorchAoHfQuantizer


Expand All @@ -36,13 +38,15 @@
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
"gguf": GGUFQuantizer,
"torchao": TorchAoHfQuantizer,
"quanto": QuantoQuantizer,
}

AUTO_QUANTIZATION_CONFIG_MAPPING = {
"bitsandbytes_4bit": BitsAndBytesConfig,
"bitsandbytes_8bit": BitsAndBytesConfig,
"gguf": GGUFQuantizationConfig,
"torchao": TorchAoConfig,
"quanto": QuantoConfig,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit): maybe this should go above torchao to keep quantizers in alphabetical order (does not really have to be addressed and we can do it order of quantization backend addition as well)

}


Expand Down
45 changes: 45 additions & 0 deletions src/diffusers/quantizers/quantization_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class QuantizationMethod(str, Enum):
BITS_AND_BYTES = "bitsandbytes"
GGUF = "gguf"
TORCHAO = "torchao"
QUANTO = "quanto"


@dataclass
Expand Down Expand Up @@ -674,3 +675,47 @@ def __repr__(self):
"""
config_dict = self.to_dict()
return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"


@dataclass
class QuantoConfig(QuantizationConfigMixin):
"""
This is a wrapper class about all possible attributes and features that you can play with a model that has been
loaded using `quanto`.

Args:
weights (`str`, *optional*, defaults to `"int8"`):
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
activations (`str`, *optional*):
The target dtype for the activations after quantization. Supported values are (None,"int8","float8")
modules_to_not_convert (`list`, *optional*, default to `None`):
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
"""

def __init__(
self,
weights="int8",
activations=None,
modules_to_not_convert: Optional[List] = None,
compute_dtype: Optional["torch.dtype"] = None,
**kwargs,
):
self.quant_method = QuantizationMethod.QUANTO
self.weights = weights
self.activations = activations
self.modules_to_not_convert = modules_to_not_convert

self.post_init()

def post_init(self):
r"""
Safety checker that arguments are correct
"""
accepted_weights = ["float8", "int8", "int4", "int2"]
accepted_activations = [None, "int8", "float8"]
if self.weights not in accepted_weights:
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights}")

if self.activations not in accepted_activations:
raise ValueError(f"Only support weights in {accepted_activations} but found {self.activations}")
1 change: 1 addition & 0 deletions src/diffusers/quantizers/quanto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .quanto_quantizer import QuantoQuantizer
Loading