Closed
Description
Describe the bug
While working on #6569 related to #6545, I run the InstructPix2Pix SDXL training example code and noticed this issue.
I think before merging #6569, we should fix this issue first.
Reproduction
I just followed the Toy example [guide].(https://github.com/huggingface/diffusers/blob/main/examples/instruct_pix2pix/README_sdxl.md#toy-example)
export MODEL_NAME="stabilityai/stable-diffusion-xl-base-1.0"
export DATASET_ID="fusing/instructpix2pix-1000-samples"
accelerate launch train_instruct_pix2pix_sdxl.py \
--pretrained_model_name_or_path=stabilityai/stable-diffusion-xl-base-1.0 \
--dataset_name=$DATASET_ID \
--use_ema \
--enable_xformers_memory_efficient_attention \
--resolution=512 --random_flip \
--train_batch_size=4 --gradient_accumulation_steps=4 --gradient_checkpointing \
--max_train_steps=15000 \
--checkpointing_steps=5000 --checkpoints_total_limit=1 \
--learning_rate=5e-05 --lr_warmup_steps=0 \
--conditioning_dropout_prob=0.05 \
--seed=42 \
--val_image_url_or_path="https://datasets-server.huggingface.co/assets/fusing/instructpix2pix-1000-samples/--/fusing--instructpix2pix-1000-samples/train/23/input_image/image.jpg" \
--validation_prompt="make it in japan" \
--report_to=wandb \
--push_to_hub
Logs
Traceback (most recent call last):
File "/root/dev/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py", line 1234, in <module>
main()
File "/root/dev/diffusers/examples/instruct_pix2pix/train_instruct_pix2pix_sdxl.py", line 1147, in main
a_val_img = pipeline(
File "/opt/env/conda/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/root/dev/diffusers/src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py", line 952, in __call__
scaled_latent_model_input = torch.cat([scaled_latent_model_input, image_latents], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 128 but got size 64 for tensor number 1 in the list.
System Info
diffusers
version: 0.26.0.dev0- Platform: Linux-4.19.93-1
- Python version: 3.10.8
- PyTorch version (GPU?): 2.0.1+cu117 (True)
- Huggingface_hub version: 0.20.2
- Transformers version: 4.36.2
- Accelerate version: 0.26.1
- xFormers version: 0.0.22
- Using GPU in script?: Yes, a single GPU
- Using distributed or parallel set-up in script?: No