|
1 | 1 | import contextlib
|
2 | 2 | import copy
|
3 | 3 | import random
|
4 |
| -from typing import Any, Dict, Iterable, List, Optional, Union |
| 4 | +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union |
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 | import torch
|
8 | 8 |
|
9 | 9 | from .models import UNet2DConditionModel
|
| 10 | +from .schedulers import SchedulerMixin |
10 | 11 | from .utils import (
|
11 | 12 | convert_state_dict_to_diffusers,
|
12 | 13 | convert_state_dict_to_peft,
|
@@ -117,6 +118,60 @@ def resolve_interpolation_mode(interpolation_type: str):
|
117 | 118 | return interpolation_mode
|
118 | 119 |
|
119 | 120 |
|
| 121 | +def compute_dream_and_update_latents( |
| 122 | + unet: UNet2DConditionModel, |
| 123 | + noise_scheduler: SchedulerMixin, |
| 124 | + timesteps: torch.Tensor, |
| 125 | + noise: torch.Tensor, |
| 126 | + noisy_latents: torch.Tensor, |
| 127 | + target: torch.Tensor, |
| 128 | + encoder_hidden_states: torch.Tensor, |
| 129 | + dream_detail_preservation: float = 1.0, |
| 130 | +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: |
| 131 | + """ |
| 132 | + Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from http://arxiv.org/abs/2312.00210. |
| 133 | + DREAM helps align training with sampling to help training be more efficient and accurate at the cost of an extra |
| 134 | + forward step without gradients. |
| 135 | +
|
| 136 | + Args: |
| 137 | + `unet`: The state unet to use to make a prediction. |
| 138 | + `noise_scheduler`: The noise scheduler used to add noise for the given timestep. |
| 139 | + `timesteps`: The timesteps for the noise_scheduler to user. |
| 140 | + `noise`: A tensor of noise in the shape of noisy_latents. |
| 141 | + `noisy_latents`: Previously noise latents from the training loop. |
| 142 | + `target`: The ground-truth tensor to predict after eps is removed. |
| 143 | + `encoder_hidden_states`: Text embeddings from the text model. |
| 144 | + `dream_detail_preservation`: A float value that indicates detail preservation level. |
| 145 | + See reference. |
| 146 | +
|
| 147 | + Returns: |
| 148 | + `tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target. |
| 149 | + """ |
| 150 | + alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None] |
| 151 | + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 152 | + |
| 153 | + # The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments. |
| 154 | + dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation |
| 155 | + |
| 156 | + pred = None |
| 157 | + with torch.no_grad(): |
| 158 | + pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 159 | + |
| 160 | + noisy_latents, target = (None, None) |
| 161 | + if noise_scheduler.config.prediction_type == "epsilon": |
| 162 | + predicted_noise = pred |
| 163 | + delta_noise = (noise - predicted_noise).detach() |
| 164 | + delta_noise.mul_(dream_lambda) |
| 165 | + noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise) |
| 166 | + target = target.add(delta_noise) |
| 167 | + elif noise_scheduler.config.prediction_type == "v_prediction": |
| 168 | + raise NotImplementedError("DREAM has not been implemented for v-prediction") |
| 169 | + else: |
| 170 | + raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 171 | + |
| 172 | + return noisy_latents, target |
| 173 | + |
| 174 | + |
120 | 175 | def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
121 | 176 | r"""
|
122 | 177 | Returns:
|
|
0 commit comments