Skip to content

[Examples] Add support for Min-SNR weighting strategy for better convergence #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 6, 2023
Merged
Changes from 18 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a009f1d
improve stable unclip doc.
sayakpaul Mar 25, 2023
ecf008f
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 28, 2023
01b4d70
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 28, 2023
24cab1d
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 29, 2023
c73fdba
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 30, 2023
c8a2856
feat: support for applying min-snr weighting for faster convergence.
sayakpaul Mar 30, 2023
ca0c158
add: support for validation logging with wandb
sayakpaul Mar 30, 2023
76e9446
make not a required arg.
sayakpaul Mar 30, 2023
052bc88
fix: arg name.
sayakpaul Mar 30, 2023
c481147
fix: cli args.
sayakpaul Mar 30, 2023
835b5ee
fix: tracker config.
sayakpaul Mar 30, 2023
7c842f2
fix: loss calculation.
sayakpaul Mar 30, 2023
3f078bc
fix: validation logging.
sayakpaul Mar 30, 2023
1d9f3bc
fix: unwrap call.
sayakpaul Mar 30, 2023
d2ce5e6
fix: validation logging.
sayakpaul Mar 30, 2023
667d23d
fix: internval.
sayakpaul Mar 30, 2023
a154335
fix: checkpointing push to hub.
sayakpaul Mar 30, 2023
ad3fb92
fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e4557…
sayakpaul Mar 31, 2023
084a341
Merge branch 'main' into feat/better-convergence
sayakpaul Apr 4, 2023
f91f6bd
fix: norm group test for UNet3D.
sayakpaul Apr 4, 2023
565566f
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Apr 4, 2023
bf837f5
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Apr 5, 2023
7434dcd
address PR comments.
sayakpaul Apr 5, 2023
077c957
resolve conflicts.
sayakpaul Apr 5, 2023
db8bbbd
remove unneeded code.
sayakpaul Apr 5, 2023
96e7254
add: entry in the readme and docs.
sayakpaul Apr 5, 2023
245b558
Apply suggestions from code review
sayakpaul Apr 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 162 additions & 8 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,23 @@
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


if is_wandb_available():
import wandb


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.15.0.dev0")

logger = get_logger(__name__, log_level="INFO")

DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
Expand Down Expand Up @@ -112,6 +120,13 @@ def parse_args():
"value if set."
),
)
parser.add_argument(
"--validation_prompts",
type=str,
default=None,
nargs="+",
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
)
parser.add_argument(
"--output_dir",
type=str,
Expand Down Expand Up @@ -193,6 +208,13 @@ def parse_args():
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
Expand Down Expand Up @@ -298,6 +320,21 @@ def parse_args():
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--validation_epochs",
type=int,
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand Down Expand Up @@ -325,9 +362,83 @@ def get_full_repo_name(model_id: str, organization: Optional[str] = None, token:
return f"{organization}/{model_id}"


dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}
def expand_tensor(arr, timesteps, broadcast_shape):
"""
Extract values from a 1-D numpy array for a batch of indices.
Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
"""
res = arr.to(device=timesteps.device)[timesteps].float()
while len(res.shape) < len(broadcast_shape):
res = res[..., None]
return res.expand(broadcast_shape)


def compute_snr(noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

def fn(timesteps):
alpha = expand_tensor(sqrt_alphas_cumprod, timesteps, timesteps.shape)
sigma = expand_tensor(sqrt_one_minus_alphas_cumprod, timesteps, timesteps.shape)
snr = (alpha / sigma) ** 2
return snr

return fn


def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this related to this PR title or does it just add logging?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR could have been without this utility but it was important to have it in the PR because otherwise, it was difficult to validate the effectiveness of the method.

FWIW, I am not a fan of adding unrelated changes in a PR but this seemed important.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok for me!

logger.info("Running validation... ")

pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=accelerator.unwrap_model(unet),
safety_checker=None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()

if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)

images = []
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]

images.append(image)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
elif tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
for i, image in enumerate(images)
]
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")

del pipeline
torch.cuda.empty_cache()


def main():
Expand Down Expand Up @@ -389,6 +500,8 @@ def main():
gitignore.write("step_*\n")
if "epoch_*" not in gitignore:
gitignore.write("epoch_*\n")
if "checkpoint-*" not in gitignore:
gitignore.write("checkpoint-*\n")
elif args.output_dir is not None:
os.makedirs(args.output_dir, exist_ok=True)

Expand Down Expand Up @@ -476,6 +589,9 @@ def load_model_hook(models, input_dir):
args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
)

if args.snr_gamma is not None:
snr_fn = compute_snr(noise_scheduler)

# Initialize the optimizer
if args.use_8bit_adam:
try:
Expand Down Expand Up @@ -526,7 +642,7 @@ def load_model_hook(models, input_dir):
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
Expand Down Expand Up @@ -645,7 +761,9 @@ def collate_fn(examples):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
tracker_config = dict(vars(args))
tracker_config.pop("validation_prompts")
accelerator.init_trackers(args.tracker_project_name, tracker_config)

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
Expand Down Expand Up @@ -734,7 +852,23 @@ def collate_fn(examples):

# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
Comment on lines +831 to +832
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some models (sd2.1 and above) use v-prediction, does this formulation also work with v-prediction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This is discussed in Section 4.2 of the same paper.
snr = snr_fn(timesteps)
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
Comment on lines +835 to +843
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
Expand Down Expand Up @@ -769,6 +903,26 @@ def collate_fn(examples):
if global_step >= args.max_train_steps:
break

if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
log_validation(
vae,
text_encoder,
tokenizer,
unet,
args,
accelerator,
weight_dtype,
global_step,
)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand All @@ -786,7 +940,7 @@ def collate_fn(examples):
pipeline.save_pretrained(args.output_dir)

if args.push_to_hub:
repo.push_to_hub(commit_message="End of training", blocking=False, auto_lfs_prune=True)
repo.push_to_hub(commit_message="End of training", blocking=True, auto_lfs_prune=True)

accelerator.end_training()

Expand Down