Skip to content

Commit 3654b9c

Browse files
committed
bugfixes
1 parent 3bcbd85 commit 3654b9c

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,7 +1242,7 @@ def invert(
12421242
callback(i, t, latents)
12431243

12441244
assert len(inverted_latents) == len(timesteps)
1245-
latents = torch.stack(list(reversed(inverted_latents)), 0)
1245+
latents = torch.cat(list(reversed(inverted_latents)))
12461246

12471247
# 8. Post-processing
12481248
image = self.decode_latents(latents.detach())
@@ -1478,7 +1478,7 @@ def __call__(
14781478
)
14791479

14801480
# 7. Preprocess inverted latents
1481-
inverted_latents_shape = (len(timesteps), batch_size, num_channels_latents, *vae_latent_size)
1481+
inverted_latents_shape = (len(timesteps) * batch_size, num_channels_latents, *vae_latent_size)
14821482
if inverted_latents is None:
14831483
raise ValueError(
14841484
"`inverted_latents` input cannot be undefined. Use `invert()` to compute `inverted_latents`."
@@ -1489,6 +1489,10 @@ def __call__(
14891489
)
14901490
if isinstance(inverted_latents, np.ndarray):
14911491
inverted_latents = torch.from_numpy(inverted_latents)
1492+
inverted_latents = torch.cat(
1493+
[inverted_latents.reshape(-1, batch_size, num_channels_latents, *vae_latent_size)] * num_images_per_prompt,
1494+
1,
1495+
)
14921496
inverted_latents = inverted_latents.to(device=device, dtype=prompt_embeds.dtype)
14931497

14941498
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline

0 commit comments

Comments
 (0)