Skip to content

[Tests] improve quantization tests by additionally measuring the inference memory savings #11021

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/diffusers/quantizers/bitsandbytes/bnb_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ def create_quantized_param(
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
**kwargs,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or should we just do dtype and keep it unused. Unused params aren't new to this method. For example:

state_dict: Dict[str, Any],
unexpected_keys: List[str],

):
import bitsandbytes as bnb

Expand Down Expand Up @@ -445,6 +446,7 @@ def create_quantized_param(
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: Optional[List[str]] = None,
**kwargs,
):
import bitsandbytes as bnb

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/gguf/gguf_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ def create_quantized_param(
target_device: "torch.device",
state_dict: Optional[Dict[str, Any]] = None,
unexpected_keys: Optional[List[str]] = None,
**kwargs,
):
module, tensor_name = get_module_from_name(model, param_name)
if tensor_name not in module._parameters and tensor_name not in module._buffers:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/quantizers/torchao/torchao_quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def create_quantized_param(
target_device: "torch.device",
state_dict: Dict[str, Any],
unexpected_keys: List[str],
**kwargs,
):
r"""
Each nn.Linear layer that needs to be quantized is processsed here. First, we set the value the weight tensor,
Expand Down
Empty file added tests/quantization/__init__.py
Empty file.
57 changes: 33 additions & 24 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,29 +54,8 @@ def get_some_linear_layer(model):

if is_torch_available():
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only

Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""

def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)

def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat


if is_bitsandbytes_available():
Expand All @@ -96,6 +75,8 @@ class Base4bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change
expected_rel_difference = 3.69

expected_memory_saving_ratio = 0.8

prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
Expand Down Expand Up @@ -140,8 +121,10 @@ def setUp(self):
)

def tearDown(self):
del self.model_fp16
del self.model_4bit
if hasattr(self, "model_fp16"):
del self.model_fp16
if hasattr(self, "model_4bit"):
del self.model_4bit

gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -180,6 +163,32 @@ def test_memory_footprint(self):
linear = get_some_linear_layer(self.model_4bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Params4bit)

def test_model_memory_usage(self):
# Delete to not let anything interfere.
del self.model_4bit, self.model_fp16

# Re-instantiate.
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
}
model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
).to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
del model_fp16

nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
model_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=nf4_config, torch_dtype=torch.float16
)
quantized_model_memory = get_memory_consumption_stat(model_4bit, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio

def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
Expand Down
55 changes: 30 additions & 25 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,29 +60,8 @@ def get_some_linear_layer(model):

if is_torch_available():
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only

Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_8bit.py#L62C5-L78C77
"""

def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)

def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat


if is_bitsandbytes_available():
Expand All @@ -102,6 +81,8 @@ class Base8bitTests(unittest.TestCase):
# This was obtained on audace so the number might slightly change
expected_rel_difference = 1.94

expected_memory_saving_ratio = 0.7

prompt = "a beautiful sunset amidst the mountains."
num_inference_steps = 10
seed = 0
Expand Down Expand Up @@ -142,8 +123,10 @@ def setUp(self):
)

def tearDown(self):
del self.model_fp16
del self.model_8bit
if hasattr(self, "model_fp16"):
del self.model_fp16
if hasattr(self, "model_8bit"):
del self.model_8bit

gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -182,6 +165,28 @@ def test_memory_footprint(self):
linear = get_some_linear_layer(self.model_8bit)
self.assertTrue(linear.weight.__class__ == bnb.nn.Int8Params)

def test_model_memory_usage(self):
# Delete to not let anything interfere.
del self.model_8bit, self.model_fp16

# Re-instantiate.
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.float16) for k, v in inputs.items() if not isinstance(v, bool)
}
model_fp16 = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", torch_dtype=torch.float16
).to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(model_fp16, inputs)
del model_fp16

config = BitsAndBytesConfig(load_in_8bit=True)
model_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name, subfolder="transformer", quantization_config=config, torch_dtype=torch.float16
)
quantized_model_memory = get_memory_consumption_stat(model_8bit, inputs)
assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_saving_ratio

