Skip to content

Still Issue on flux dreambooth lora training #9237 #9548

Closed
@jeongiin

Description

@jeongiin

Describe the bug

I tried running train_dreambooth_lora_flux.py again with the merged source code, but I am still encountering an issue similar to #9237 during the log_validation stage.

I have resolved this issue with the following modification:

autocast_ctx = nullcontext()

to

autocast_ctx = torch.autocast(accelerator.device.type, dtype=torch_dtype)

I am currently in the process of verifying that this fix correctly uploads the experiment to wandb before submitting a PR with the change.
If you have any suggestions for a better solution, I would greatly appreciate your feedback!

Reproduction

CUDA_VISIBLE_DEVICES=0 accelerate launch train_dreambooth_lora_flux.py \
--pretrained_model_name_or_path="/FLUX.1-dev "\
--instance_data_dir="/dataset/dog "\
--output_dir="trained-flux-dog-0928" \
--mixed_precision=bf16 \
--instance_prompt="a photo of sks dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=4 \
--learning_rate=1e-4 \
--lr_scheduler=constant \
--lr_warmup_steps=0 \
--max_train_steps=500 \
--checkpointing_steps=50 \
--seed=0 \
--rank=32 \
--report_to="wandb" \
--validation_prompt="A photo of sks dog in a bucket" \
--validation_epochs=25

Logs

[WARNING]  async_io requires the dev libaio .so object and headers but these were not found.
 [WARNING]  async_io: please install the libaio-dev package with apt
 [WARNING]  If libaio is already installed (perhaps from source), try setting the CFLAGS and LDFLAGS environment variables to where it can be found.
 [WARNING]  Please specify the CUTLASS repo directory as environment variable $CUTLASS_PATH
 [WARNING]  sparse_attn requires a torch version >= 1.5 and < 2.0 but detected 2.4
 [WARNING]  using untested triton version (3.0.0), only 1.0.0 is known to be compatible

stderr: Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
09/28/2024 14:20:08 - INFO - __main__ - Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: cuda

Mixed precision type: bf16

You set `add_prefix_space`. The tokenizer needs to be converted from the slow tokenizers
You are using a model of type clip_text_model to instantiate a model of type . This is not supported for all configurations of models and can yield errors.
You are using a model of type t5 to instantiate a model of type . This is not supported for all configurations of models and can yield errors.

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:03<00:03,  3.92s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.60s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.65s/it]
{'axes_dims_rope'} was not found in config. Values will be initialized to default values.
wandb: Currently logged in as: timdalee (timdalee-ai). Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.17.8
wandb: Run data is saved locally in /diffusers/examples/dreambooth/wandb/run-20240928_142109-n4e0rrva
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run brisk-cherry-9
wandb: ⭐️ View project at https://wandb.ai/timdalee-ai/dreambooth-flux-dev-lora
wandb: 🚀 View run at https://wandb.ai/timdalee-ai/dreambooth-flux-dev-lora/runs/n4e0rrva
09/28/2024 14:21:13 - INFO - __main__ - ***** Running training *****
09/28/2024 14:21:13 - INFO - __main__ -   Num examples = 5
09/28/2024 14:21:13 - INFO - __main__ -   Num batches each epoch = 5
09/28/2024 14:21:13 - INFO - __main__ -   Num Epochs = 250
09/28/2024 14:21:13 - INFO - __main__ -   Instantaneous batch size per device = 1
09/28/2024 14:21:13 - INFO - __main__ -   Total train batch size (w. parallel, distributed & accumulation) = 4
09/28/2024 14:21:13 - INFO - __main__ -   Gradient Accumulation steps = 4
09/28/2024 14:21:13 - INFO - __main__ -   Total optimization steps = 500

Steps:   0%|          | 0/500 [00:00<?, ?it/s]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor

Steps:   0%|          | 0/500 [00:01<?, ?it/s, loss=0.559, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor

Steps:   0%|          | 0/500 [00:01<?, ?it/s, loss=0.574, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor

Steps:   0%|          | 0/500 [00:02<?, ?it/s, loss=0.529, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor

Steps:   0%|          | 1/500 [00:02<24:27,  2.94s/it, loss=0.529, lr=0.0001]
Steps:   0%|          | 1/500 [00:02<24:27,  2.94s/it, loss=0.691, lr=0.0001]Passing `txt_ids` 3d torch.Tensor is deprecated.Please remove the batch dimension and pass it as a 2d torch Tensor

Steps:   0%|          | 2/500 [00:03<12:46,  1.54s/it, loss=0.691, lr=0.0001]
Steps:   0%|          | 2/500 [00:03<12:46,  1.54s/it, loss=0.762, lr=0.0001]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]
Loading checkpoint shards:  50%|█████     | 1/2 [00:04<00:04,  4.13s/it]
Loading checkpoint shards: 100%|██████████| 2/2 [00:07<00:00,  3.78s/it]
/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:49: FutureWarning: `torch.cuda.amp.custom_fwd(args...)` is deprecated. Please use `torch.amp.custom_fwd(args..., device_type='cuda')` instead.
  def forward(ctx, input, weight, bias=None):
/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/deepspeed/runtime/zero/linear.py:67: FutureWarning: `torch.cuda.amp.custom_bwd(args...)` is deprecated. Please use `torch.amp.custom_bwd(args..., device_type='cuda')` instead.
  def backward(ctx, grad_output):

                                                                     Loaded scheduler as FlowMatchEulerDiscreteScheduler from `scheduler` subfolder of /FLUX.1-dev.]
Loaded tokenizer as CLIPTokenizer from `tokenizer` subfolder of /FLUX.1-dev.
Loaded tokenizer_2 as T5TokenizerFast from `tokenizer_2` subfolder of /FLUX.1-dev.


Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 22.50it/s]
09/28/2024 14:21:27 - INFO - __main__ - Running validation... 
 Generating 4 images with prompt: A photo of sks dog in a bucket.
Traceback (most recent call last):
  File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1893, in <module>
    main(args)
  File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 1813, in main
    images = log_validation(
  File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 191, in log_validation
    images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
  File "/diffusers/examples/dreambooth/train_dreambooth_lora_flux.py", line 191, in <listcomp>
    images = [pipeline(**pipeline_args, generator=generator).images[0] for _ in range(args.num_validation_images)]
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/diffusers/src/diffusers/pipelines/flux/pipeline_flux.py", line 763, in __call__
image = self.vae.decode(latents, return_dict=False)[0]
  File "/diffusers/src/diffusers/utils/accelerate_utils.py", line 46, in wrapper
    return method(self, *args, **kwargs)
  File "/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 326, in decode
    decoded = self._decode(z).sample
  File "/diffusers/src/diffusers/models/autoencoders/autoencoder_kl.py", line 297, in _decode
    dec = self.decoder(z)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/diffusers/src/diffusers/models/autoencoders/vae.py", line 291, in forward
    sample = self.conv_in(sample)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1553, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1562, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 458, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/torch/nn/modules/conv.py", line 454, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Input type (float) and bias type (c10::BFloat16) should be the same

Steps:   0%|          | 2/500 [00:39<2:44:07, 19.77s/it, loss=0.755, lr=0.0001]
Traceback (most recent call last):
  File "/root/miniconda3/envs/flux_diffusers/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 48, in main
    args.func(args)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/accelerate/commands/launch.py", line 1106, in launch_command
    simple_launcher(args)
  File "/root/miniconda3/envs/flux_diffusers/lib/python3.10/site-packages/accelerate/commands/launch.py", line 704, in simple_launcher
    raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
subprocess.CalledProcessError: Command '['/root/miniconda3/envs/flux_diffusers/bin/python', 'train_dreambooth_lora_flux.py', '--pretrained_model_name_or_path=/FLUX.1-dev', '--instance_data_dir=/dataset/dog', '--output_dir=trained-flux-dog-0928', '--mixed_precision=bf16', '--instance_prompt=a photo of sks dog', '--resolution=512', '--train_batch_size=1', '--gradient_accumulation_steps=4', '--learning_rate=1e-4', '--lr_scheduler=constant', '--lr_warmup_steps=0', '--max_train_steps=500', '--checkpointing_steps=50', '--seed=0', '--rank=32', '--validation_prompt=A photo of sks dog in a bucket', '--validation_epochs=25']' returned non-zero exit status 1.

System Info

🤗 Diffusers version: 0.31.0.dev0
Platform: Linux-4.18.0-513.11.1.el8_9.x86_64-x86_64-with-glibc2.31
Running on Google Colab?: No
Python version: 3.10.14
PyTorch version (GPU?): 2.4.0+cu121 (True)
Flax version (CPU?/GPU?/TPU?): not installed (NA)
Jax version: not installed
JaxLib version: not installed
Huggingface_hub version: 0.24.6
Transformers version: 4.44.2
Accelerate version: 0.33.0
PEFT version: 0.12.0
Bitsandbytes version: not installed
Safetensors version: 0.4.4
xFormers version: not installed
Accelerator: NVIDIA A100 80GB PCIe, 81920 MiB
Using GPU in script?:
Using distributed or parallel set-up in script?:

Who can help?

@sayakpaul @linoytsaban

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions