Skip to content

[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components #9840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Dec 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
35b4cf2
allow device placement when using bnb quantization.
sayakpaul Nov 1, 2024
ec4d422
warning.
sayakpaul Nov 2, 2024
2afa9b0
tests
sayakpaul Nov 2, 2024
3679ebd
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 2, 2024
79633ee
fixes
sayakpaul Nov 5, 2024
876cd13
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
a28c702
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
ad1584d
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 5, 2024
34d0925
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 7, 2024
d713c41
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 11, 2024
e9ef6ea
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 15, 2024
6ce560e
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 16, 2024
329b32e
docs.
sayakpaul Nov 16, 2024
2f6b07d
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 18, 2024
fdeb500
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 19, 2024
53bc502
require accelerate version.
sayakpaul Nov 19, 2024
f81b71e
remove print.
sayakpaul Nov 19, 2024
8e1b6f5
revert to()
sayakpaul Nov 21, 2024
e3e3a96
tests
sayakpaul Nov 21, 2024
9e9561b
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 21, 2024
2ddcbf1
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 24, 2024
5130cc3
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 26, 2024
e76f93a
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Nov 29, 2024
1963b5c
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 2, 2024
a799ba8
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 2, 2024
7d47364
fixes
sayakpaul Dec 2, 2024
ebfec45
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 3, 2024
1fe8a79
fix: missing AutoencoderKL lora adapter (#9807)
beniz Dec 3, 2024
f05d81d
fixes
sayakpaul Dec 3, 2024
6e17cad
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
ea09eb2
fix condition test
sayakpaul Dec 4, 2024
1779093
updates
sayakpaul Dec 4, 2024
6ff53e3
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
7b73dc2
updates
sayakpaul Dec 4, 2024
729acea
remove is_offloaded.
sayakpaul Dec 4, 2024
3d3aab4
fixes
sayakpaul Dec 4, 2024
c033816
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
b5cffab
Merge branch 'main' into allow-device-placement-bnb
sayakpaul Dec 4, 2024
662868b
better
sayakpaul Dec 4, 2024
3fc15fe
empty
sayakpaul Dec 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions src/diffusers/pipelines/pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,6 @@
if is_torch_npu_available():
import torch_npu # noqa: F401


from .pipeline_loading_utils import (
ALL_IMPORTABLE_CLASSES,
CONNECTED_PIPES_KEYS,
Expand Down Expand Up @@ -388,6 +387,7 @@ def to(self, *args, **kwargs):
)

device = device or device_arg
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to have some overlapping logics with the code just a little bit below this, no?

if is_loaded_in_8bit_bnb and device is not None:

Copy link
Member Author

@sayakpaul sayakpaul Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

However, the LoC you pointed out is relevant when we're transferring an 8bit quantized model from one device to the other. It's a log to let the users know that this model has already been placed on a GPU and will remain so. Requesting to put it on a CPU will be ineffective.

We call self.to("cpu") when doing enable_model_cpu_offload():

self.to("cpu", silence_dtype_warnings=True)

So, this kind of log becomes informative in the context of using enable_model_cpu_offload(), for example.

This PR, however, allows users to move an entire pipeline to a GPU when the memory permits. Previously it wasn't possible.

So, maybe this apparent overlap is justified. LMK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR, however, allows users to move an entire pipeline to a GPU when the memory permits. Previously it wasn't possible.

did I miss something?
this PR add a check which throw a value error under certain condition - not enable a new use case like you described here, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the enablement comes from the accelerate fix huggingface/accelerate#3223 and this PR adds a check for that as you described. Sorry for the wrong order of words 😅

If you have other comments on the PR happy to address them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my previous comments stands, it has overlapping logic with other checks you have below and is very very confusing.

you're not enable a new use case here, this PR correct a previous wrong error message and allow user to take correct action, I would simply update the warning message here, to add the other possible scenario that they are trying to call to("cuda") on a quantized model without offloading, and they need to upgrade accelerate in order to do that

if pipeline_is_offloaded and device and torch.device(device).type == "cuda":

Copy link
Member Author

@sayakpaul sayakpaul Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR correct a previous wrong error message

What was the wrong error message?

IIUC the line you're point to has nothing to do with the changes introduced in this PR and has been in the codebase for quite a while.

The problem line (fixed by the accelerate PR) was this:

if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":

So, what I have done in 1779093 is as follows:

Updated the condition of the error message:

"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."

to:

if (
      not pipeline_is_offloaded
      and not pipeline_is_sequentially_offloaded
      and pipeline_has_bnb
      and torch.device(device).type == "cuda"
      and is_accelerate_version("<", "1.1.0.dev0")
):

This now also considers when the pipeline is not offloaded. Additionally,

f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."

now also considers if the pipeline is not offloaded:

if is_loaded_in_8bit_bnb and not is_offloaded and device is not None:


# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
def module_is_sequentially_offloaded(module):
Expand All @@ -410,10 +410,16 @@ def module_is_offloaded(module):
pipeline_is_sequentially_offloaded = any(
module_is_sequentially_offloaded(module) for _, module in self.components.items()
)
if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my previous comments here apply almost exactly here so I will just repeat it
#9840

the error message you want to throw against this scenario:

  • accelerator < 1.1.0.dev0
  • you call pipeline.to("cuda") on a pipeline that has bnb

if these 2 condition are met (older accelerator version + bnb), it will not reach the error message you intended, it will be caught here at this firs check, and the error message is same as before this PR (about offloading)

can you do this? #9840 (comment)

IF not, please remove the changes to pipline_utils.py and we can merge (I will work on it in a separate PR) I think the added tests are fine without the changes: if accecelrate version is new, it is not affected by the changes in this PR; if it is not, it throw a different error, that's all

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I was wrong! will merge

Copy link
Member Author

@sayakpaul sayakpaul Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that works but here's my last try.

if these 2 condition are met (older accelerator version + bnb), it will not reach the error message you intended, it will be caught here at this firs check, and the error message is same as before this PR (about offloading)

When you have:

model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
pipeline_4bit = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    text_encoder_2=t5_4bit,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
)

in if pipeline_is_sequentially_offloaded and not pipeline_has_bnb, pipeline_is_sequentially_offloaded will be True (older accelerate version), however, not pipeline_has_bnb will be False (as expected). So, the following error won't be raised:

"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."

And it will hit the else.

To test, you can run the following with accelerate 1.0.1:

from diffusers import DiffusionPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch 

model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
pipeline_4bit = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    text_encoder_2=t5_4bit,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
).to("cuda")

