Skip to content

WAN2.1 apply_group_offloading **ERROR** result #11041

Closed
@Passenger12138

Description

@Passenger12138

Describe the bug

I am attempting to use the WAN 2.1 model from the diffusers library to complete an image-to-video task on an NVIDIA RTX 4090. To optimize memory usage, I chose the group offload method and intended to compare resource consumption across different configurations. However, during testing, I encountered two main issues:

  1. When using the group_offload_leaf_stream method:
    I received warnings that some layers were not executed during the forward pass:
It seems like some layers were not executed during the forward pass. This may lead to problems when applying lazy prefetching with automatic tracing and lead to device-mismatch related errors. Please make sure that all layers are executed during the forward pass. The following layers were not executed:
unexecuted_layers=['blocks.25.attn2.norm_added_q', 'blocks.10.attn2.norm_added_q', 'blocks.13.attn2.norm_added_q', 'blocks.11.attn2.norm_added_q', 'blocks.34.attn2.norm_added_q', 'blocks.0.attn2.norm_added_q', 'blocks.35.attn2.norm_added_q', 'blocks.33.attn2.norm_added_q', 'blocks.21.attn2.norm_added_q', 'blocks.20.attn2.norm_added_q', 'blocks.3.attn2.norm_added_q', 'blocks.7.attn2.norm_added_q', 'blocks.22.attn2.norm_added_q', 'blocks.14.attn2.norm_added_q', 'blocks.29.attn2.norm_added_q', 'blocks.9.attn2.norm_added_q', 'blocks.1.attn2.norm_added_q', 'blocks.37.attn2.norm_added_q', 'blocks.18.attn2.norm_added_q', 'blocks.30.attn2.norm_added_q', 'blocks.4.attn2.norm_added_q', 'blocks.32.attn2.norm_added_q', 'blocks.36.attn2.norm_added_q', 'blocks.26.attn2.norm_added_q', 'blocks.6.attn2.norm_added_q', 'blocks.38.attn2.norm_added_q', 'blocks.17.attn2.norm_added_q', 'blocks.12.attn2.norm_added_q', 'blocks.19.attn2.norm_added_q', 'blocks.16.attn2.norm_added_q', 'blocks.15.attn2.norm_added_q', 'blocks.28.attn2.norm_added_q', 'blocks.24.attn2.norm_added_q', 'blocks.31.attn2.norm_added_q', 'blocks.8.attn2.norm_added_q', 'blocks.5.attn2.norm_added_q', 'blocks.27.attn2.norm_added_q', 'blocks.2.attn2.norm_added_q', 'blocks.39.attn2.norm_added_q', 'blocks.23.attn2.norm_added_q']

Image

This issue resulted in severe degradation of the generated output.
这是我选择的图像:
Image
我得到了错误的视频:
https://github.com/user-attachments/assets/7a8b55a2-6a71-493a-b7ae-64566b321954
当我使用默认pipe即不采用 group_offload_leaf_stream我得到了正确的结果:
https://github.com/user-attachments/assets/9b54c2f2-fa93-422f-b3df-619ee96bb3c8

2.When using the group_offload_block_1_stream method:
I encountered a runtime error: "RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same". It appears that the VAE module was not correctly assigned to the GPU device.

