Skip to content

[bitsandbytes] allow directly CUDA placements of pipelines loaded with bnb components #9840

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 40 commits into from
Dec 4, 2024

Conversation

sayakpaul
Copy link
Member

What does this PR do?

When a pipeline is loaded with models that have quantization config, we should still be able to call to("cuda") on the pipeline object. For GPUs that would allow the memory (such as a 4090), this has performance benefits (as demonstrated below).

Model CPU Offload Batch Size Time (seconds) Memory (GB)
False 1 19.316 14.935
True 1 36.746 12.139
False 4 80.665 20.576
True 4 98.612 12.138

Flux.1 Dev, steps: 30

Currently, calling to("cuda") is not possible because:

from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as BnbConfig
import torch 

ckpt_id = "black-forest-labs/FLUX.1-dev"

text_encoder_2_config = BnbConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
text_encoder_2 = T5EncoderModel.from_pretrained(
    ckpt_id,
    subfolder="text_encoder_2",
    quantization_config=text_encoder_2_config,
    torch_dtype=torch.bfloat16
)
print(text_encoder_2._hf_hook)

has:

AlignDevicesHook(execution_device=0, offload=False, io_same_device=True, offload_buffers=False, place_submodules=True, skip_keys=None)

This is why this line complains:

if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":

This PR fixes that behavior.

Benchmarking code:

Unroll
from diffusers import DiffusionPipeline, FluxTransformer2DModel, BitsAndBytesConfig
from transformers import T5EncoderModel
from transformers import BitsAndBytesConfig as BnbConfig
import torch.utils.benchmark as benchmark
import torch 
import fire

def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)",
        globals={"args": args, "kwargs": kwargs, "f": f},
        num_threads=torch.get_num_threads(),
    )
    return f"{(t0.blocked_autorange().mean):.3f}"

def load_pipeline(model_cpu_offload=False):
    ckpt_id = "black-forest-labs/FLUX.1-dev"

    transformer_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    transformer = FluxTransformer2DModel.from_pretrained(
        ckpt_id, 
        subfolder="transformer",
        quantization_config=transformer_config,
        torch_dtype=torch.bfloat16
    )

    text_encoder_2_config = BnbConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.bfloat16,
    )
    text_encoder_2 = T5EncoderModel.from_pretrained(
        ckpt_id,
        subfolder="text_encoder_2",
        quantization_config=text_encoder_2_config,
        torch_dtype=torch.bfloat16
    )

    pipeline = DiffusionPipeline.from_pretrained(
        "black-forest-labs/FLUX.1-dev",
        text_encoder_2=text_encoder_2,
        transformer=transformer,
        torch_dtype=torch.bfloat16,
    )
    if model_cpu_offload:
        pipeline.enable_model_cpu_offload()
    else:
        pipeline = pipeline.to("cuda")

    pipeline.set_progress_bar_config(disable=True)
    return pipeline

def run_pipeline(pipeline, batch_size=1):
    _ = pipeline(
        prompt="a dog sitting besides a sea", 
        guidance_scale=3.5, 
        max_sequence_length=512, 
        num_inference_steps=30,
        num_images_per_prompt=batch_size
    )


def main(batch_size: int = 1, model_cpu_offload: bool = False):
    pipeline = load_pipeline(model_cpu_offload=model_cpu_offload)

    for _ in range(5):
        run_pipeline(pipeline)

    time = benchmark_fn(run_pipeline, pipeline, batch_size)
    memory = torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024
    print(f"{model_cpu_offload=}, {batch_size=} {time=} seconds {memory=} GB.")

    image = pipeline(
        prompt="a dog sitting besides a sea", 
        guidance_scale=3.5, 
        max_sequence_length=512, 
        num_inference_steps=30,
        num_images_per_prompt=1
    ).images[0]
    img_name = f"mco@{model_cpu_offload}-bs@{batch_size}.png"
    image.save(img_name)


if __name__ == "__main__":
    fire.Fire(main)

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@sayakpaul sayakpaul requested review from DN6, yiyixuxu and SunMarc November 2, 2024 04:34
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR ! Left a suggestion

@sayakpaul
Copy link
Member Author

@SunMarc WDYT now?

