Skip to content

[Examples] Add support for Min-SNR weighting strategy for better convergence #2899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
a009f1d
improve stable unclip doc.
sayakpaul Mar 25, 2023
ecf008f
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 28, 2023
01b4d70
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 28, 2023
24cab1d
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 29, 2023
c73fdba
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Mar 30, 2023
c8a2856
feat: support for applying min-snr weighting for faster convergence.
sayakpaul Mar 30, 2023
ca0c158
add: support for validation logging with wandb
sayakpaul Mar 30, 2023
76e9446
make not a required arg.
sayakpaul Mar 30, 2023
052bc88
fix: arg name.
sayakpaul Mar 30, 2023
c481147
fix: cli args.
sayakpaul Mar 30, 2023
835b5ee
fix: tracker config.
sayakpaul Mar 30, 2023
7c842f2
fix: loss calculation.
sayakpaul Mar 30, 2023
3f078bc
fix: validation logging.
sayakpaul Mar 30, 2023
1d9f3bc
fix: unwrap call.
sayakpaul Mar 30, 2023
d2ce5e6
fix: validation logging.
sayakpaul Mar 30, 2023
667d23d
fix: internval.
sayakpaul Mar 30, 2023
a154335
fix: checkpointing push to hub.
sayakpaul Mar 30, 2023
ad3fb92
fix: https://github.com/huggingface/diffusers/commit/c8a2856c6d5e4557…
sayakpaul Mar 31, 2023
084a341
Merge branch 'main' into feat/better-convergence
sayakpaul Apr 4, 2023
f91f6bd
fix: norm group test for UNet3D.
sayakpaul Apr 4, 2023
565566f
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Apr 4, 2023
bf837f5
Merge branch 'main' of https://github.com/huggingface/diffusers
sayakpaul Apr 5, 2023
7434dcd
address PR comments.
sayakpaul Apr 5, 2023
077c957
resolve conflicts.
sayakpaul Apr 5, 2023
db8bbbd
remove unneeded code.
sayakpaul Apr 5, 2023
96e7254
add: entry in the readme and docs.
sayakpaul Apr 5, 2023
245b558
Apply suggestions from code review
sayakpaul Apr 5, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions docs/source/en/training/text2image.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,28 @@ python train_text_to_image_flax.py \
</jax>
</frameworkcontent>

## Training with Min-SNR weighting

We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
value when using it is 5.0.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity, where is this value proposed? Is there a rule of thumb when choosing a value for this?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's reported in the paper. A gamma of 5.0 always leads to better results in the experiments presented by the authors in the paper.


You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:

* Training without the Min-SNR weighting strategy
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)

For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.

Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.

<Tip warning={true}>

Training with Min-SNR weighting strategy is only supported in PyTorch.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for future PR: Could be cool to add this in jax as well, will be useful for the jax event.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@yiyixuxu could you take a look?


</Tip>

## LoRA

You can also use Low-Rank Adaptation of Large Language Models (LoRA), a fine-tuning technique for accelerating training large models, for fine-tuning text-to-image models. For more details, take a look at the [LoRA training](lora#text-to-image) guide.
Expand Down
16 changes: 16 additions & 0 deletions examples/text_to_image/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,22 @@ image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")
```

#### Training with Min-SNR weighting

We support training with the Min-SNR weighting strategy proposed in [Efficient Diffusion Training via Min-SNR Weighting Strategy](https://arxiv.org/abs/2303.09556) which helps to achieve faster convergence
by rebalancing the loss. In order to use it, one needs to set the `--snr_gamma` argument. The recommended
value when using it is 5.0.

You can find [this project on Weights and Biases](https://wandb.ai/sayakpaul/text2image-finetune-minsnr) that compares the loss surfaces of the following setups:

* Training without the Min-SNR weighting strategy
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 5.0)
* Training with the Min-SNR weighting strategy (`snr_gamma` set to 1.0)

For our small Pokemons dataset, the effects of Min-SNR weighting strategy might not appear to be pronounced, but for larger datasets, we believe the effects will be more pronounced.

Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.

## Training with LoRA

Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.
Expand Down
163 changes: 154 additions & 9 deletions examples/text_to_image/train_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,74 @@
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import check_min_version, deprecate
from diffusers.utils import check_min_version, deprecate, is_wandb_available
from diffusers.utils.import_utils import is_xformers_available


if is_wandb_available():
import wandb


# Will error if the minimal version of diffusers is not installed. Remove at your own risks.
check_min_version("0.15.0.dev0")

logger = get_logger(__name__, log_level="INFO")

DATASET_NAME_MAPPING = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}


def log_validation(vae, text_encoder, tokenizer, unet, args, accelerator, weight_dtype, epoch):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added this method as a part of the PR as well. Handles EMA offload and unload properly to ensure inference is being done with the EMA'd checkpoints.

logger.info("Running validation... ")

pipeline = StableDiffusionPipeline.from_pretrained(
args.pretrained_model_name_or_path,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=accelerator.unwrap_model(unet),
safety_checker=None,
revision=args.revision,
torch_dtype=weight_dtype,
)
pipeline = pipeline.to(accelerator.device)
pipeline.set_progress_bar_config(disable=True)

if args.enable_xformers_memory_efficient_attention:
pipeline.enable_xformers_memory_efficient_attention()

if args.seed is None:
generator = None
else:
generator = torch.Generator(device=accelerator.device).manual_seed(args.seed)

images = []
for i in range(len(args.validation_prompts)):
with torch.autocast("cuda"):
image = pipeline(args.validation_prompts[i], num_inference_steps=20, generator=generator).images[0]

images.append(image)

for tracker in accelerator.trackers:
if tracker.name == "tensorboard":
np_images = np.stack([np.asarray(img) for img in images])
tracker.writer.add_images("validation", np_images, epoch, dataformats="NHWC")
elif tracker.name == "wandb":
tracker.log(
{
"validation": [
wandb.Image(image, caption=f"{i}: {args.validation_prompts[i]}")
for i, image in enumerate(images)
]
}
)
else:
logger.warn(f"image logging not implemented for {tracker.name}")

del pipeline
torch.cuda.empty_cache()


def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
Expand Down Expand Up @@ -111,6 +170,13 @@ def parse_args():
"value if set."
),
)
parser.add_argument(
"--validation_prompts",
type=str,
default=None,
nargs="+",
help=("A set of prompts evaluated every `--validation_epochs` and logged to `--report_to`."),
)
parser.add_argument(
"--output_dir",
type=str,
Expand Down Expand Up @@ -192,6 +258,13 @@ def parse_args():
parser.add_argument(
"--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler."
)
parser.add_argument(
"--snr_gamma",
type=float,
default=None,
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
"More details here: https://arxiv.org/abs/2303.09556.",
)
parser.add_argument(
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
)
Expand Down Expand Up @@ -297,6 +370,21 @@ def parse_args():
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument("--noise_offset", type=float, default=0, help="The scale of noise offset.")
parser.add_argument(
"--validation_epochs",
type=int,
default=5,
help="Run validation every X epochs.",
)
parser.add_argument(
"--tracker_project_name",
type=str,
default="text2image-fine-tune",
help=(
"The `project_name` argument passed to Accelerator.init_trackers for"
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
),
)

args = parser.parse_args()
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
Expand All @@ -314,11 +402,6 @@ def parse_args():
return args


dataset_name_mapping = {
"lambdalabs/pokemon-blip-captions": ("image", "text"),
}


def main():
args = parse_args()

Expand Down Expand Up @@ -410,6 +493,30 @@ def main():
else:
raise ValueError("xformers is not available. Make sure it is installed correctly")

def compute_snr(timesteps):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
alphas_cumprod = noise_scheduler.alphas_cumprod
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5

# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)

sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)

# Compute SNR.
snr = (alpha / sigma) ** 2
return snr

# `accelerate` 0.16.0 will have better support for customized saving
if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
# create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
Expand Down Expand Up @@ -507,7 +614,7 @@ def load_model_hook(models, input_dir):
column_names = dataset["train"].column_names

# 6. Get the column names for input/target.
dataset_columns = dataset_name_mapping.get(args.dataset_name, None)
dataset_columns = DATASET_NAME_MAPPING.get(args.dataset_name, None)
if args.image_column is None:
image_column = dataset_columns[0] if dataset_columns is not None else column_names[0]
else:
Expand Down Expand Up @@ -626,7 +733,9 @@ def collate_fn(examples):
# We need to initialize the trackers we use, and also store our configuration.
# The trackers initializes automatically on the main process.
if accelerator.is_main_process:
accelerator.init_trackers("text2image-fine-tune", config=vars(args))
tracker_config = dict(vars(args))
tracker_config.pop("validation_prompts")
accelerator.init_trackers(args.tracker_project_name, tracker_config)

# Train!
total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
Expand Down Expand Up @@ -715,7 +824,23 @@ def collate_fn(examples):

# Predict the noise residual and compute loss
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")

if args.snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
# Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556.
# Since we predict the noise instead of x_0, the original formulation is slightly changed.
Comment on lines +831 to +832
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some models (sd2.1 and above) use v-prediction, does this formulation also work with v-prediction?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

# This is discussed in Section 4.2 of the same paper.
snr = compute_snr(timesteps)
mse_loss_weights = (
torch.stack([snr, args.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr
)
# We first calculate the original loss. Then we mean over the non-batch dimensions and
# rebalance the sample-wise losses with their respective loss weights.
# Finally, we take the mean of the rebalanced loss.
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
Comment on lines +835 to +843
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice!


# Gather the losses across all processes for logging (if we use distributed training).
avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
Expand Down Expand Up @@ -750,6 +875,26 @@ def collate_fn(examples):
if global_step >= args.max_train_steps:
break

if accelerator.is_main_process:
if args.validation_prompts is not None and epoch % args.validation_epochs == 0:
if args.use_ema:
# Store the UNet parameters temporarily and load the EMA parameters to perform inference.
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
log_validation(
vae,
text_encoder,
tokenizer,
unet,
args,
accelerator,
weight_dtype,
global_step,
)
if args.use_ema:
# Switch back to the original UNet parameters.
ema_unet.restore(unet.parameters())

# Create the pipeline using the trained modules and save it.
accelerator.wait_for_everyone()
if accelerator.is_main_process:
Expand Down