Traceback (most recent call last):
  File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 171, in <module>
    main(args)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 143, in main
    run_inference()
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/memory_profiler.py", line 1188, in wrapper
    val = prof(func)(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/memory_profiler.py", line 761, in f
    return func(*args, **kwds)
  File "/maindata/data/shared/public/haobang.geng/code/video-generate/i2v-baseline/wanx-all-profile.py", line 130, in run_inference
    output = pipe(
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 587, in __call__
    latents, condition = self.prepare_latents(
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/pipelines/wan/pipeline_wan_i2v.py", line 392, in prepare_latents
    latent_condition = retrieve_latents(self.vae.encode(video_condition), generator)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 795, in encode
    h = self._encode(x)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 762, in _encode
    out = self.encoder(x[:, :, :1, :, :], feat_cache=self._enc_feat_map, feat_idx=self._enc_conv_idx)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 439, in forward
    x = self.conv_in(x, feat_cache[idx])
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1750, in _call_impl
    return forward_call(*args, **kwargs)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/diffusers/models/autoencoders/autoencoder_kl_wan.py", line 78, in forward
    return super().forward(x)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 725, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/maindata/data/shared/public/haobang.geng/miniconda/envs/vdm/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 720, in _conv_forward
    return F.conv3d(
RuntimeError: Input type (torch.cuda.FloatTensor) and weight type (torch.FloatTensor) should be the same

Request for Help:

Are there recommended approaches to ensure all layers are properly executed, especially for the group_offload_leaf_stream method?
How can I resolve the device mismatch issue related to the VAE?
Any suggestions or guidance would be greatly appreciated!

Reproduction

here is my code

import argparse
import functools
import json
import os
import pathlib
import psutil
import time

import torch
from diffusers import FluxPipeline
from diffusers.hooks import apply_group_offloading
from memory_profiler import profile
import torch
import numpy as np
from diffusers import AutoencoderKLWan, WanImageToVideoPipeline
from diffusers.utils import export_to_video, load_image
from transformers import CLIPVisionModel
from diffusers import FlowMatchEulerDiscreteScheduler, UniPCMultistepScheduler, WanPipeline


def get_memory_usage():
    process = psutil.Process(os.getpid())
    mem_bytes = process.memory_info().rss
    return mem_bytes


@profile(precision=2)
def apply_offload(pipe: FluxPipeline, method: str) -> None:
    if method == "full_cuda":
        pipe.to("cuda")
    
    elif method == "model_offload":
        pipe.enable_model_cpu_offload()
    
    elif method == "sequential_offload":
        pipe.enable_sequential_cpu_offload()
    
    elif method == "group_offload_block_1":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="block_level",
            num_blocks_per_group=1,
            use_stream=False,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))

    elif method == "group_offload_leaf":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="leaf_level",
            use_stream=False,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))

    
    elif method == "group_offload_block_1_stream":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="block_level",
            num_blocks_per_group=1,
            use_stream=True,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))
    
    elif method == "group_offload_leaf_stream":
        offloader_fn = functools.partial(
            apply_group_offloading,
            onload_device=torch.device("cuda"),
            offload_device=torch.device("cpu"),
            offload_type="leaf_level",
            use_stream=True,
        )
        list(map(offloader_fn, [pipe.text_encoder, pipe.transformer, pipe.vae, pipe.image_encoder]))


@profile(precision=2)
def load_pipeline():
    model_id = "Wan2.1-I2V-14B-480P-Diffusers"
    image_encoder = CLIPVisionModel.from_pretrained(
        model_id, subfolder="image_encoder", torch_dtype=torch.float32
    )
    vae = AutoencoderKLWan.from_pretrained(model_id, subfolder="vae", torch_dtype=torch.float32)
    scheduler_b = UniPCMultistepScheduler(prediction_type="flow_prediction", use_flow_sigmas=True, flow_shift=3.0)

    pipe = WanImageToVideoPipeline.from_pretrained(
        model_id, vae=vae, image_encoder=image_encoder, torch_dtype=torch.bfloat16, scheduler=scheduler_b
    )
    return pipe