@sayakpaul sayakpaul requested a review from SunMarc November 11, 2024 11:33
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding this ! LGTM ! I'll marge the PR on accelerate also

@sayakpaul
Copy link
Member Author

Have run the integration tests and they are passing.

@SunMarc
Copy link
Member

SunMarc commented Nov 18, 2024

Have run the integration tests and they are passing.
On diffusers ?

@sayakpaul
Copy link
Member Author

@SunMarc yes, on diffusers. Anywhere else they need to be run?

@SunMarc
Copy link
Member

SunMarc commented Nov 18, 2024

No, I read that as a question, my bad ;)

@sayakpaul sayakpaul requested a review from yiyixuxu November 26, 2024 06:11
Comment on lines 427 to 429
pipeline_has_bnb = any(
(_check_bnb_status(module)[1] or _check_bnb_status(module)[-1]) for _, module in self.components.items()
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO cleaner.

Suggested change
pipeline_has_bnb = any(
(_check_bnb_status(module)[1] or _check_bnb_status(module)[-1]) for _, module in self.components.items()
)
pipeline_has_bnb = any(
any((_check_bnb_status(module))) for _, module in self.components.items()
)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this check is placed after the sequential offloading check, placement would still fail right?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running the test gives:

E           ValueError: It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading.

src/diffusers/pipelines/pipeline_utils.py:417: ValueError

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this check is placed after the sequential offloading check, placement would still fail right?

I have modified the placement of the logic. Could you check again?

Re. tests, I just ran pytest tests/quantization/bnb/test_4bit.py::SlowBnb4BitTests and pytest tests/quantization/bnb/test_mixed_int8.py::SlowBnb8bitTests and everything passed.

You need this PR huggingface/accelerate#3223 for this to work.

@sayakpaul sayakpaul requested a review from DN6 December 2, 2024 10:29
@@ -389,6 +392,13 @@ def to(self, *args, **kwargs):

device = device or device_arg

pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems to have some overlapping logics with the code just a little bit below this, no?

if is_loaded_in_8bit_bnb and device is not None:

Copy link
Member Author

@sayakpaul sayakpaul Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

However, the LoC you pointed out is relevant when we're transferring an 8bit quantized model from one device to the other. It's a log to let the users know that this model has already been placed on a GPU and will remain so. Requesting to put it on a CPU will be ineffective.

We call self.to("cpu") when doing enable_model_cpu_offload():

self.to("cpu", silence_dtype_warnings=True)

So, this kind of log becomes informative in the context of using enable_model_cpu_offload(), for example.

This PR, however, allows users to move an entire pipeline to a GPU when the memory permits. Previously it wasn't possible.

So, maybe this apparent overlap is justified. LMK.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This PR, however, allows users to move an entire pipeline to a GPU when the memory permits. Previously it wasn't possible.

did I miss something?
this PR add a check which throw a value error under certain condition - not enable a new use case like you described here, no?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well, the enablement comes from the accelerate fix huggingface/accelerate#3223 and this PR adds a check for that as you described. Sorry for the wrong order of words 😅

If you have other comments on the PR happy to address them.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my previous comments stands, it has overlapping logic with other checks you have below and is very very confusing.

you're not enable a new use case here, this PR correct a previous wrong error message and allow user to take correct action, I would simply update the warning message here, to add the other possible scenario that they are trying to call to("cuda") on a quantized model without offloading, and they need to upgrade accelerate in order to do that

if pipeline_is_offloaded and device and torch.device(device).type == "cuda":

Copy link
Member Author

@sayakpaul sayakpaul Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this PR correct a previous wrong error message

What was the wrong error message?

IIUC the line you're point to has nothing to do with the changes introduced in this PR and has been in the codebase for quite a while.

The problem line (fixed by the accelerate PR) was this:

if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":

So, what I have done in 1779093 is as follows:

Updated the condition of the error message:

"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."

to:

if (
      not pipeline_is_offloaded
      and not pipeline_is_sequentially_offloaded
      and pipeline_has_bnb
      and torch.device(device).type == "cuda"
      and is_accelerate_version("<", "1.1.0.dev0")
):

This now also considers when the pipeline is not offloaded. Additionally,

f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."

now also considers if the pipeline is not offloaded:

if is_loaded_in_8bit_bnb and not is_offloaded and device is not None:

and torch.device(device).type == "cuda"
and is_accelerate_version("<", "1.1.0.dev0")
):
raise ValueError(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the error message you want to throw against this scenario, no?

  1. accelerator < 1.1.0.dev0
  2. you call pipeline.to("cuda") on a pipeline that has bnb

but if these 2 condition are met (older accelerator version + bnb):

  1. not pipeline_is_sequentially_offloadedwill beFalse` here and you will not reach the value error
  2. you will reach this check first and get an error message -this is the wrong error message I was talking about
    if pipeline_is_sequentially_offloaded and device and torch.device(device).type == "cuda":
 if (
            not pipeline_is_offloaded
            and not pipeline_is_sequentially_offloaded
            and pipeline_has_bnb
            and torch.device(device).type == "cuda"
            and is_accelerate_version("<", "1.1.0.dev0")
        ):

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this makes a ton of sense. Thanks for the elaborate clarification. I have reflected this in my latest commits.

I have also tested most of the SLOW tests and they are passing. This is to ensure existing functionalities don't break with the current changes.

LMK.

"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
)
if device and torch.device(device).type == "cuda":
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my previous comments here apply almost exactly here so I will just repeat it
#9840

the error message you want to throw against this scenario:

  • accelerator < 1.1.0.dev0
  • you call pipeline.to("cuda") on a pipeline that has bnb

if these 2 condition are met (older accelerator version + bnb), it will not reach the error message you intended, it will be caught here at this firs check, and the error message is same as before this PR (about offloading)

can you do this? #9840 (comment)

IF not, please remove the changes to pipline_utils.py and we can merge (I will work on it in a separate PR) I think the added tests are fine without the changes: if accecelrate version is new, it is not affected by the changes in this PR; if it is not, it throw a different error, that's all

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok I was wrong! will merge

Copy link
Member Author

@sayakpaul sayakpaul Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure that works but here's my last try.

if these 2 condition are met (older accelerator version + bnb), it will not reach the error message you intended, it will be caught here at this firs check, and the error message is same as before this PR (about offloading)

When you have:

model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
pipeline_4bit = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    text_encoder_2=t5_4bit,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
)

in if pipeline_is_sequentially_offloaded and not pipeline_has_bnb, pipeline_is_sequentially_offloaded will be True (older accelerate version), however, not pipeline_has_bnb will be False (as expected). So, the following error won't be raised:

"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."

And it will hit the else.

To test, you can run the following with accelerate 1.0.1:

from diffusers import DiffusionPipeline, FluxTransformer2DModel
from transformers import T5EncoderModel
import torch 

model_id = "hf-internal-testing/flux.1-dev-nf4-pkg"
t5_4bit = T5EncoderModel.from_pretrained(model_id, subfolder="text_encoder_2")
transformer_4bit = FluxTransformer2DModel.from_pretrained(model_id, subfolder="transformer")
pipeline_4bit = DiffusionPipeline.from_pretrained(
    "black-forest-labs/FLUX.1-dev",
    text_encoder_2=t5_4bit,
    transformer=transformer_4bit,
    torch_dtype=torch.float16,
).to("cuda")

It throws:

ValueError: You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation.

Isn't this what we expect or am I missing something?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I missed that not pipeline_has_bnb in the statement, it works

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Saw your comment. Thanks for beating it with me :)

@sayakpaul sayakpaul merged commit e8da75d into main Dec 4, 2024
18 checks passed
@sayakpaul sayakpaul deleted the allow-device-placement-bnb branch December 4, 2024 16:57
sayakpaul added a commit that referenced this pull request Dec 23, 2024
…h bnb components (#9840)

* allow device placement when using bnb quantization.

* warning.

* tests

* fixes

* docs.

* require accelerate version.

* remove print.

* revert to()

* tests

* fixes

* fix: missing AutoencoderKL lora adapter (#9807)

* fix: missing AutoencoderKL lora adapter

* fix

---------

Co-authored-by: Sayak Paul <[email protected]>

* fixes

* fix condition test

* updates

* updates

* remove is_offloaded.

* fixes

* better

* empty

---------

Co-authored-by: Emmanuel Benazera <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants