@@ -1242,7 +1242,7 @@ def invert(
1242
1242
callback (i , t , latents )
1243
1243
1244
1244
assert len (inverted_latents ) == len (timesteps )
1245
- latents = torch .stack (list (reversed (inverted_latents )), 0 )
1245
+ latents = torch .cat (list (reversed (inverted_latents )))
1246
1246
1247
1247
# 8. Post-processing
1248
1248
image = self .decode_latents (latents .detach ())
@@ -1478,7 +1478,7 @@ def __call__(
1478
1478
)
1479
1479
1480
1480
# 7. Preprocess inverted latents
1481
- inverted_latents_shape = (len (timesteps ), batch_size , num_channels_latents , * vae_latent_size )
1481
+ inverted_latents_shape = (len (timesteps ) * batch_size , num_channels_latents , * vae_latent_size )
1482
1482
if inverted_latents is None :
1483
1483
raise ValueError (
1484
1484
"`inverted_latents` input cannot be undefined. Use `invert()` to compute `inverted_latents`."
@@ -1489,6 +1489,10 @@ def __call__(
1489
1489
)
1490
1490
if isinstance (inverted_latents , np .ndarray ):
1491
1491
inverted_latents = torch .from_numpy (inverted_latents )
1492
+ inverted_latents = torch .cat (
1493
+ [inverted_latents .reshape (- 1 , batch_size , num_channels_latents , * vae_latent_size )] * num_images_per_prompt ,
1494
+ 1 ,
1495
+ )
1492
1496
inverted_latents = inverted_latents .to (device = device , dtype = prompt_embeds .dtype )
1493
1497
1494
1498
# 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
0 commit comments