Skip to content

Commit 5b11c5d

Browse files
yiyixuxuyiyixuxu
and
yiyixuxu
authored
fix the add_noise function for dpm-multi et al (#5158)
* remove to _device() for sigmas * update add_noise to use simgas --------- Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 310cf32 commit 5b11c5d

File tree

5 files changed

+27
-23
lines changed

5 files changed

+27
-23
lines changed

src/diffusers/schedulers/scheduling_deis_multistep.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -243,8 +243,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
243243
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
244244
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
245245

246-
self.sigmas = torch.from_numpy(sigmas).to(device=device)
247-
self.timesteps = torch.from_numpy(timesteps).to(device)
246+
self.sigmas = torch.from_numpy(sigmas)
247+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
248248

249249
self.num_inference_steps = len(timesteps)
250250

@@ -707,12 +707,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
707707
"""
708708
return sample
709709

710-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
710+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
711711
def add_noise(
712712
self,
713713
original_samples: torch.FloatTensor,
714714
noise: torch.FloatTensor,
715-
timesteps: torch.FloatTensor,
715+
timesteps: torch.IntTensor,
716716
) -> torch.FloatTensor:
717717
# Make sure sigmas and timesteps have the same device and dtype as original_samples
718718
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
@@ -730,7 +730,8 @@ def add_noise(
730730
while len(sigma.shape) < len(original_samples.shape):
731731
sigma = sigma.unsqueeze(-1)
732732

733-
noisy_samples = original_samples + noise * sigma
733+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
734+
noisy_samples = alpha_t * original_samples + sigma_t * noise
734735
return noisy_samples
735736

736737
def __len__(self):

src/diffusers/schedulers/scheduling_dpmsolver_multistep.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,8 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
263263
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
264264
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
265265

266-
self.sigmas = torch.from_numpy(sigmas).to(device=device)
267-
self.timesteps = torch.from_numpy(timesteps).to(device)
266+
self.sigmas = torch.from_numpy(sigmas)
267+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
268268

269269
self.num_inference_steps = len(timesteps)
270270

@@ -840,12 +840,11 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
840840
"""
841841
return sample
842842

843-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
844843
def add_noise(
845844
self,
846845
original_samples: torch.FloatTensor,
847846
noise: torch.FloatTensor,
848-
timesteps: torch.FloatTensor,
847+
timesteps: torch.IntTensor,
849848
) -> torch.FloatTensor:
850849
# Make sure sigmas and timesteps have the same device and dtype as original_samples
851850
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
@@ -863,7 +862,8 @@ def add_noise(
863862
while len(sigma.shape) < len(original_samples.shape):
864863
sigma = sigma.unsqueeze(-1)
865864

866-
noisy_samples = original_samples + noise * sigma
865+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
866+
noisy_samples = alpha_t * original_samples + sigma_t * noise
867867
return noisy_samples
868868

869869
def __len__(self):

src/diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,7 @@ def set_timesteps(self, num_inference_steps: int = None, device: Union[str, torc
274274
_, unique_indices = np.unique(timesteps, return_index=True)
275275
timesteps = timesteps[np.sort(unique_indices)]
276276

277-
self.timesteps = torch.from_numpy(timesteps).to(device)
277+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
278278

279279
self.num_inference_steps = len(timesteps)
280280

@@ -858,12 +858,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
858858
"""
859859
return sample
860860

861-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
861+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
862862
def add_noise(
863863
self,
864864
original_samples: torch.FloatTensor,
865865
noise: torch.FloatTensor,
866-
timesteps: torch.FloatTensor,
866+
timesteps: torch.IntTensor,
867867
) -> torch.FloatTensor:
868868
# Make sure sigmas and timesteps have the same device and dtype as original_samples
869869
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
@@ -881,7 +881,8 @@ def add_noise(
881881
while len(sigma.shape) < len(original_samples.shape):
882882
sigma = sigma.unsqueeze(-1)
883883

884-
noisy_samples = original_samples + noise * sigma
884+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
885+
noisy_samples = alpha_t * original_samples + sigma_t * noise
885886
return noisy_samples
886887

887888
def __len__(self):

src/diffusers/schedulers/scheduling_dpmsolver_singlestep.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
275275

276276
self.sigmas = torch.from_numpy(sigmas).to(device=device)
277277

278-
self.timesteps = torch.from_numpy(timesteps).to(device)
278+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
279279
self.model_outputs = [None] * self.config.solver_order
280280
self.sample = None
281281

@@ -870,12 +870,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
870870
"""
871871
return sample
872872

873-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
873+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
874874
def add_noise(
875875
self,
876876
original_samples: torch.FloatTensor,
877877
noise: torch.FloatTensor,
878-
timesteps: torch.FloatTensor,
878+
timesteps: torch.IntTensor,
879879
) -> torch.FloatTensor:
880880
# Make sure sigmas and timesteps have the same device and dtype as original_samples
881881
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
@@ -893,7 +893,8 @@ def add_noise(
893893
while len(sigma.shape) < len(original_samples.shape):
894894
sigma = sigma.unsqueeze(-1)
895895

896-
noisy_samples = original_samples + noise * sigma
896+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
897+
noisy_samples = alpha_t * original_samples + sigma_t * noise
897898
return noisy_samples
898899

899900
def __len__(self):

src/diffusers/schedulers/scheduling_unipc_multistep.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.devic
254254
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
255255
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
256256

257-
self.sigmas = torch.from_numpy(sigmas).to(device=device)
258-
self.timesteps = torch.from_numpy(timesteps).to(device)
257+
self.sigmas = torch.from_numpy(sigmas)
258+
self.timesteps = torch.from_numpy(timesteps).to(device=device, dtype=torch.int64)
259259

260260
self.num_inference_steps = len(timesteps)
261261

@@ -801,12 +801,12 @@ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch
801801
"""
802802
return sample
803803

804-
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
804+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.add_noise
805805
def add_noise(
806806
self,
807807
original_samples: torch.FloatTensor,
808808
noise: torch.FloatTensor,
809-
timesteps: torch.FloatTensor,
809+
timesteps: torch.IntTensor,
810810
) -> torch.FloatTensor:
811811
# Make sure sigmas and timesteps have the same device and dtype as original_samples
812812
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
@@ -824,7 +824,8 @@ def add_noise(
824824
while len(sigma.shape) < len(original_samples.shape):
825825
sigma = sigma.unsqueeze(-1)
826826

827-
noisy_samples = original_samples + noise * sigma
827+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
828+
noisy_samples = alpha_t * original_samples + sigma_t * noise
828829
return noisy_samples
829830

830831
def __len__(self):

0 commit comments

Comments
 (0)