Skip to content

AttributeError: _hf_hook caused by delattr in hooks.remove_hook_from_module() #10729

Open
@eppaneamd

Description

@eppaneamd

System Info

- `Accelerate` version: 1.4.0.dev0
- Platform: Linux-6.2.0-39-generic-x86_64-with-glibc2.39
- `accelerate` bash location: /workspaces/.venv_py311/bin/accelerate
- Python version: 3.11.11
- Numpy version: 1.26.3
- PyTorch version (GPU?): 2.5.1+rocm6.2 (True)
- System RAM: 1007.70 GB
- GPU type: AMD Instinct MI250X/MI250
- `Accelerate` default config:
        Not found

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

When using the HunyuanVideo model via diffusers framework with the following script:

import torch
from diffusers import HunyuanVideoPipeline, HunyuanVideoTransformer3DModel

MODEL_ID = "hunyuanvideo-community/HunyuanVideo"
PROMPT = "A cat walks on the grass, realistic"

transformer = HunyuanVideoTransformer3DModel.from_pretrained(
    MODEL_ID, subfolder="transformer", torch_dtype=torch.bfloat16
)
pipe = HunyuanVideoPipeline.from_pretrained(
    MODEL_ID, transformer=transformer, torch_dtype=torch.float16
)

pipe.enable_model_cpu_offload()
pipe.vae.enable_tiling()
pipe.vae.enable_slicing()

pipe.transformer = torch.compile(pipe.transformer, mode="default")

# Warmup
_ = pipe(
    prompt=PROMPT,
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=1,
    max_sequence_length=256,
    guidance_scale=0.0,
    generator=torch.Generator(device="cuda").manual_seed(42),
)

# Inference
output = pipe(
    prompt=PROMPT,
    height=320,
    width=512,
    num_frames=61,
    num_inference_steps=30,
    max_sequence_length=256,
    guidance_scale=0.0,
    output_type="pil",
    generator=torch.Generator(device="cuda").manual_seed(42),
)

Following error arises:

Traceback (most recent call last):
  File "/workspaces/huvideo_diffusers_repro.py", line 21, in <module>
    _ = pipe(
        ^^^^^
  File "/workspaces/.venv_py311/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/workspaces/diffusers/src/diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py", line 689, in __call__
    self.maybe_free_model_hooks()
  File "/workspaces/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1119, in maybe_free_model_hooks
    self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
  File "/workspaces/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1050, in enable_model_cpu_offload
    self.remove_all_hooks()
  File "/workspaces/diffusers/src/diffusers/pipelines/pipeline_utils.py", line 1017, in remove_all_hooks
    accelerate.hooks.remove_hook_from_module(model, recurse=True)
  File "/workspaces/accelerate/src/accelerate/hooks.py", line 203, in remove_hook_from_module
    delattr(module, "_hf_hook")
  File "/workspaces/.venv_py311/lib/python3.11/site-packages/torch/nn/modules/module.py", line 2043, in __delattr__
    super().__delattr__(name)
AttributeError: 'OptimizedModule' object has no attribute '_hf_hook'

When wrapping these two lines with try ... except, the run succeeds.

https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L203
https://github.com/huggingface/accelerate/blob/main/src/accelerate/hooks.py#L212

If we still log the exceptions, they produce:

Error:  AttributeError – 'OptimizedModule' object has no attribute '_hf_hook'
Error:  AttributeError – 'OptimizedModule' object has no attribute '_old_forward'

This indicates that there seems to be a logical error when using the following coding pattern:

if hasattr(object, attr):
    delattr(object, attr)

Suggesting that even though hasattr returns True, there is no guarantee delattr will work as intended.

Expected behavior

Successful run without errors when using torch.compile and cpu model offloading.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions