Skip to content

train_dreambooth_lora.py failed on two machines #3363

Closed
@bohong13

Description

@bohong13

Describe the bug

I have found two errors.

  1. when process save checkpoint
Traceback (most recent call last):
  File "/home/momistest/db/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1112, in <module>
    main(args)
  File "/home/momistest/db/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 991, in main
    LoraLoaderMixin.save_lora_weights(
  File "/home/momistest/db/diffusers/src/diffusers/loaders.py", line 1111, in save_lora_weights
    for module_name, param in unet_lora_layers.state_dict().items()
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1818, in state_dict
    module.state_dict(destination=destination, prefix=prefix + name + '.', keep_vars=keep_vars)
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1820, in state_dict
    hook_result = hook(self, destination, prefix, local_metadata)
  File "/home/momistest/db/diffusers/src/diffusers/loaders.py", line 74, in map_to
    num = int(key.split(".")[1])  # 0 is always "layers"
ValueError: invalid literal for int() with base 10: 'layers'

Then I try to solve this error using the method in issue #3284

but i get this error

Traceback (most recent call last):
  File "/home/momistest/db/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1112, in <module>
    main(args)
  File "/home/momistest/db/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1067, in main
    pipeline.load_lora_weights(args.output_dir)
  File "/home/momistest/db/diffusers/src/diffusers/loaders.py", line 846, in load_lora_weights
    self.unet.load_attn_procs(unet_lora_state_dict)
  File "/home/momistest/db/diffusers/src/diffusers/loaders.py", line 305, in load_attn_procs
    self.set_attn_processor(attn_processors)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 533, in set_attn_processor
    fn_recursive_attn_processor(name, module, processor)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  [Previous line repeated 3 more times]
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 527, in fn_recursive_attn_processor
    module.set_processor(processor.pop(f"{name}.processor"))
KeyError: 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor'
  1. Abnormal model parameter exchange.
    I have two machines on the same local network, but when I monitor the network traffic using iftop, the model parameters exchange packet of TX and RX is not the same.
196.168.1.123 => 192.168.1.183     20.2kb
              <=                     416b

TX:23.0MB
RX:108MB
TOTAL 131MB

Reproduction

I followed this dog example to run the program on two machines.

I have two laptops with NVIDIA RTX 3080 GPUs.
machine 1 IP is 192.168.1.123
machine 2 IP is 192.168.1.183

The environment and package versions of the two machines are exactly the same

Accelerate env is 
- `Accelerate` version: 0.18.0
- Platform: Linux-5.13.0-30-generic-x86_64-with-glibc2.31
- Python version: 3.10.9
- Numpy version: 1.24.3
- PyTorch version (GPU?): 2.0.0+cu117 (True)
- `Accelerate` default config:
	- compute_environment: LOCAL_MACHINE
	- distributed_type: MULTI_GPU
	- mixed_precision: no
	- use_cpu: False
	- num_processes: 2
	- machine_rank: 0
	- num_machines: 2
	- gpu_ids: all
	- main_process_ip: 192.168.1.123
	- main_process_port: 29500
	- rdzv_backend: static
	- same_network: True
	- main_training_function: main
	- downcast_bf16: no
	- tpu_use_cluster: False
	- tpu_use_sudo: False
	- tpu_env: []

and I Run this script on two machine

export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export INSTANCE_DIR="/home/momistest/db/diffusers/examples/dreambooth/dog"
export OUTPUT_DIR="/home/momistest/db/diffusers/examples/dreambooth/lora_output"

NCCL_DEBUG=INFO accelerate launch train_dreambooth_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME  \
  --instance_data_dir=$INSTANCE_DIR \
  --output_dir=$OUTPUT_DIR \
  --instance_prompt="a photo of sks dog" \
  --resolution=512 \
  --train_batch_size=1 \
  --gradient_accumulation_steps=1 \
  --checkpointing_steps=100 \
  --learning_rate=1e-4 \
  --report_to="wandb" \
  --lr_scheduler="constant" \
  --lr_warmup_steps=0 \
  --max_train_steps=500 \
  --validation_prompt="A photo of sks dog in a bucket" \
  --validation_epochs=50 \
  --seed="0" \

Logs

Steps: 100%|██████████████████████████████████| 500/500 [06:52<00:00,  1.71it/s, loss=0.208, lr=0.0001]Model weights saved in /home/momistest/db/diffusers/examples/dreambooth/lora_output/pytorch_lora_weights.bin
{'requires_safety_checker'} was not found in config. Values will be initialized to default values.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
{'prediction_type'} was not found in config. Values will be initialized to default values.
{'scaling_factor'} was not found in config. Values will be initialized to default values.
{'timestep_post_act', 'num_class_embeds', 'resnet_time_scale_shift', 'resnet_skip_time_act', 'addition_embed_type_num_heads', 'conv_in_kernel', 'mid_block_only_cross_attention', 'only_cross_attention', 'time_embedding_act_fn', 'addition_embed_type', 'encoder_hid_dim', 'use_linear_projection', 'conv_out_kernel', 'upcast_attention', 'class_embeddings_concat', 'class_embed_type', 'time_embedding_dim', 'mid_block_type', 'projection_class_embeddings_input_dim', 'dual_cross_attention', 'resnet_out_scale_factor', 'cross_attention_norm', 'time_embedding_type', 'time_cond_proj_dim'} was not found in config. Values will be initialized to default values.
{'sample_max_value', 'thresholding', 'solver_type', 'solver_order', 'dynamic_thresholding_ratio', 'use_karras_sigmas', 'algorithm_type', 'lower_order_final'} was not found in config. Values will be initialized to default values.
Loading unet.
Traceback (most recent call last):
  File "/home/momistest/db/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1112, in <module>
    main(args)
  File "/home/momistest/db/diffusers/examples/dreambooth/train_dreambooth_lora.py", line 1067, in main
    pipeline.load_lora_weights(args.output_dir)
  File "/home/momistest/db/diffusers/src/diffusers/loaders.py", line 847, in load_lora_weights
    self.unet.load_attn_procs(unet_lora_state_dict)
  File "/home/momistest/db/diffusers/src/diffusers/loaders.py", line 305, in load_attn_procs
    self.set_attn_processor(attn_processors)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 533, in set_attn_processor
    fn_recursive_attn_processor(name, module, processor)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 530, in fn_recursive_attn_processor
    fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
  [Previous line repeated 3 more times]
  File "/home/momistest/db/diffusers/src/diffusers/models/unet_2d_condition.py", line 527, in fn_recursive_attn_processor
    module.set_processor(processor.pop(f"{name}.processor"))
KeyError: 'down_blocks.0.attentions.0.transformer_blocks.0.attn1.processor'
wandb: Waiting for W&B process to finish... (failed 1). Press Control-C to abort syncing.
wandb: 
wandb: Run history:
wandb: loss ▁▂▂▁▁▁▃█▂▁▂▂▁▄▁▁▁▁▂▁▁▃▂▁▂▃▁▂▂▁▁▁▄▂▁▂▂▁▁▁
wandb:   lr ▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wandb: 
wandb: Run summary:
wandb: loss 0.20755
wandb:   lr 0.0001
wandb: 
wandb: 🚀 View run swift-aardvark-13 at: https://wandb.ai/account/dreambooth-lora/runs/qvn0373n
wandb: Synced 6 W&B file(s), 16 media file(s), 0 artifact file(s) and 0 other file(s)
wandb: Find logs at: ./wandb/run-xxxxxx-qvn0373n/logs
ERROR:torch.distributed.elastic.multiprocessing.api:failed (exitcode: 1) local_rank: 0 (pid: 200217) of binary: /home/momistest/anaconda3/envs/hg/bin/python
Traceback (most recent call last):
  File "/home/momistest/anaconda3/envs/hg/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/accelerate/commands/launch.py", line 914, in launch_command
    multi_gpu_launcher(args)
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/accelerate/commands/launch.py", line 603, in multi_gpu_launcher
    distrib_run.run(args)
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/torch/distributed/run.py", line 785, in run
    elastic_launch(
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/home/momistest/anaconda3/envs/hg/lib/python3.10/site-packages/torch/distributed/launcher/api.py", line 250, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError: 
============================================================
train_dreambooth_lora.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-05-08_02:07:48
  host      : host-192-168-1-123.openstacklocal
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 200217)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html

System Info

- `diffusers` version: 0.17.0.dev0
- Platform: Linux-5.13.0-30-generic-x86_64-with-glibc2.31
- Python version: 3.10.9
- PyTorch version (GPU?): 2.0.0+cu117 (True)
- Huggingface_hub version: 0.14.1
- Transformers version: 4.28.1
- Accelerate version: 0.18.0
- xFormers version: 0.0.19
- Using GPU in script?: <fill in>
- Using distributed or parallel set-up in script?: <fill in>

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingstaleIssues that haven't received updates

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions