Skip to content

[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components #9840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
35b4cf2
allow device placement when using bnb quantization.
sayakpaul Nov 1, 2024
ec4d422
warning.
sayakpaul Nov 2, 2024
2afa9b0
tests
sayakpaul Nov 2, 2024
3679ebd
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 2, 2024
79633ee
fixes
sayakpaul Nov 5, 2024
876cd13
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
a28c702
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
ad1584d
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
34d0925
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 7, 2024
d713c41
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 11, 2024
e9ef6ea
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 15, 2024
6ce560e
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 16, 2024
329b32e
docs.
sayakpaul Nov 16, 2024
2f6b07d
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 18, 2024
fdeb500
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 19, 2024
53bc502
require accelerate version.
sayakpaul Nov 19, 2024
f81b71e
remove print.
sayakpaul Nov 19, 2024
8e1b6f5
revert to()
sayakpaul Nov 21, 2024
e3e3a96
tests
sayakpaul Nov 21, 2024
9e9561b
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 21, 2024
2ddcbf1
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 24, 2024
5130cc3
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 26, 2024
e76f93a
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 29, 2024
1963b5c
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 2, 2024
a799ba8
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 2, 2024
7d47364
fixes
sayakpaul Dec 2, 2024
ebfec45
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 3, 2024
1fe8a79
fix: missing AutoencoderKL lora adapter (#9807)
beniz Dec 3, 2024
f05d81d
fixes
sayakpaul Dec 3, 2024
6e17cad
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
ea09eb2
fix condition test
sayakpaul Dec 4, 2024
1779093
updates
sayakpaul Dec 4, 2024
6ff53e3
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
7b73dc2
updates
sayakpaul Dec 4, 2024
729acea
remove is_offloaded.
sayakpaul Dec 4, 2024
3d3aab4
fixes
sayakpaul Dec 4, 2024
c033816
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
b5cffab
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
662868b
better
sayakpaul Dec 4, 2024
3fc15fe
empty
sayakpaul Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 21 additions & 5 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
is_accelerate_version,
is_torch_npu_available,
is_torch_version,
is_transformers_available,
is_transformers_version,
logging,
numpy_to_pil,
Expand All @@ -66,6 +67,8 @@
if is_torch_npu_available():
import torch_npu # noqa: F401

if is_transformers_available():
from transformers import PreTrainedModel

from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
Expand Down Expand Up @@ -410,10 +413,14 @@ def module_is_offloaded(module):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
pipeline_has_bnb = any(
(_check_bnb_status(module)[1] or _check_bnb_status(module)[-1]) for _, module in self.components.items()
)
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
if not pipeline_has_bnb:
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
Expand Down Expand Up @@ -448,8 +455,17 @@ def module_is_offloaded(module):

# This can happen for `transformer` models. CPU placement was added in
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
module.to(device=device)
if is_loaded_in_4bit_bnb and device is not None:
if is_transformers_available() and isinstance(module, PreTrainedModel):
if is_transformers_version(">", "4.44.0"):
module.to(device=device)
else:
logger.warning(
f"{module.__class__.__name__} could not be placed on {device}. Module is still on {module.device}. Please update your `transformers` installation to the latest."
)
# For `diffusers` it should not be a problem.
else:
module.to(device=device)
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
module.to(device, dtype)

Expand Down
37 changes: 37 additions & 0 deletions tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def get_some_linear_layer(model):


if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel

if is_torch_available():
Expand Down Expand Up @@ -484,6 +485,42 @@ def test_moving_to_cpu_throws_warning(self):

assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out

def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
transformer_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")

# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)

del pipeline_4bit


@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
Expand Down
29 changes: 29 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def get_some_linear_layer(model):


if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel

if is_torch_available():
Expand Down Expand Up @@ -432,6 +433,34 @@ def test_generate_quality_dequantize(self):
output_type="np",
).images

def test_pipeline_cuda_placement_works_with_mixed_int8(self):
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_8bit,
text_encoder_3=text_encoder_3_8bit,
torch_dtype=torch.float16,
).to("cuda")

# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)

del pipeline_8bit


@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
Expand Down
Loading