@torch.no_grad()
def main(args):
    os.makedirs(args.output_dir, exist_ok=True)
    os.makedirs(f"./results/check-wanmulti-framework/{args.method}/", exist_ok=True)
    pipe = load_pipeline()
    apply_offload(pipe, args.method)
    apply_offload_memory_usage = get_memory_usage()

    torch.cuda.reset_peak_memory_stats()
    cuda_model_memory = torch.cuda.max_memory_reserved()

    output_dir = pathlib.Path(args.output_dir)
    output_dir.mkdir(exist_ok=True, parents=True)

    run_inference_memory_usage_list = []
    
    def cpu_mem_callback():
        nonlocal run_inference_memory_usage_list
        run_inference_memory_usage_list.append(get_memory_usage())

    @profile(precision=2)
    def run_inference():
        image = load_image("./dataset/character-img/imgs3/1.jpeg")
        max_area = 480 * 832
        aspect_ratio = image.height / image.width
        mod_value = pipe.vae_scale_factor_spatial * pipe.transformer.config.patch_size[1]
        height = round(np.sqrt(max_area * aspect_ratio)) // mod_value * mod_value
        width = round(np.sqrt(max_area / aspect_ratio)) // mod_value * mod_value
        prompt = (
            "A person smile."
        )
        negative_prompt = "Bright tones, overexposed, static, blurred details, subtitles, style, works, paintings, images, static, overall gray, worst quality, low quality, JPEG compression residue, ugly, incomplete, extra fingers, poorly drawn hands, poorly drawn faces, deformed, disfigured, misshapen limbs, fused fingers, still picture, messy background, three legs, many people in the background, walking backwards"
        generator = torch.Generator("cuda").manual_seed(100)
        output = pipe(
            image=image,
            prompt=prompt,
            negative_prompt=negative_prompt,
            height=height,
            width=width,
            num_frames=81,
            guidance_scale=5.0,
            generator=generator,
        ).frames[0]
        export_to_video(output, f"./results/check-wanmulti-framework/{args.method}/wanx_diffusers.mp4", fps=16)

    t1 = time.time()
    run_inference()
    torch.cuda.synchronize()
    t2 = time.time()
    cuda_inference_memory = torch.cuda.max_memory_reserved()
    time_required = t2 - t1
    # run_inference_memory_usage = sum(run_inference_memory_usage_list) / len(run_inference_memory_usage_list)
    # print(f"Run inference memory usage list: {run_inference_memory_usage_list}")

    info = {
        "time": round(time_required, 2),
        "cuda_model_memory": round(cuda_model_memory / 1024**3, 2),
        "cuda_inference_memory": round(cuda_inference_memory / 1024**3, 2),
        "cpu_offload_memory": round(apply_offload_memory_usage / 1024**3, 2),
    }
    with open(output_dir / f"memory_usage_{args.method}.json", "w") as f:
        json.dump(info, f, indent=4)


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("--method", type=str, default="full_cuda", choices=["full_cuda", "model_offload", "sequential_offload", "group_offload_block_1", "group_offload_leaf", "group_offload_block_1_stream", "group_offload_leaf_stream"])
    parser.add_argument("--output_dir", type=str, default="./results/offload_profiling")
    return parser.parse_args()


if __name__ == "__main__":
    args = get_args()
    main(args)

here is my environment

