|
| 1 | +"""SAMPLING ONLY.""" |
| 2 | + |
| 3 | +import torch |
| 4 | +import numpy as np |
| 5 | +from tqdm import tqdm |
| 6 | +from functools import partial |
| 7 | + |
| 8 | +from ldm.modules.diffusionmodules.util import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like |
| 9 | + |
| 10 | + |
| 11 | +class PLMSSampler(object): |
| 12 | + def __init__(self, model, schedule="linear", **kwargs): |
| 13 | + super().__init__() |
| 14 | + self.model = model |
| 15 | + self.ddpm_num_timesteps = model.num_timesteps |
| 16 | + self.schedule = schedule |
| 17 | + |
| 18 | + def register_buffer(self, name, attr): |
| 19 | + if type(attr) == torch.Tensor: |
| 20 | + if attr.device != torch.device("cuda"): |
| 21 | + attr = attr.to(torch.device("cuda")) |
| 22 | + setattr(self, name, attr) |
| 23 | + |
| 24 | + def make_schedule(self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0., verbose=True): |
| 25 | + if ddim_eta != 0: |
| 26 | + raise ValueError('ddim_eta must be 0 for PLMS') |
| 27 | + self.ddim_timesteps = make_ddim_timesteps(ddim_discr_method=ddim_discretize, num_ddim_timesteps=ddim_num_steps, |
| 28 | + num_ddpm_timesteps=self.ddpm_num_timesteps,verbose=verbose) |
| 29 | + alphas_cumprod = self.model.alphas_cumprod |
| 30 | + assert alphas_cumprod.shape[0] == self.ddpm_num_timesteps, 'alphas have to be defined for each timestep' |
| 31 | + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) |
| 32 | + |
| 33 | + self.register_buffer('betas', to_torch(self.model.betas)) |
| 34 | + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) |
| 35 | + self.register_buffer('alphas_cumprod_prev', to_torch(self.model.alphas_cumprod_prev)) |
| 36 | + |
| 37 | + # calculations for diffusion q(x_t | x_{t-1}) and others |
| 38 | + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod.cpu()))) |
| 39 | + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod.cpu()))) |
| 40 | + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod.cpu()))) |
| 41 | + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu()))) |
| 42 | + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod.cpu() - 1))) |
| 43 | + |
| 44 | + # ddim sampling parameters |
| 45 | + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters(alphacums=alphas_cumprod.cpu(), |
| 46 | + ddim_timesteps=self.ddim_timesteps, |
| 47 | + eta=ddim_eta,verbose=verbose) |
| 48 | + self.register_buffer('ddim_sigmas', ddim_sigmas) |
| 49 | + self.register_buffer('ddim_alphas', ddim_alphas) |
| 50 | + self.register_buffer('ddim_alphas_prev', ddim_alphas_prev) |
| 51 | + self.register_buffer('ddim_sqrt_one_minus_alphas', np.sqrt(1. - ddim_alphas)) |
| 52 | + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( |
| 53 | + (1 - self.alphas_cumprod_prev) / (1 - self.alphas_cumprod) * ( |
| 54 | + 1 - self.alphas_cumprod / self.alphas_cumprod_prev)) |
| 55 | + self.register_buffer('ddim_sigmas_for_original_num_steps', sigmas_for_original_sampling_steps) |
| 56 | + |
| 57 | + @torch.no_grad() |
| 58 | + def sample(self, |
| 59 | + S, |
| 60 | + batch_size, |
| 61 | + shape, |
| 62 | + conditioning=None, |
| 63 | + callback=None, |
| 64 | + normals_sequence=None, |
| 65 | + img_callback=None, |
| 66 | + quantize_x0=False, |
| 67 | + eta=0., |
| 68 | + mask=None, |
| 69 | + x0=None, |
| 70 | + temperature=1., |
| 71 | + noise_dropout=0., |
| 72 | + score_corrector=None, |
| 73 | + corrector_kwargs=None, |
| 74 | + verbose=True, |
| 75 | + x_T=None, |
| 76 | + log_every_t=100, |
| 77 | + unconditional_guidance_scale=1., |
| 78 | + unconditional_conditioning=None, |
| 79 | + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... |
| 80 | + **kwargs |
| 81 | + ): |
| 82 | + if conditioning is not None: |
| 83 | + if isinstance(conditioning, dict): |
| 84 | + cbs = conditioning[list(conditioning.keys())[0]].shape[0] |
| 85 | + if cbs != batch_size: |
| 86 | + print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}") |
| 87 | + else: |
| 88 | + if conditioning.shape[0] != batch_size: |
| 89 | + print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}") |
| 90 | + |
| 91 | + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) |
| 92 | + # sampling |
| 93 | + C, H, W = shape |
| 94 | + size = (batch_size, C, H, W) |
| 95 | + print(f'Data shape for PLMS sampling is {size}') |
| 96 | + |
| 97 | + samples, intermediates = self.plms_sampling(conditioning, size, |
| 98 | + callback=callback, |
| 99 | + img_callback=img_callback, |
| 100 | + quantize_denoised=quantize_x0, |
| 101 | + mask=mask, x0=x0, |
| 102 | + ddim_use_original_steps=False, |
| 103 | + noise_dropout=noise_dropout, |
| 104 | + temperature=temperature, |
| 105 | + score_corrector=score_corrector, |
| 106 | + corrector_kwargs=corrector_kwargs, |
| 107 | + x_T=x_T, |
| 108 | + log_every_t=log_every_t, |
| 109 | + unconditional_guidance_scale=unconditional_guidance_scale, |
| 110 | + unconditional_conditioning=unconditional_conditioning, |
| 111 | + ) |
| 112 | + return samples, intermediates |
| 113 | + |
| 114 | + @torch.no_grad() |
| 115 | + def plms_sampling(self, cond, shape, |
| 116 | + x_T=None, ddim_use_original_steps=False, |
| 117 | + callback=None, timesteps=None, quantize_denoised=False, |
| 118 | + mask=None, x0=None, img_callback=None, log_every_t=100, |
| 119 | + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, |
| 120 | + unconditional_guidance_scale=1., unconditional_conditioning=None,): |
| 121 | + device = self.model.betas.device |
| 122 | + b = shape[0] |
| 123 | + if x_T is None: |
| 124 | + img = torch.randn(shape, device=device) |
| 125 | + else: |
| 126 | + img = x_T |
| 127 | + |
| 128 | + if timesteps is None: |
| 129 | + timesteps = self.ddpm_num_timesteps if ddim_use_original_steps else self.ddim_timesteps |
| 130 | + elif timesteps is not None and not ddim_use_original_steps: |
| 131 | + subset_end = int(min(timesteps / self.ddim_timesteps.shape[0], 1) * self.ddim_timesteps.shape[0]) - 1 |
| 132 | + timesteps = self.ddim_timesteps[:subset_end] |
| 133 | + |
| 134 | + intermediates = {'x_inter': [img], 'pred_x0': [img]} |
| 135 | + time_range = list(reversed(range(0,timesteps))) if ddim_use_original_steps else np.flip(timesteps) |
| 136 | + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] |
| 137 | + print(f"Running PLMS Sampling with {total_steps} timesteps") |
| 138 | + |
| 139 | + iterator = tqdm(time_range, desc='PLMS Sampler', total=total_steps) |
| 140 | + old_eps = [] |
| 141 | + |
| 142 | + for i, step in enumerate(iterator): |
| 143 | + index = total_steps - i - 1 |
| 144 | + ts = torch.full((b,), step, device=device, dtype=torch.long) |
| 145 | + ts_next = torch.full((b,), time_range[min(i + 1, len(time_range) - 1)], device=device, dtype=torch.long) |
| 146 | + |
| 147 | + if mask is not None: |
| 148 | + assert x0 is not None |
| 149 | + img_orig = self.model.q_sample(x0, ts) # TODO: deterministic forward pass? |
| 150 | + img = img_orig * mask + (1. - mask) * img |
| 151 | + |
| 152 | + outs = self.p_sample_plms(img, cond, ts, index=index, use_original_steps=ddim_use_original_steps, |
| 153 | + quantize_denoised=quantize_denoised, temperature=temperature, |
| 154 | + noise_dropout=noise_dropout, score_corrector=score_corrector, |
| 155 | + corrector_kwargs=corrector_kwargs, |
| 156 | + unconditional_guidance_scale=unconditional_guidance_scale, |
| 157 | + unconditional_conditioning=unconditional_conditioning, |
| 158 | + old_eps=old_eps, t_next=ts_next) |
| 159 | + img, pred_x0, e_t = outs |
| 160 | + old_eps.append(e_t) |
| 161 | + if len(old_eps) >= 4: |
| 162 | + old_eps.pop(0) |
| 163 | + if callback: callback(i) |
| 164 | + if img_callback: img_callback(pred_x0, i) |
| 165 | + |
| 166 | + if index % log_every_t == 0 or index == total_steps - 1: |
| 167 | + intermediates['x_inter'].append(img) |
| 168 | + intermediates['pred_x0'].append(pred_x0) |
| 169 | + |
| 170 | + return img, intermediates |
| 171 | + |
| 172 | + @torch.no_grad() |
| 173 | + def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=False, quantize_denoised=False, |
| 174 | + temperature=1., noise_dropout=0., score_corrector=None, corrector_kwargs=None, |
| 175 | + unconditional_guidance_scale=1., unconditional_conditioning=None, old_eps=None, t_next=None): |
| 176 | + b, *_, device = *x.shape, x.device |
| 177 | + |
| 178 | + def get_model_output(x, t): |
| 179 | + if unconditional_conditioning is None or unconditional_guidance_scale == 1.: |
| 180 | + e_t = self.model.apply_model(x, t, c) |
| 181 | + else: |
| 182 | + x_in = torch.cat([x] * 2) |
| 183 | + t_in = torch.cat([t] * 2) |
| 184 | + c_in = torch.cat([unconditional_conditioning, c]) |
| 185 | + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) |
| 186 | + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) |
| 187 | + |
| 188 | + if score_corrector is not None: |
| 189 | + assert self.model.parameterization == "eps" |
| 190 | + e_t = score_corrector.modify_score(self.model, e_t, x, t, c, **corrector_kwargs) |
| 191 | + |
| 192 | + return e_t |
| 193 | + |
| 194 | + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas |
| 195 | + alphas_prev = self.model.alphas_cumprod_prev if use_original_steps else self.ddim_alphas_prev |
| 196 | + sqrt_one_minus_alphas = self.model.sqrt_one_minus_alphas_cumprod if use_original_steps else self.ddim_sqrt_one_minus_alphas |
| 197 | + sigmas = self.model.ddim_sigmas_for_original_num_steps if use_original_steps else self.ddim_sigmas |
| 198 | + |
| 199 | + def get_x_prev_and_pred_x0(e_t, index): |
| 200 | + # select parameters corresponding to the currently considered timestep |
| 201 | + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) |
| 202 | + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) |
| 203 | + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) |
| 204 | + sqrt_one_minus_at = torch.full((b, 1, 1, 1), sqrt_one_minus_alphas[index],device=device) |
| 205 | + |
| 206 | + # current prediction for x_0 |
| 207 | + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() |
| 208 | + if quantize_denoised: |
| 209 | + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) |
| 210 | + # direction pointing to x_t |
| 211 | + dir_xt = (1. - a_prev - sigma_t**2).sqrt() * e_t |
| 212 | + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature |
| 213 | + if noise_dropout > 0.: |
| 214 | + noise = torch.nn.functional.dropout(noise, p=noise_dropout) |
| 215 | + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise |
| 216 | + return x_prev, pred_x0 |
| 217 | + |
| 218 | + e_t = get_model_output(x, t) |
| 219 | + if len(old_eps) == 0: |
| 220 | + # Pseudo Improved Euler (2nd order) |
| 221 | + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t, index) |
| 222 | + e_t_next = get_model_output(x_prev, t_next) |
| 223 | + e_t_prime = (e_t + e_t_next) / 2 |
| 224 | + elif len(old_eps) == 1: |
| 225 | + # 2nd order Pseudo Linear Multistep (Adams-Bashforth) |
| 226 | + e_t_prime = (3 * e_t - old_eps[-1]) / 2 |
| 227 | + elif len(old_eps) == 2: |
| 228 | + # 3nd order Pseudo Linear Multistep (Adams-Bashforth) |
| 229 | + e_t_prime = (23 * e_t - 16 * old_eps[-1] + 5 * old_eps[-2]) / 12 |
| 230 | + elif len(old_eps) >= 3: |
| 231 | + # 4nd order Pseudo Linear Multistep (Adams-Bashforth) |
| 232 | + e_t_prime = (55 * e_t - 59 * old_eps[-1] + 37 * old_eps[-2] - 9 * old_eps[-3]) / 24 |
| 233 | + |
| 234 | + x_prev, pred_x0 = get_x_prev_and_pred_x0(e_t_prime, index) |
| 235 | + |
| 236 | + return x_prev, pred_x0, e_t |
0 commit comments