def test_original_dtype(self):
r"""
A simple test to check if the model succesfully stores the original dtype
Expand Down Expand Up @@ -248,7 +253,7 @@ def test_llm_skip(self):
self.assertTrue(linear.weight.dtype == torch.int8)
self.assertTrue(isinstance(linear, bnb.nn.Linear8bitLt))

self.assertTrue(isinstance(model_8bit.proj_out, nn.Linear))
self.assertTrue(isinstance(model_8bit.proj_out, torch.nn.Linear))
self.assertTrue(model_8bit.proj_out.weight.dtype != torch.int8)

def test_config_from_pretrained(self):
Expand Down
Empty file.
49 changes: 14 additions & 35 deletions tests/quantization/quanto/test_quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,8 @@

if is_torch_available():
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only

Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""

def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)

def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat


@nightly
Expand Down Expand Up @@ -85,20 +64,20 @@ def test_quanto_layers(self):
assert isinstance(module, QLinear)

def test_quanto_memory_usage(self):
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model_memory = unquantized_model.get_memory_footprint() / 1024**3

model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
inputs = self.get_dummy_inputs()
inputs = {
k: v.to(device=torch_device, dtype=torch.bfloat16) for k, v in inputs.items() if not isinstance(v, bool)
}

torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()
unquantized_model = self.model_cls.from_pretrained(self.model_id, torch_dtype=self.torch_dtype)
unquantized_model.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(unquantized_model, inputs)

model.to(torch_device)
with torch.no_grad():
model(**inputs)
max_memory = torch.cuda.max_memory_allocated() / 1024**3
assert (1.0 - (max_memory / unquantized_model_memory)) >= self.expected_memory_reduction
quantized_model = self.model_cls.from_pretrained(**self.get_dummy_model_init_kwargs())
quantized_model.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(quantized_model, inputs)

assert unquantized_model_memory / quantized_model_memory >= self.expected_memory_reduction

def test_keep_modules_in_fp32(self):
r"""
Expand Down Expand Up @@ -318,14 +297,14 @@ def test_training(self):


class FluxTransformerFloat8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3
expected_memory_reduction = 0.6

def get_dummy_init_kwargs(self):
return {"weights_dtype": "float8"}


class FluxTransformerInt8WeightsTest(FluxTransformerQuantoMixin, unittest.TestCase):
expected_memory_reduction = 0.3
expected_memory_reduction = 0.6
_test_torch_compile = True

def get_dummy_init_kwargs(self):
Expand Down
Empty file.
38 changes: 17 additions & 21 deletions tests/quantization/torchao/test_torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,7 @@
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only

Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""

def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)

def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)
from ..utils import LoRALayer, get_memory_consumption_stat


if is_torchao_available():
Expand Down Expand Up @@ -503,6 +483,22 @@ def test_memory_footprint(self):
# there is additional overhead of scales and zero points
self.assertTrue(total_bf16 < total_int4wo)

def test_model_memory_usage(self):
model_id = "hf-internal-testing/tiny-flux-pipe"
expected_memory_saving_ratio = 2.0

inputs = self.get_dummy_tensor_inputs(device=torch_device)

transformer_bf16 = self.get_dummy_components(None, model_id=model_id)["transformer"]
transformer_bf16.to(torch_device)
unquantized_model_memory = get_memory_consumption_stat(transformer_bf16, inputs)
del transformer_bf16

transformer_int8wo = self.get_dummy_components(TorchAoConfig("int8wo"), model_id=model_id)["transformer"]
transformer_int8wo.to(torch_device)
quantized_model_memory = get_memory_consumption_stat(transformer_int8wo, inputs)
assert unquantized_model_memory / quantized_model_memory >= expected_memory_saving_ratio

def test_wrong_config(self):
with self.assertRaises(ValueError):
self.get_dummy_components(TorchAoConfig("int42"))
Expand Down
38 changes: 38 additions & 0 deletions tests/quantization/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from diffusers.utils import is_torch_available


if is_torch_available():
import torch
import torch.nn as nn

class LoRALayer(nn.Module):
"""Wraps a linear layer with LoRA-like adapter - Used for testing purposes only

Taken from
https://github.com/huggingface/transformers/blob/566302686a71de14125717dea9a6a45b24d42b37/tests/quantization/bnb/test_4bit.py#L62C5-L78C77
"""

def __init__(self, module: nn.Module, rank: int):
super().__init__()
self.module = module
self.adapter = nn.Sequential(
nn.Linear(module.in_features, rank, bias=False),
nn.Linear(rank, module.out_features, bias=False),
)
small_std = (2.0 / (5 * min(module.in_features, module.out_features))) ** 0.5
nn.init.normal_(self.adapter[0].weight, std=small_std)
nn.init.zeros_(self.adapter[1].weight)
self.adapter.to(module.weight.device)

def forward(self, input, *args, **kwargs):
return self.module(input, *args, **kwargs) + self.adapter(input)

@torch.no_grad()
@torch.inference_mode()
def get_memory_consumption_stat(model, inputs):
torch.cuda.reset_peak_memory_stats()
torch.cuda.empty_cache()

model(**inputs)
max_memory_mem_allocated = torch.cuda.max_memory_allocated()
return max_memory_mem_allocated
Loading