Skip to content

Offloading behaviour for HiDream seems broken #11376

Closed
@sayakpaul

Description

@sayakpaul

When using enable_model_cpu_offload() I would expect the final component in the offloading string chain to be offloaded to the CPU to realize memory savings. But it's not likely the case.

Consider the script below:

Code
from transformers import PreTrainedTokenizerFast, LlamaForCausalLM
from diffusers import DiffusionPipeline
import torch

def print_detailed_memory(step_name):
    print(f"\n=== CUDA Memory Stats {step_name} ===")
    torch.cuda.reset_peak_memory_stats()
    print(f"Current allocated: {torch.cuda.memory_allocated() / 1024**3:.2f} GB")
    print(f"Max allocated: {torch.cuda.max_memory_allocated() / 1024**3:.2f} GB")
    print(f"Current reserved: {torch.cuda.memory_reserved() / 1024**3:.2f} GB")
    print(f"Max reserved: {torch.cuda.max_memory_reserved() / 1024**3:.2f} GB")

tokenizer_4 = PreTrainedTokenizerFast.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct")
text_encoder_4 = LlamaForCausalLM.from_pretrained(
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
    output_hidden_states=True,
    output_attentions=True,
    torch_dtype=torch.bfloat16,
)
repo_id = "HiDream-ai/HiDream-I1-Full"
pipe = DiffusionPipeline.from_pretrained(
    repo_id, 
    transformer=None, 
    vae=None, 
    tokenizer_4=tokenizer_4, 
    text_encoder_4=text_encoder_4, 
    torch_dtype=torch.bfloat16
)
pipe.enable_model_cpu_offload()
with torch.no_grad():
    (
        prompt_embeds_t5,
        negative_prompt_embeds_t5,
        prompt_embeds_llama3,
        negative_prompt_embeds_llama3,
        pooled_prompt_embeds,
        negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(prompt="hello")

print_detailed_memory("after_prompt_encoding")

for name, component in pipe.components.items():
    if component is not None and isinstance(component, torch.nn.Module):
        print(f"{name=}{component.device=}")
        del component 
    gc.collect()
torch.cuda.empty_cache()
print_detailed_memory("after_prompt_encoding_cleanup")

It prints the following logs:

=== CUDA Memory Stats start of stuff ===
Current allocated: 0.00 GB
Max allocated: 0.00 GB
Current reserved: 0.00 GB
Max reserved: 0.00 GB

=== CUDA Memory Stats after_prompt_encoding ===
Current allocated: 15.05 GB
Max allocated: 15.05 GB
Current reserved: 15.29 GB
Max reserved: 15.29 GB
name='text_encoder'component.device=device(type='cpu')
name='text_encoder_2'component.device=device(type='cpu')
name='text_encoder_3'component.device=device(type='cpu')
name='text_encoder_4'component.device=device(type='cuda', index=0)

=== CUDA Memory Stats after_prompt_encoding_cleanup ===
Current allocated: 15.05 GB
Max allocated: 15.05 GB
Current reserved: 15.18 GB
Max reserved: 15.18 GB

We can see that the GPU memory is NOT getting freed up as the last model is still on the GPU. Interestingly, when the following change fixes the behaviour i.e., we can clearly see GPU memory savings:

for name, component in pipe.components.items():
    if component is not None and isinstance(component, torch.nn.Module):
        print(f"{name=}{component.device=}")
+        if component.device.type != "cpu":
+            component.cpu()
        del component 
    gc.collect()
torch.cuda.empty_cache()
print_detailed_memory("after_prompt_encoding_cleanup")
=== CUDA Memory Stats start of stuff ===
Current allocated: 0.00 GB
Max allocated: 0.00 GB
Current reserved: 0.00 GB
Max reserved: 0.00 GB

=== CUDA Memory Stats after_prompt_encoding ===
Current allocated: 15.05 GB
Max allocated: 15.05 GB
Current reserved: 15.29 GB
Max reserved: 15.29 GB
name='text_encoder'component.device=device(type='cpu')
name='text_encoder_2'component.device=device(type='cpu')
name='text_encoder_3'component.device=device(type='cpu')
name='text_encoder_4'component.device=device(type='cuda', index=0)

=== CUDA Memory Stats after_prompt_encoding_cleanup ===
Current allocated: 0.10 GB
Max allocated: 0.10 GB
Current reserved: 0.10 GB
Max reserved: 0.10 GB

Similar issue happens for Flux too 👀

Cc: @DN6 @SunMarc

Also cc @asomoza as we both found this to be the case for HiDream.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions