Description
Describe the bug
I am attempting to use the WAN 2.1 model from the diffusers library to complete an image-to-video task on an NVIDIA RTX 4090. To optimize memory usage, I chose the group offload method and intended to compare resource consumption across different configurations. However, during testing, I encountered two main issues:
- When using the group_offload_leaf_stream method:
I received warnings that some layers were not executed during the forward pass:
It seems like some layers were not executed during the forward pass. This may lead to problems when applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please make sure that all layers are executed during the forward pass. The following layers were not executed:
unexecuted_layers=['blocks.25.attn2.norm_added_q', 'blocks.10.attn2.norm_added_q', 'blocks.13.attn2.norm_added_q', 'blocks.11.attn2.norm_added_q', 'blocks.34.attn2.norm_added_q', 'blocks.0.attn2.norm_added_q', 'blocks.35.attn2.norm_added_q', 'blocks.33.attn2.norm_added_q', 'blocks.21.attn2.norm_added_q', 'blocks.20.attn2.norm_added_q', 'blocks.3.attn2.norm_added_q', 'blocks.7.attn2.norm_added_q', 'blocks.22.attn2.norm_added_q', 'blocks.14.attn2.norm_added_q', 'blocks.29.attn2.norm_added_q', 'blocks.9.attn2.norm_added_q', 'blocks.1.attn2.norm_added_q', 'blocks.37.attn2.norm_added_q', 'blocks.18.attn2.norm_added_q', 'blocks.30.attn2.norm_added_q', 'blocks.4.attn2.norm_added_q', 'blocks.32.attn2.norm_added_q', 'blocks.36.attn2.norm_added_q', 'blocks.26.attn2.norm_added_q', 'blocks.6.attn2.norm_added_q', 'blocks.38.attn2.norm_added_q', 'blocks.17.attn2.norm_added_q', 'blocks.12.attn2.norm_added_q', 'blocks.19.attn2.norm_added_q', 'blocks.16.attn2.norm_added_q', 'blocks.15.attn2.norm_added_q', 'blocks.28.attn2.norm_added_q', 'blocks.24.attn2.norm_added_q', 'blocks.31.attn2.norm_added_q', 'blocks.8.attn2.norm_added_q', 'blocks.5.attn2.norm_added_q', 'blocks.27.attn2.norm_added_q', 'blocks.2.attn2.norm_added_q', 'blocks.39.attn2.norm_added_q', 'blocks.23.attn2.norm_added_q']
This issue resulted in severe degradation of the generated output.
这是我选择的图像:
我得到了错误的视频:
https://github.com/user-attachments/assets/7a8b55a2-6a71-493a-b7ae-64566b321954
当我使用默认pipe即不采用 group_offload_leaf_stream我得到了正确的结果:
https://github.com/user-attachments/assets/9b54c2f2-fa93-422f-b3df-619ee96bb3c8
2.When using the group_offload_block_1_stream method:
I encountered a runtime error: "RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same". It appears that the VAE module was not correctly assigned to the GPU device.
Traceback (most recent call last):
File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 171, in <module>
main(args)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 143, in main
run_inference()
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/memory_profiler.py", line 1188, in wrapper
val = prof(func)(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/memory_profiler.py", line 761, in f
return func(*args, **kwds)
File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 130, in run_inference
output = pipe(
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 587, in __call__
latents, condition = self.prepare_latents(
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 392, in prepare_latents
latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
return method(self, *args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 795, in encode
h = self._encode(x)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 762, in _encode
out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 439, in forward
x = self.conv_in(x, feat_cache[idx])
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
return forward_call(*args, **kwargs)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 78, in forward
return super().forward(x)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 725, in forward
return self._conv_forward(input, self.weight, self.bias)
File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 720, in _conv_forward
return F.conv3d(
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same
Request for Help:
Are there recommended approaches to ensure all layers are properly executed, especially for the group_offload_leaf_stream method?
How can I resolve the device mismatch issue related to the VAE?
Any suggestions or guidance would be greatly appreciated!
Reproduction
here is my code
import argparse
import functools
import json
import os
import pathlib
import psutil
import time
import torch
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
from memory_profiler import profile
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline
def get_memory_usage():
process = psutil.Process(os.getpid())
mem_bytes = process.memory_info().rss
return mem_bytes
@profile(precision=2)
def apply_offload(pipe: FluxPipeline, method: str) -> None:
if method == "full_cuda":
pipe.to("cuda")
elif method == "model_offload":
pipe.enable_model_cpu_offload()
elif method == "sequential_offload":
pipe.enable_sequential_cpu_offload()
elif method == "group_offload_block_1":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=False,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))
elif method == "group_offload_leaf":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=False,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))
elif method == "group_offload_block_1_stream":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="block_level",
num_blocks_per_group=1,
use_stream=True,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))
elif method == "group_offload_leaf_stream":
offloader_fn = functools.partial(
apply_group_offloading,
onload_device=torch.device("cuda"),
offload_device=torch.device("cpu"),
offload_type="leaf_level",
use_stream=True,
)
list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))
@profile(precision=2)
def load_pipeline():
model_id = "Wan2.1-I2V-14B-480P-Diffusers"
image_encoder = CLIPVisionModel.from_pretrained(
model_id, subfolder="image_encoder", torch_dtype=torch.float32
)
vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)
pipe = WanImageToVideoPipeline.from_pretrained(
model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16, scheduler=scheduler_b
)
return pipe
@torch.no_grad()
def main(args):
os.makedirs(args.output_dir, exist_ok=True)
os.makedirs(f"./results/check-wanmulti-framework/{args.method}/", exist_ok=True)
pipe = load_pipeline()
apply_offload(pipe, args.method)
apply_offload_memory_usage = get_memory_usage()
torch.cuda.reset_peak_memory_stats()
cuda_model_memory = torch.cuda.max_memory_reserved()
output_dir = pathlib.Path(args.output_dir)
output_dir.mkdir(exist_ok=True, parents=True)
run_inference_memory_usage_list = []
def cpu_mem_callback():
nonlocal run_inference_memory_usage_list
run_inference_memory_usage_list.append(get_memory_usage())
@profile(precision=2)
def run_inference():
image = load_image("./dataset/character-img/imgs3/1.jpeg")
max_area = 480 * 832
aspect_ratio = image.height / image.width
mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
prompt = (
"A person smile."
)
negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
generator = torch.Generator("cuda").manual_seed(100)
output = pipe(
image=image,
prompt=prompt,
negative_prompt=negative_prompt,
height=height,
width=width,
num_frames=81,
guidance_scale=5.0,
generator=generator,
).frames[0]
export_to_video(output, f"./results/check-wanmulti-framework/{args.method}/wanx_diffusers.mp4", fps=16)
t1 = time.time()
run_inference()
torch.cuda.synchronize()
t2 = time.time()
cuda_inference_memory = torch.cuda.max_memory_reserved()
time_required = t2 - t1
# run_inference_memory_usage = sum(run_inference_memory_usage_list) / len(run_inference_memory_usage_list)
# print(f"Run inference memory usage list: {run_inference_memory_usage_list}")
info = {
"time": round(time_required, 2),
"cuda_model_memory": round(cuda_model_memory / 1024**3, 2),
"cuda_inference_memory": round(cuda_inference_memory / 1024**3, 2),
"cpu_offload_memory": round(apply_offload_memory_usage / 1024**3, 2),
}
with open(output_dir / f"memory_usage_{args.method}.json", "w") as f:
json.dump(info, f, indent=4)
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--method", type=str, default="full_cuda", choices=["full_cuda", "model_offload", "sequential_offload", "group_offload_block_1", "group_offload_leaf", "group_offload_block_1_stream", "group_offload_leaf_stream"])
parser.add_argument("--output_dir", type=str, default="./results/offload_profiling")
return parser.parse_args()
if __name__ == "__main__":
args = get_args()
main(args)
here is my environment
Package Version
--------------------------------- --------------------
absl-py 2.1.0
accelerate 1.4.0
addict 2.4.0
aiofiles 23.2.1
aiohappyeyeballs 2.4.3
aiohttp 3.10.10
aiosignal 1.3.1
airportsdata 20241001
albucore 0.0.17
albumentations 1.4.18
aliyun-python-sdk-core 2.16.0
aliyun-python-sdk-kms 2.16.5
altair 5.4.1
annotated-types 0.7.0
antlr4-python3-runtime 4.9.3
anyio 4.6.2.post1
astor 0.8.1
asttokens 2.4.1
astunparse 1.6.3
async-timeout 4.0.3
attrs 24.2.0
av 13.1.0
beautifulsoup4 4.12.3
blake3 1.0.4
blinker 1.9.0
boto3 1.35.60
botocore 1.35.60
braceexpand 0.1.7
certifi 2024.8.30
cffi 1.17.1
charset-normalizer 3.4.0
click 8.1.7
clip 0.2.0
cloudpickle 3.1.0
coloredlogs 15.0.1
comm 0.2.2
compressed-tensors 0.8.0
ConfigArgParse 1.7
contourpy 1.3.0
controlnet_aux 0.0.7
cpm-kernels 1.0.11
crcmod 1.7
cryptography 44.0.1
cupy-cuda12x 13.3.0
cycler 0.12.1
Cython 3.0.12
dash 2.18.2
dash-core-components 2.0.0
dash-html-components 2.0.0
dash-table 5.0.0
dashscope 1.22.2
datasets 3.0.1
debugpy 1.8.10
decorator 4.4.2
decord 0.6.0
deepspeed 0.15.2
depyf 0.18.0
diffsynth 1.1.2
diffusers 0.33.0.dev0
dill 0.3.8
diskcache 5.6.3
distro 1.9.0
dnspython 2.7.0
docker-pycreds 0.4.0
easydict 1.13
einops 0.8.0
email_validator 2.2.0
eval_type_backport 0.2.0
exceptiongroup 1.2.2
executing 2.1.0
facexlib 0.3.0
fairscale 0.4.13
fastapi 0.115.2
fastjsonschema 2.20.0
fastrlock 0.8.3
ffmpy 0.4.0
filelock 3.16.1
filterpy 1.4.5
flash-attn 2.6.3
Flask 3.0.3
flatbuffers 24.3.25
fonttools 4.54.1
frozenlist 1.4.1
fsspec 2024.6.1
ftfy 6.3.0
func_timeout 4.3.5
future 1.0.0
fvcore 0.1.5.post20221221
gast 0.6.0
gguf 0.10.0
gitdb 4.0.11
GitPython 3.1.43
google-pasta 0.2.0
gradio 5.5.0
gradio_client 1.4.2
grpcio 1.66.2
h11 0.14.0
h5py 3.12.1
hjson 3.1.0
httpcore 1.0.6
httptools 0.6.4
httpx 0.27.2
huggingface-hub 0.29.1
humanfriendly 10.0
idna 3.10
imageio 2.36.0
imageio-ffmpeg 0.5.1
imgaug 0.4.0
importlib_metadata 8.5.0
iniconfig 2.0.0
interegular 0.3.3
iopath 0.1.10
ipykernel 6.29.5
ipython 8.29.0
ipywidgets 8.1.5
itsdangerous 2.2.0
jaxtyping 0.2.34
jedi 0.19.1
Jinja2 3.1.4
jiter 0.7.0
jmespath 0.10.0
joblib 1.4.2
jsonschema 4.23.0
jsonschema-specifications 2024.10.1
jupyter_client 8.6.3
jupyter_core 5.7.2
jupyterlab_widgets 3.0.13
keras 3.7.0
kiwisolver 1.4.7
lark 1.2.2
lazy_loader 0.4
libclang 18.1.1
libigl 2.5.1
linkify-it-py 2.0.3
llvmlite 0.43.0
lm-format-enforcer 0.10.9
lmdb 1.6.2
loguru 0.7.3
lvis 0.5.3
Markdown 3.7
markdown-it-py 2.2.0
MarkupSafe 2.1.5
matplotlib 3.9.2
matplotlib-inline 0.1.7
mdit-py-plugins 0.3.3
mdurl 0.1.2
memory-profiler 0.61.0
mistral_common 1.5.1
ml-dtypes 0.4.1
modelscope 1.23.2
moviepy 1.0.3
mpmath 1.3.0
msgpack 1.1.0
msgspec 0.18.6
multidict 6.1.0
multiprocess 0.70.16
namex 0.0.8
narwhals 1.10.0
natsort 8.4.0
nbformat 5.10.4
nest-asyncio 1.6.0
networkx 3.4.1
ninja 1.11.1.3
numba 0.60.0
numpy 1.26.4
nvdiffrast 0.3.3
nvidia-cublas-cu12 12.4.5.8
nvidia-cuda-cupti-cu12 12.4.127
nvidia-cuda-nvrtc-cu12 12.4.127
nvidia-cuda-runtime-cu12 12.4.127
nvidia-cudnn-cu12 9.1.0.70
nvidia-cufft-cu12 11.2.1.3
nvidia-curand-cu12 10.3.5.147
nvidia-cusolver-cu12 11.6.1.9
nvidia-cusparse-cu12 12.3.1.170
nvidia-cusparselt-cu12 0.6.2
nvidia-ml-py 12.560.30
nvidia-nccl-cu12 2.21.5
nvidia-nvjitlink-cu12 12.4.127
nvidia-nvtx-cu12 12.4.127
omegaconf 2.3.0
onnxruntime 1.20.0
open3d 0.18.0
openai 1.54.4
openai-clip 1.0.1
opencv-python 4.10.0.84
opencv-python-headless 4.10.0.84
opt_einsum 3.4.0
optree 0.13.1
orjson 3.10.7
oss2 2.19.1
outlines 0.0.46
packaging 24.1
pandas 2.2.3
parso 0.8.4
partial-json-parser 0.2.1.1.post4
peft 0.13.2
pexpect 4.9.0
pillow 10.4.0
pip 24.2
platformdirs 4.3.6
plotly 5.24.1
pluggy 1.5.0
pooch 1.8.2
portalocker 2.10.1
proglog 0.1.10
prometheus_client 0.21.0
prometheus-fastapi-instrumentator 7.0.0
prompt_toolkit 3.0.48
propcache 0.2.0
protobuf 5.28.2
psutil 6.0.0
ptyprocess 0.7.0
pudb 2024.1.2
pure_eval 0.2.3
py-cpuinfo 9.0.0
pyairports 2.1.1
pyarrow 17.0.0
pybind11 2.13.6
pycocoevalcap 1.2
pycocotools 2.0.8
pycountry 24.6.1
pycparser 2.22
pycryptodome 3.21.0
pydantic 2.9.2
pydantic_core 2.23.4
pydub 0.25.1
Pygments 2.18.0
pyiqa 0.1.10
PyMatting 1.1.12
PyMCubes 0.1.6
pyparsing 3.2.0
pyquaternion 0.9.9
pytest 8.3.4
python-dateutil 2.9.0.post0
python-dotenv 1.0.1
python-multipart 0.0.12
pytorch3d 0.7.8
pytz 2024.2
PyYAML 6.0.2
pyzmq 26.2.0
qwen-vl-utils 0.0.10
ray 2.37.0
referencing 0.35.1
regex 2024.9.11
rembg 2.0.59
requests 2.32.3
requests-toolbelt 1.0.0
retrying 1.3.4
rich 13.9.2
rpds-py 0.20.0
ruff 0.6.9
s3transfer 0.10.3
safehttpx 0.1.1
safetensors 0.4.5
scikit-image 0.24.0
scikit-learn 1.5.2
scikit-video 1.1.11
scipy 1.14.1
semantic-version 2.10.0
sentencepiece 0.2.0
sentry-sdk 2.18.0
setproctitle 1.3.3
setuptools 75.2.0
shapely 2.0.7
shellingham 1.5.4
six 1.16.0
sk-video 1.1.10
smmap 5.0.1
sniffio 1.3.1
soupsieve 2.6
stack-data 0.6.3
starlette 0.40.0
SwissArmyTransformer 0.4.12
sympy 1.13.1
tabulate 0.9.0
tenacity 9.0.0
tensorboard 2.18.0
tensorboard-data-server 0.7.2
tensorboardX 2.6.2.2
tensorflow-io-gcs-filesystem 0.37.1
termcolor 2.5.0
thop 0.1.1.post2209072238
threadpoolctl 3.5.0
tifffile 2024.9.20
tiktoken 0.7.0
timm 1.0.11
tokenizers 0.20.3
tomesd 0.1.3
tomli 2.2.1
tomlkit 0.12.0
torch 2.6.0
torchaudio 2.6.0
torchdiffeq 0.2.4
torchsde 0.2.6
torchvision 0.21.0
tornado 6.4.2
tqdm 4.66.5
traitlets 5.14.3
trampoline 0.1.2
transformers 4.46.2
transformers-stream-generator 0.0.4
trimesh 4.5.2
triton 3.2.0
typeguard 2.13.3
typer 0.12.5
typing_extensions 4.12.2
tzdata 2024.2
uc-micro-py 1.0.3
urllib3 2.2.3
urwid 2.6.16
urwid_readline 0.15.1
uvicorn 0.32.0
uvloop 0.21.0
wandb 0.18.7
watchfiles 0.24.0
wcwidth 0.2.13
webdataset 0.2.100
websocket-client 1.8.0
websockets 12.0
Werkzeug 3.0.4
wheel 0.44.0
widgetsnbextension 4.0.13
wrapt 1.17.0
xatlas 0.0.9
xxhash 3.5.0
yacs 0.1.8
yapf 0.43.0
yarl 1.15.3
zipp 3.20.2
Logs
System Info
- 🤗 Diffusers version: 0.33.0.dev0
- Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.35
- Running on Google Colab?: No
- Python version: 3.10.15
- PyTorch version (GPU?): 2.6.0+cu124 (True)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Huggingface_hub version: 0.29.1
- Transformers version: 4.46.2
- Accelerate version: 1.4.0
- PEFT version: 0.13.2
- Bitsandbytes version: not installed
- Safetensors version: 0.4.5
- xFormers version: not installed
- Accelerator: NVIDIA A800-SXM4-80GB, 81251 MiB
- Using GPU in script?:
- Using distributed or parallel set-up in script?: