Skip to content

Commit 9d16daa

Browse files
Add DREAM training (#6381)
A new function compute_dream_and_update_latents has been added to the training utilities that allows you to do DREAM rectified training in line with the paper https://arxiv.org/abs/2312.00210. The method can be used with an extra argument in the train_text_to_image.py script. Co-authored-by: Jimmy <39@🇺🇸.com>
1 parent 8e4ca1b commit 9d16daa

File tree

3 files changed

+88
-2
lines changed

3 files changed

+88
-2
lines changed

examples/text_to_image/README.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,11 @@ For our small Pokemons dataset, the effects of Min-SNR weighting strategy might
170170

171171
Also, note that in this example, we either predict `epsilon` (i.e., the noise) or the `v_prediction`. For both of these cases, the formulation of the Min-SNR weighting strategy that we have used holds.
172172

173+
#### Training with DREAM
174+
175+
We support training epsilon (noise) prediction models using the [DREAM (Diffusion Rectification and Estimation-Adaptive Models) strategy](https://arxiv.org/abs/2312.00210). DREAM claims to increase model fidelity for the performance cost of an extra grad-less unet `forward` step in the training loop. You can turn on DREAM training by using the `--dream_training` argument. The `--dream_detail_preservation` argument controls the detail preservation variable p and is the default of 1 from the paper.
176+
177+
173178
## Training with LoRA
174179

175180
Low-Rank Adaption of Large Language Models was first introduced by Microsoft in [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685) by *Edward J. Hu, Yelong Shen, Phillip Wallis, Zeyuan Allen-Zhu, Yuanzhi Li, Shean Wang, Lu Wang, Weizhu Chen*.

examples/text_to_image/train_text_to_image.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import diffusers
4646
from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline, UNet2DConditionModel
4747
from diffusers.optimization import get_scheduler
48-
from diffusers.training_utils import EMAModel, compute_snr
48+
from diffusers.training_utils import EMAModel, compute_dream_and_update_latents, compute_snr
4949
from diffusers.utils import check_min_version, deprecate, is_wandb_available, make_image_grid
5050
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
5151
from diffusers.utils.import_utils import is_xformers_available
@@ -361,6 +361,20 @@ def parse_args():
361361
help="SNR weighting gamma to be used if rebalancing the loss. Recommended value is 5.0. "
362362
"More details here: https://arxiv.org/abs/2303.09556.",
363363
)
364+
parser.add_argument(
365+
"--dream_training",
366+
action="store_true",
367+
help=(
368+
"Use the DREAM training method, which makes training more efficient and accurate at the ",
369+
"expense of doing an extra forward pass. See: https://arxiv.org/abs/2312.00210",
370+
),
371+
)
372+
parser.add_argument(
373+
"--dream_detail_preservation",
374+
type=float,
375+
default=1.0,
376+
help="Dream detail preservation factor p (should be greater than 0; default=1.0, as suggested in the paper)",
377+
)
364378
parser.add_argument(
365379
"--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
366380
)
@@ -948,6 +962,18 @@ def unwrap_model(model):
948962
else:
949963
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
950964

965+
if args.dream_training:
966+
noisy_latents, target = compute_dream_and_update_latents(
967+
unet,
968+
noise_scheduler,
969+
timesteps,
970+
noise,
971+
noisy_latents,
972+
target,
973+
encoder_hidden_states,
974+
args.dream_detail_preservation,
975+
)
976+
951977
# Predict the noise residual and compute loss
952978
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0]
953979

src/diffusers/training_utils.py

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
import contextlib
22
import copy
33
import random
4-
from typing import Any, Dict, Iterable, List, Optional, Union
4+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
55

66
import numpy as np
77
import torch
88

99
from .models import UNet2DConditionModel
10+
from .schedulers import SchedulerMixin
1011
from .utils import (
1112
convert_state_dict_to_diffusers,
1213
convert_state_dict_to_peft,
@@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str):
117118
return interpolation_mode
118119

119120

121+
def compute_dream_and_update_latents(
122+
unet: UNet2DConditionModel,
123+
noise_scheduler: SchedulerMixin,
124+
timesteps: torch.Tensor,
125+
noise: torch.Tensor,
126+
noisy_latents: torch.Tensor,
127+
target: torch.Tensor,
128+
encoder_hidden_states: torch.Tensor,
129+
dream_detail_preservation: float = 1.0,
130+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
131+
"""
132+
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210.
133+
DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra
134+
forward step without gradients.
135+
136+
Args:
137+
`unet`: The state unet to use to make a prediction.
138+
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
139+
`timesteps`: The timesteps for the noise_scheduler to user.
140+
`noise`: A tensor of noise in the shape of noisy_latents.
141+
`noisy_latents`: Previously noise latents from the training loop.
142+
`target`: The ground-truth tensor to predict after eps is removed.
143+
`encoder_hidden_states`: Text embeddings from the text model.
144+
`dream_detail_preservation`: A float value that indicates detail preservation level.
145+
See reference.
146+
147+
Returns:
148+
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
149+
"""
150+
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
151+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
152+
153+
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
154+
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
155+
156+
pred = None
157+
with torch.no_grad():
158+
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
159+
160+
noisy_latents, target = (None, None)
161+
if noise_scheduler.config.prediction_type == "epsilon":
162+
predicted_noise = pred
163+
delta_noise = (noise - predicted_noise).detach()
164+
delta_noise.mul_(dream_lambda)
165+
noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
166+
target = target.add(delta_noise)
167+
elif noise_scheduler.config.prediction_type == "v_prediction":
168+
raise NotImplementedError("DREAM has not been implemented for v-prediction")
169+
else:
170+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
171+
172+
return noisy_latents, target
173+
174+
120175
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
121176
r"""
122177
Returns:

0 commit comments

Comments
 (0)