Skip to content

Commit 84781fb

Browse files
committed
Fix noise pred timestep, clip_tokenizer, CLIP image encoding, and other bugs.
1 parent 0300563 commit 84781fb

File tree

2 files changed

+36
-49
lines changed

2 files changed

+36
-49
lines changed

src/diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py

Lines changed: 33 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def _infer_mode(self, prompt, prompt_embeds, image, prompt_latents, vae_latents,
242242
def set_text_mode(self):
243243
self.mode = "text"
244244

245-
def set_img_mode(self):
245+
def set_image_mode(self):
246246
self.mode = "img"
247247

248248
def set_text_to_image_mode(self):
@@ -276,7 +276,8 @@ def _infer_batch_size(self, mode, prompt, prompt_embeds, image, num_samples):
276276
batch_size = num_samples
277277
return batch_size
278278

279-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
279+
# Modified from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
280+
# self.tokenizer => self.clip_tokenizer
280281
def _encode_prompt(
281282
self,
282283
prompt,
@@ -319,25 +320,25 @@ def _encode_prompt(
319320
batch_size = prompt_embeds.shape[0]
320321

321322
if prompt_embeds is None:
322-
text_inputs = self.tokenizer(
323+
text_inputs = self.clip_tokenizer(
323324
prompt,
324325
padding="max_length",
325-
max_length=self.tokenizer.model_max_length,
326+
max_length=self.clip_tokenizer.model_max_length,
326327
truncation=True,
327328
return_tensors="pt",
328329
)
329330
text_input_ids = text_inputs.input_ids
330-
untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
331+
untruncated_ids = self.clip_tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
331332

332333
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
333334
text_input_ids, untruncated_ids
334335
):
335-
removed_text = self.tokenizer.batch_decode(
336-
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
336+
removed_text = self.clip_tokenizer.batch_decode(
337+
untruncated_ids[:, self.clip_tokenizer.model_max_length - 1 : -1]
337338
)
338339
logger.warning(
339340
"The following part of your input was truncated because CLIP can only handle sequences up to"
340-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
341+
f" {self.clip_tokenizer.model_max_length} tokens: {removed_text}"
341342
)
342343

343344
if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
@@ -380,7 +381,7 @@ def _encode_prompt(
380381
uncond_tokens = negative_prompt
381382

382383
max_length = prompt_embeds.shape[1]
383-
uncond_input = self.tokenizer(
384+
uncond_input = self.clip_tokenizer(
384385
uncond_tokens,
385386
padding="max_length",
386387
max_length=max_length,
@@ -480,24 +481,21 @@ def encode_image_clip_latents(
480481
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
481482
)
482483

483-
image = image.to(device=device, dtype=dtype)
484+
preprocessed_image = self.image_processor.preprocess(
485+
image,
486+
do_center_crop=True,
487+
crop_size=resolution,
488+
return_tensors="pt",
489+
)
490+
preprocessed_image = preprocessed_image.to(device=device, dtype=dtype)
484491

485492
if isinstance(generator, list):
486493
image_latents = [
487-
self.image_encoder(
488-
**self.image_processor.preprocess(
489-
image[i : i + 1], do_center_crop=True, crop_size=resolution, return_tensors="pt"
490-
)
491-
)
492-
for i in range(batch_size)
494+
self.image_encoder(**preprocessed_image[i : i + 1]).pooler_output for i in range(batch_size)
493495
]
494496
image_latents = torch.cat(image_latents, dim=0)
495497
else:
496-
# TODO: figure out self.image_processor.preprocess kwargs
497-
inputs = self.image_processor.preprocess(
498-
image, do_center_crop=True, crop_size=resolution, return_tensors="pt"
499-
)
500-
image_latents = self.image_encoder(**inputs)
498+
image_latents = self.image_encoder(**preprocessed_image).pooler_output
501499

502500
if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
503501
# expand image_latents for batch_size
@@ -659,7 +657,7 @@ def get_noise_pred(
659657
prompt_embeds,
660658
img_vae,
661659
img_clip,
662-
timesteps,
660+
max_timestep,
663661
guidance_scale,
664662
generator,
665663
device,
@@ -689,17 +687,15 @@ def get_noise_pred(
689687
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
690688
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
691689
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
692-
t_img_uncond = torch.ones_like(t) * timesteps[0]
693-
t_text_uncond = torch.ones_like(t) * timesteps[0]
694690

695691
# print(f"t_img_uncond: {t_img_uncond}")
696692
# print(f"t_img_uncond shape: {t_img_uncond.shape}")
697693

698694
# print("Running unconditional U-Net call 1 for CFG...")
699-
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=t_img_uncond, t_text=t)
695+
_, _, text_out_uncond = self.unet(img_vae_T, img_clip_T, text_latents, t_img=max_timestep, t_text=t)
700696
# print("Running unconditional U-Net call 2 for CFG...")
701697
img_vae_out_uncond, img_clip_out_uncond, _ = self.unet(
702-
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=t_text_uncond
698+
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
703699
)
704700

705701
x_out_uncond = self._combine_joint(img_vae_out_uncond, img_clip_out_uncond, text_out_uncond)
@@ -708,10 +704,9 @@ def get_noise_pred(
708704
elif mode == "text2img":
709705
# Text-conditioned image generation
710706
img_vae_latents, img_clip_latents = self._split(latents, height, width)
711-
t_text = torch.zeros(t.size(0), dtype=torch.int, device=device)
712707

713708
img_vae_out, img_clip_out, text_out = self.unet(
714-
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text
709+
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=0
715710
)
716711

717712
img_out = self._combine(img_vae_out, img_clip_out)
@@ -721,48 +716,41 @@ def get_noise_pred(
721716

722717
# Classifier-free guidance
723718
text_T = randn_tensor(prompt_embeds.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
724-
t_text_uncond = torch.ones_like(t) * timesteps
725719

726720
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
727-
img_vae_latents, img_clip_latents, text_T, t_img=timesteps, t_text=t_text_uncond
721+
img_vae_latents, img_clip_latents, text_T, t_img=t, t_text=max_timestep
728722
)
729723

730724
img_out_uncond = self._combine(img_vae_out_uncond, img_clip_out_uncond)
731725

732726
return guidance_scale * img_out + (1.0 - guidance_scale) * img_out_uncond
733727
elif mode == "img2text":
734728
# Image-conditioned text generation
735-
t_img = torch.zeros(t.size(0), dtype=torch.int, device=device)
736-
737-
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t)
729+
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=0, t_text=t)
738730

739731
if guidance_scale <= 1.0:
740732
return text_out
741733

742734
# Classifier-free guidance
743735
img_vae_T = randn_tensor(img_vae.shape, generator=generator, device=device, dtype=img_vae.dtype)
744736
img_clip_T = randn_tensor(img_clip.shape, generator=generator, device=device, dtype=img_clip.dtype)
745-
t_img_uncond = torch.ones_like(t) * timesteps
746737

747738
img_vae_out_uncond, img_clip_out_uncond, text_out_uncond = self.unet(
748-
img_vae_T, img_clip_T, latents, t_img=t_img_uncond, t_text=timesteps
739+
img_vae_T, img_clip_T, latents, t_img=max_timestep, t_text=t
749740
)
750741

751742
return guidance_scale * text_out + (1.0 - guidance_scale) * text_out_uncond
752743
elif mode == "text":
753744
# Unconditional ("marginal") text generation (no CFG)
754-
t_img = torch.ones_like(t) * timesteps
755-
756-
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=t_img, t_text=t)
745+
img_vae_out, img_clip_out, text_out = self.unet(img_vae, img_clip, latents, t_img=max_timestep, t_text=t)
757746

758747
return text_out
759748
elif mode == "img":
760749
# Unconditional ("marginal") image generation (no CFG)
761750
img_vae_latents, img_clip_latents = self._split(latents, height, width)
762-
t_text = torch.ones_like(t) * timesteps
763751

764752
img_vae_out, img_clip_out, text_out = self.unet(
765-
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=t_text
753+
img_vae_latents, img_clip_latents, prompt_embeds, t_img=t, t_text=max_timestep
766754
)
767755

768756
img_out = self._combine(img_vae_out, img_clip_out)
@@ -980,7 +968,7 @@ def __call__(
980968
assert image is not None
981969
# Encode image using VAE
982970
image_vae = preprocess(image)
983-
height, width = image.shape[-2:]
971+
height, width = image_vae.shape[-2:]
984972
image_vae_latents = self.encode_image_vae_latents(
985973
image_vae,
986974
batch_size,
@@ -1001,6 +989,8 @@ def __call__(
1001989
device,
1002990
generator,
1003991
)
992+
# (batch_size, clip_hidden_size) => (batch_size, 1, clip_hidden_size)
993+
image_clip_latents = image_clip_latents.unsqueeze(1)
1004994
else:
1005995
# 4.2. Prepare image latent variables, if input not available
1006996
# Prepare image VAE latents
@@ -1030,6 +1020,7 @@ def __call__(
10301020
# 5. Set timesteps
10311021
self.scheduler.set_timesteps(num_inference_steps, device=device)
10321022
timesteps = self.scheduler.timesteps
1023+
max_timestep = timesteps[0]
10331024
# print(f"Timesteps: {timesteps}")
10341025
# print(f"Timesteps shape: {timesteps.shape}")
10351026

@@ -1062,7 +1053,7 @@ def __call__(
10621053
prompt_embeds,
10631054
image_vae_latents,
10641055
image_clip_latents,
1065-
timesteps,
1056+
max_timestep,
10661057
guidance_scale,
10671058
generator,
10681059
device,

tests/pipelines/unidiffuser/test_unidiffuser.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import unittest
33

44
import numpy as np
5-
import pytest
65
import torch
76
from PIL import Image
87
from transformers import (
@@ -158,7 +157,6 @@ def get_dummy_inputs(self, device, seed=0):
158157
}
159158
return inputs
160159

161-
# @pytest.mark.xfail(reason="not finished debugging")
162160
def test_unidiffuser_default_joint(self):
163161
device = "cpu" # ensure determinism for the device-dependent torch.Generator
164162
components = self.get_dummy_components()
@@ -186,7 +184,6 @@ def test_unidiffuser_default_joint(self):
186184
# TODO: need to figure out correct text output
187185
print(text)
188186

189-
@pytest.mark.xfail(reason="haven't begun debugging")
190187
def test_unidiffuser_default_text2img(self):
191188
device = "cpu" # ensure determinism for the device-dependent torch.Generator
192189
components = self.get_dummy_components()
@@ -208,7 +205,6 @@ def test_unidiffuser_default_text2img(self):
208205
expected_slice = np.array([0.3965, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321])
209206
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
210207

211-
@pytest.mark.xfail(reason="haven't begun debugging")
212208
def test_unidiffuser_default_img2text(self):
213209
device = "cpu" # ensure determinism for the device-dependent torch.Generator
214210
components = self.get_dummy_components()
@@ -227,8 +223,8 @@ def test_unidiffuser_default_img2text(self):
227223

228224
# TODO: need to figure out correct text output
229225
print(text)
226+
assert 0 == 1
230227

231-
@pytest.mark.xfail(reason="haven't begun debugging")
232228
def test_unidiffuser_default_text(self):
233229
device = "cpu" # ensure determinism for the device-dependent torch.Generator
234230
components = self.get_dummy_components()
@@ -248,8 +244,8 @@ def test_unidiffuser_default_text(self):
248244

249245
# TODO: need to figure out correct text output
250246
print(text)
247+
assert 0 == 1
251248

252-
@pytest.mark.xfail(reason="haven't begun debugging")
253249
def test_unidiffuser_default_image(self):
254250
device = "cpu" # ensure determinism for the device-dependent torch.Generator
255251
components = self.get_dummy_components()
@@ -268,8 +264,8 @@ def test_unidiffuser_default_image(self):
268264
image = unidiffuser_pipe(**inputs).images
269265
assert image.shape == (1, 32, 32, 3)
270266

271-
# TODO: get expected slice of image output
272267
image_slice = image[0, -3:, -3:, -1]
268+
print(image_slice.flatten())
273269
expected_slice = np.array([0.3967, 0.4568, 0.4495, 0.4590, 0.4463, 0.4690, 0.5454, 0.5093, 0.4321])
274270
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
275271

0 commit comments

Comments
 (0)