It throws:

ValueError: You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation.

Isn't this what we expect or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I missed that not pipeline_has_bnb in the statement, it works

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw your comment. Thanks for beating it with me :)

raise ValueError(
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
# PR: https://github.com/huggingface/accelerate/pull/3223/
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
raise ValueError(
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
)

is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
if is_pipeline_device_mapped:
Expand Down
45 changes: 44 additions & 1 deletion tests/quantization/bnb/test_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@
import unittest

import numpy as np
import pytest
import safetensors.torch

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel
from diffusers.utils import logging
from diffusers.utils import is_accelerate_version, logging
from diffusers.utils.testing_utils import (
CaptureLogger,
is_bitsandbytes_available,
Expand All @@ -47,6 +48,7 @@ def get_some_linear_layer(model):


if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel

if is_torch_available():
Expand Down Expand Up @@ -483,6 +485,47 @@ def test_moving_to_cpu_throws_warning(self):

assert "Pipelines loaded with `dtype=torch.float16`" in cap_logger.out

@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_nf4(self):
transformer_nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
transformer_4bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_nf4_config,
torch_dtype=torch.float16,
)
text_encoder_3_nf4_config = BnbConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
)
text_encoder_3_4bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_nf4_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_4bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_4bit,
text_encoder_3=text_encoder_3_4bit,
torch_dtype=torch.float16,
).to("cuda")

# Check if inference works.
_ = pipeline_4bit("table", max_sequence_length=20, num_inference_steps=2)

del pipeline_4bit


@require_transformers_version_greater("4.44.0")
class SlowBnb4BitFluxTests(Base4bitTests):
Expand Down
36 changes: 36 additions & 0 deletions tests/quantization/bnb/test_mixed_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@
import unittest

import numpy as np
import pytest

from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
from diffusers.utils import is_accelerate_version
from diffusers.utils.testing_utils import (
CaptureLogger,
is_bitsandbytes_available,
Expand All @@ -44,6 +46,7 @@ def get_some_linear_layer(model):


if is_transformers_available():
from transformers import BitsAndBytesConfig as BnbConfig
from transformers import T5EncoderModel

if is_torch_available():
Expand Down Expand Up @@ -432,6 +435,39 @@ def test_generate_quality_dequantize(self):
output_type="np",
).images

@pytest.mark.xfail(
condition=is_accelerate_version("<=", "1.1.1"),
reason="Test will pass after https://github.com/huggingface/accelerate/pull/3223 is in a release.",
strict=True,
)
def test_pipeline_cuda_placement_works_with_mixed_int8(self):
transformer_8bit_config = BitsAndBytesConfig(load_in_8bit=True)
transformer_8bit = SD3Transformer2DModel.from_pretrained(
self.model_name,
subfolder="transformer",
quantization_config=transformer_8bit_config,
torch_dtype=torch.float16,
)
text_encoder_3_8bit_config = BnbConfig(load_in_8bit=True)
text_encoder_3_8bit = T5EncoderModel.from_pretrained(
self.model_name,
subfolder="text_encoder_3",
quantization_config=text_encoder_3_8bit_config,
torch_dtype=torch.float16,
)
# CUDA device placement works.
pipeline_8bit = DiffusionPipeline.from_pretrained(
self.model_name,
transformer=transformer_8bit,
text_encoder_3=text_encoder_3_8bit,
torch_dtype=torch.float16,
).to("cuda")

# Check if inference works.
_ = pipeline_8bit("table", max_sequence_length=20, num_inference_steps=2)

del pipeline_8bit


@require_transformers_version_greater("4.44.0")
class SlowBnb8bitFluxTests(Base8bitTests):
Expand Down
Loading