Skip to content

[Examples] fix checkpointing and casting bugs in train_text_to_image_lora_sdxl.py #4632

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions examples/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -828,6 +828,87 @@ def test_text_to_image_lora_checkpointing_checkpoints_total_limit(self):
{"checkpoint-4", "checkpoint-6"},
)

def test_text_to_image_lora_sdxl_checkpointing_checkpoints_total_limit(self):
prompt = "a prompt"
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted

initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + initial_run_args)

pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)

def test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
prompt = "a prompt"
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

with tempfile.TemporaryDirectory() as tmpdir:
# Run training script with checkpointing
# max_train_steps == 7, checkpointing_steps == 2, checkpoints_total_limit == 2
# Should create checkpoints at steps 2, 4, 6
# with checkpoint at step 2 deleted

initial_run_args = f"""
examples/text_to_image/train_text_to_image_lora_sdxl.py
--pretrained_model_name_or_path {pipeline_path}
--dataset_name hf-internal-testing/dummy_image_text_data
--resolution 64
--train_batch_size 1
--gradient_accumulation_steps 1
--max_train_steps 7
--learning_rate 5.0e-04
--scale_lr
--lr_scheduler constant
--train_text_encoder
--lr_warmup_steps 0
--output_dir {tmpdir}
--checkpointing_steps=2
--checkpoints_total_limit=2
""".split()

run_command(self._launch_args + initial_run_args)

pipe = DiffusionPipeline.from_pretrained(pipeline_path)
pipe.load_lora_weights(tmpdir)
pipe(prompt, num_inference_steps=2)

# check checkpoint directories exist
self.assertEqual(
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
# checkpoint-2 should have been deleted
{"checkpoint-4", "checkpoint-6"},
)

def test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
pretrained_model_name_or_path = "hf-internal-testing/tiny-stable-diffusion-pipe"
prompt = "a prompt"
Expand Down
32 changes: 11 additions & 21 deletions examples/text_to_image/train_text_to_image_lora_sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -396,16 +396,6 @@ def parse_args(input_args=None):
" flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
),
)
parser.add_argument(
"--prior_generation_precision",
type=str,
default=None,
choices=["no", "fp32", "fp16", "bf16"],
help=(
"Choose prior generation precision between fp32, fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
" 1.10.and an Nvidia Ampere GPU. Default to fp16 if a GPU is available else fp32."
),
)
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
Expand Down Expand Up @@ -724,11 +714,15 @@ def load_model_hook(models, input_dir):

lora_state_dict, network_alphas = LoraLoaderMixin.lora_state_dict(input_dir)
LoraLoaderMixin.load_lora_into_unet(lora_state_dict, network_alphas=network_alphas, unet=unet_)

text_encoder_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
text_encoder_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_one_
)

text_encoder_2_state_dict = {k: v for k, v in lora_state_dict.items() if "text_encoder_2." in k}
LoraLoaderMixin.load_lora_into_text_encoder(
lora_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
text_encoder_2_state_dict, network_alphas=network_alphas, text_encoder=text_encoder_two_
)

accelerator.register_save_state_pre_hook(save_model_hook)
Expand Down Expand Up @@ -1002,9 +996,12 @@ def collate_fn(examples):
continue

with accelerator.accumulate(unet):
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)

# Convert images to latent space
if args.pretrained_vae_model_name_or_path is not None:
pixel_values = batch["pixel_values"].to(dtype=weight_dtype)
else:
pixel_values = batch["pixel_values"]

model_input = vae.encode(pixel_values).latent_dist.sample()
model_input = model_input * vae.config.scaling_factor
if args.pretrained_vae_model_name_or_path is None:
Expand Down Expand Up @@ -1147,13 +1144,6 @@ def compute_time_ids(original_size, crops_coords_top_left):
f" {args.validation_prompt}."
)
# create pipeline
if not args.train_text_encoder:
text_encoder_one = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
)
text_encoder_two = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2", revision=args.revision
)
pipeline = StableDiffusionXLPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
Expand Down
1 change: 1 addition & 0 deletions tests/models/test_lora_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ def test_load_lora_locally(self):
unet_lora_layers=lora_components["unet_lora_layers"],
text_encoder_lora_layers=lora_components["text_encoder_one_lora_layers"],
text_encoder_2_lora_layers=lora_components["text_encoder_two_lora_layers"],
safe_serialization=False,
)
self.assertTrue(os.path.isfile(os.path.join(tmpdirname, "pytorch_lora_weights.bin")))
sd_pipe.load_lora_weights(os.path.join(tmpdirname, "pytorch_lora_weights.bin"))
Expand Down