Package                           Version
--------------------------------- --------------------
absl-py                           2.1.0
accelerate                        1.4.0
addict                            2.4.0
aiofiles                          23.2.1
aiohappyeyeballs                  2.4.3
aiohttp                           3.10.10
aiosignal                         1.3.1
airportsdata                      20241001
albucore                          0.0.17
albumentations                    1.4.18
aliyun-python-sdk-core            2.16.0
aliyun-python-sdk-kms             2.16.5
altair                            5.4.1
annotated-types                   0.7.0
antlr4-python3-runtime            4.9.3
anyio                             4.6.2.post1
astor                             0.8.1
asttokens                         2.4.1
astunparse                        1.6.3
async-timeout                     4.0.3
attrs                             24.2.0
av                                13.1.0
beautifulsoup4                    4.12.3
blake3                            1.0.4
blinker                           1.9.0
boto3                             1.35.60
botocore                          1.35.60
braceexpand                       0.1.7
certifi                           2024.8.30
cffi                              1.17.1
charset-normalizer                3.4.0
click                             8.1.7
clip                              0.2.0
cloudpickle                       3.1.0
coloredlogs                       15.0.1
comm                              0.2.2
compressed-tensors                0.8.0
ConfigArgParse                    1.7
contourpy                         1.3.0
controlnet_aux                    0.0.7
cpm-kernels                       1.0.11
crcmod                            1.7
cryptography                      44.0.1
cupy-cuda12x                      13.3.0
cycler                            0.12.1
Cython                            3.0.12
dash                              2.18.2
dash-core-components              2.0.0
dash-html-components              2.0.0
dash-table                        5.0.0
dashscope                         1.22.2
datasets                          3.0.1
debugpy                           1.8.10
decorator                         4.4.2
decord                            0.6.0
deepspeed                         0.15.2
depyf                             0.18.0
diffsynth                         1.1.2
diffusers                         0.33.0.dev0
dill                              0.3.8
diskcache                         5.6.3
distro                            1.9.0
dnspython                         2.7.0
docker-pycreds                    0.4.0
easydict                          1.13
einops                            0.8.0
email_validator                   2.2.0
eval_type_backport                0.2.0
exceptiongroup                    1.2.2
executing                         2.1.0
facexlib                          0.3.0
fairscale                         0.4.13
fastapi                           0.115.2
fastjsonschema                    2.20.0
fastrlock                         0.8.3
ffmpy                             0.4.0
filelock                          3.16.1
filterpy                          1.4.5
flash-attn                        2.6.3
Flask                             3.0.3
flatbuffers                       24.3.25
fonttools                         4.54.1
frozenlist                        1.4.1
fsspec                            2024.6.1
ftfy                              6.3.0
func_timeout                      4.3.5
future                            1.0.0
fvcore                            0.1.5.post20221221
gast                              0.6.0
gguf                              0.10.0
gitdb                             4.0.11
GitPython                         3.1.43
google-pasta                      0.2.0
gradio                            5.5.0
gradio_client                     1.4.2
grpcio                            1.66.2
h11                               0.14.0
h5py                              3.12.1
hjson                             3.1.0
httpcore                          1.0.6
httptools                         0.6.4
httpx                             0.27.2
huggingface-hub                   0.29.1
humanfriendly                     10.0
idna                              3.10
imageio                           2.36.0
imageio-ffmpeg                    0.5.1
imgaug                            0.4.0
importlib_metadata                8.5.0
iniconfig                         2.0.0
interegular                       0.3.3
iopath                            0.1.10
ipykernel                         6.29.5
ipython                           8.29.0
ipywidgets                        8.1.5
itsdangerous                      2.2.0
jaxtyping                         0.2.34
jedi                              0.19.1
Jinja2                            3.1.4
jiter                             0.7.0
jmespath                          0.10.0
joblib                            1.4.2
jsonschema                        4.23.0
jsonschema-specifications         2024.10.1
jupyter_client                    8.6.3
jupyter_core                      5.7.2
jupyterlab_widgets                3.0.13
keras                             3.7.0
kiwisolver                        1.4.7
lark                              1.2.2
lazy_loader                       0.4
libclang                          18.1.1
libigl                            2.5.1
linkify-it-py                     2.0.3
llvmlite                          0.43.0
lm-format-enforcer                0.10.9
lmdb                              1.6.2
loguru                            0.7.3
lvis                              0.5.3
Markdown                          3.7
markdown-it-py                    2.2.0
MarkupSafe                        2.1.5
matplotlib                        3.9.2
matplotlib-inline                 0.1.7
mdit-py-plugins                   0.3.3
mdurl                             0.1.2
memory-profiler                   0.61.0
mistral_common                    1.5.1
ml-dtypes                         0.4.1
modelscope                        1.23.2
moviepy                           1.0.3
mpmath                            1.3.0
msgpack                           1.1.0
msgspec                           0.18.6
multidict                         6.1.0
multiprocess                      0.70.16
namex                             0.0.8
narwhals                          1.10.0
natsort                           8.4.0
nbformat                          5.10.4
nest-asyncio                      1.6.0
networkx                          3.4.1
ninja                             1.11.1.3
numba                             0.60.0
numpy                             1.26.4
nvdiffrast                        0.3.3
nvidia-cublas-cu12                12.4.5.8
nvidia-cuda-cupti-cu12            12.4.127
nvidia-cuda-nvrtc-cu12            12.4.127
nvidia-cuda-runtime-cu12          12.4.127
nvidia-cudnn-cu12                 9.1.0.70
nvidia-cufft-cu12                 11.2.1.3
nvidia-curand-cu12                10.3.5.147
nvidia-cusolver-cu12              11.6.1.9
nvidia-cusparse-cu12              12.3.1.170
nvidia-cusparselt-cu12            0.6.2
nvidia-ml-py                      12.560.30
nvidia-nccl-cu12                  2.21.5
nvidia-nvjitlink-cu12             12.4.127
nvidia-nvtx-cu12                  12.4.127
omegaconf                         2.3.0
onnxruntime                       1.20.0
open3d                            0.18.0
openai                            1.54.4
openai-clip                       1.0.1
opencv-python                     4.10.0.84
opencv-python-headless            4.10.0.84
opt_einsum                        3.4.0
optree                            0.13.1
orjson                            3.10.7
oss2                              2.19.1
outlines                          0.0.46
packaging                         24.1
pandas                            2.2.3
parso                             0.8.4
partial-json-parser               0.2.1.1.post4
peft                              0.13.2
pexpect                           4.9.0
pillow                            10.4.0
pip                               24.2
platformdirs                      4.3.6
plotly                            5.24.1
pluggy                            1.5.0
pooch                             1.8.2
portalocker                       2.10.1
proglog                           0.1.10
prometheus_client                 0.21.0
prometheus-fastapi-instrumentator 7.0.0
prompt_toolkit                    3.0.48
propcache                         0.2.0
protobuf                          5.28.2
psutil                            6.0.0
ptyprocess                        0.7.0
pudb                              2024.1.2
pure_eval                         0.2.3
py-cpuinfo                        9.0.0
pyairports                        2.1.1
pyarrow                           17.0.0
pybind11                          2.13.6
pycocoevalcap                     1.2
pycocotools                       2.0.8
pycountry                         24.6.1
pycparser                         2.22
pycryptodome                      3.21.0
pydantic                          2.9.2
pydantic_core                     2.23.4
pydub                             0.25.1
Pygments                          2.18.0
pyiqa                             0.1.10
PyMatting                         1.1.12
PyMCubes                          0.1.6
pyparsing                         3.2.0
pyquaternion                      0.9.9
pytest                            8.3.4
python-dateutil                   2.9.0.post0
python-dotenv                     1.0.1
python-multipart                  0.0.12
pytorch3d                         0.7.8
pytz                              2024.2
PyYAML                            6.0.2
pyzmq                             26.2.0
qwen-vl-utils                     0.0.10
ray                               2.37.0
referencing                       0.35.1
regex                             2024.9.11
rembg                             2.0.59
requests                          2.32.3
requests-toolbelt                 1.0.0
retrying                          1.3.4
rich                              13.9.2
rpds-py                           0.20.0
ruff                              0.6.9
s3transfer                        0.10.3
safehttpx                         0.1.1
safetensors                       0.4.5
scikit-image                      0.24.0
scikit-learn                      1.5.2
scikit-video                      1.1.11
scipy                             1.14.1
semantic-version                  2.10.0
sentencepiece                     0.2.0
sentry-sdk                        2.18.0
setproctitle                      1.3.3
setuptools                        75.2.0
shapely                           2.0.7
shellingham                       1.5.4
six                               1.16.0
sk-video                          1.1.10
smmap                             5.0.1
sniffio                           1.3.1
soupsieve                         2.6
stack-data                        0.6.3
starlette                         0.40.0
SwissArmyTransformer              0.4.12
sympy                             1.13.1
tabulate                          0.9.0
tenacity                          9.0.0
tensorboard                       2.18.0
tensorboard-data-server           0.7.2
tensorboardX                      2.6.2.2
tensorflow-io-gcs-filesystem      0.37.1
termcolor                         2.5.0
thop                              0.1.1.post2209072238
threadpoolctl                     3.5.0
tifffile                          2024.9.20
tiktoken                          0.7.0
timm                              1.0.11
tokenizers                        0.20.3
tomesd                            0.1.3
tomli                             2.2.1
tomlkit                           0.12.0
torch                             2.6.0
torchaudio                        2.6.0
torchdiffeq                       0.2.4
torchsde                          0.2.6
torchvision                       0.21.0
tornado                           6.4.2
tqdm                              4.66.5
traitlets                         5.14.3
trampoline                        0.1.2
transformers                      4.46.2
transformers-stream-generator     0.0.4
trimesh                           4.5.2
triton                            3.2.0
typeguard                         2.13.3
typer                             0.12.5
typing_extensions                 4.12.2
tzdata                            2024.2
uc-micro-py                       1.0.3
urllib3                           2.2.3
urwid                             2.6.16
urwid_readline                    0.15.1
uvicorn                           0.32.0
uvloop                            0.21.0
wandb                             0.18.7
watchfiles                        0.24.0
wcwidth                           0.2.13
webdataset                        0.2.100
websocket-client                  1.8.0
websockets                        12.0
Werkzeug                          3.0.4
wheel                             0.44.0
widgetsnbextension                4.0.13
wrapt                             1.17.0
xatlas                            0.0.9
xxhash                            3.5.0
yacs                              0.1.8
yapf                              0.43.0
yarl                              1.15.3
zipp                              3.20.2

Logs

System Info

  • 🤗 Diffusers version: 0.33.0.dev0
  • Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.35
  • Running on Google Colab?: No
  • Python version: 3.10.15
  • PyTorch version (GPU?): 2.6.0+cu124 (True)
  • Flax version (CPU?/GPU?/TPU?): not installed (NA)
  • Jax version: not installed
  • JaxLib version: not installed
  • Huggingface_hub version: 0.29.1
  • Transformers version: 4.46.2
  • Accelerate version: 1.4.0
  • PEFT version: 0.13.2
  • Bitsandbytes version: not installed
  • Safetensors version: 0.4.5
  • xFormers version: not installed
  • Accelerator: NVIDIA A800-SXM4-80GB, 81251 MiB
  • Using GPU in script?:
  • Using distributed or parallel set-up in script?:

Who can help?

@DN6 @a-r-r-o-w

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions