Skip to content

Commit c2a38ef

Browse files
authored
Fix/update the LDM pipeline and tests (#1743)
* Fix/update LDM tests * batched generators
1 parent 08cc36d commit c2a38ef

File tree

2 files changed

+145
-107
lines changed

2 files changed

+145
-107
lines changed

src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -128,29 +128,42 @@ def __call__(
128128

129129
# get unconditional embeddings for classifier free guidance
130130
if guidance_scale != 1.0:
131-
uncond_input = self.tokenizer([""] * batch_size, padding="max_length", max_length=77, return_tensors="pt")
131+
uncond_input = self.tokenizer(
132+
[""] * batch_size, padding="max_length", max_length=77, truncation=True, return_tensors="pt"
133+
)
132134
uncond_embeddings = self.bert(uncond_input.input_ids.to(self.device))[0]
133135

134136
# get prompt text embeddings
135-
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, return_tensors="pt")
137+
text_input = self.tokenizer(prompt, padding="max_length", max_length=77, truncation=True, return_tensors="pt")
136138
text_embeddings = self.bert(text_input.input_ids.to(self.device))[0]
137139

138140
# get the initial random noise unless the user supplied it
139141
latents_shape = (batch_size, self.unet.in_channels, height // 8, width // 8)
142+
if isinstance(generator, list) and len(generator) != batch_size:
143+
raise ValueError(
144+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
145+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
146+
)
147+
140148
if latents is None:
141-
if self.device.type == "mps":
142-
# randn does not work reproducibly on mps
143-
latents = torch.randn(latents_shape, generator=generator, device="cpu").to(self.device)
149+
rand_device = "cpu" if self.device.type == "mps" else self.device
150+
151+
if isinstance(generator, list):
152+
latents_shape = (1,) + latents_shape[1:]
153+
latents = [
154+
torch.randn(latents_shape, generator=generator[i], device=rand_device, dtype=text_embeddings.dtype)
155+
for i in range(batch_size)
156+
]
157+
latents = torch.cat(latents, dim=0)
144158
else:
145159
latents = torch.randn(
146-
latents_shape,
147-
generator=generator,
148-
device=self.device,
160+
latents_shape, generator=generator, device=rand_device, dtype=text_embeddings.dtype
149161
)
162+
latents = latents.to(self.device)
150163
else:
151164
if latents.shape != latents_shape:
152165
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
153-
latents = latents.to(self.device)
166+
latents = latents.to(self.device)
154167

155168
self.scheduler.set_timesteps(num_inference_steps)
156169

tests/pipelines/latent_diffusion/test_latent_diffusion.py

Lines changed: 123 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,29 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import unittest
1718

1819
import numpy as np
1920
import torch
2021

2122
from diffusers import AutoencoderKL, DDIMScheduler, LDMTextToImagePipeline, UNet2DConditionModel
22-
from diffusers.utils.testing_utils import require_torch, slow, torch_device
23+
from diffusers.utils.testing_utils import load_numpy, nightly, require_torch_gpu, slow, torch_device
2324
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
2425

26+
from ...test_pipelines_common import PipelineTesterMixin
27+
2528

2629
torch.backends.cuda.matmul.allow_tf32 = False
2730

2831

29-
class LDMTextToImagePipelineFastTests(unittest.TestCase):
30-
@property
31-
def dummy_cond_unet(self):
32+
class LDMTextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33+
pipeline_class = LDMTextToImagePipeline
34+
test_cpu_offload = False
35+
36+
def get_dummy_components(self):
3237
torch.manual_seed(0)
33-
model = UNet2DConditionModel(
38+
unet = UNet2DConditionModel(
3439
block_out_channels=(32, 64),
3540
layers_per_block=2,
3641
sample_size=32,
@@ -40,25 +45,24 @@ def dummy_cond_unet(self):
4045
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
4146
cross_attention_dim=32,
4247
)
43-
return model
44-
45-
@property
46-
def dummy_vae(self):
48+
scheduler = DDIMScheduler(
49+
beta_start=0.00085,
50+
beta_end=0.012,
51+
beta_schedule="scaled_linear",
52+
clip_sample=False,
53+
set_alpha_to_one=False,
54+
)
4755
torch.manual_seed(0)
48-
model = AutoencoderKL(
49-
block_out_channels=[32, 64],
56+
vae = AutoencoderKL(
57+
block_out_channels=(32, 64),
5058
in_channels=3,
5159
out_channels=3,
52-
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
53-
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
60+
down_block_types=("DownEncoderBlock2D", "DownEncoderBlock2D"),
61+
up_block_types=("UpDecoderBlock2D", "UpDecoderBlock2D"),
5462
latent_channels=4,
5563
)
56-
return model
57-
58-
@property
59-
def dummy_text_encoder(self):
6064
torch.manual_seed(0)
61-
config = CLIPTextConfig(
65+
text_encoder_config = CLIPTextConfig(
6266
bos_token_id=0,
6367
eos_token_id=2,
6468
hidden_size=32,
@@ -69,96 +73,117 @@ def dummy_text_encoder(self):
6973
pad_token_id=1,
7074
vocab_size=1000,
7175
)
72-
return CLIPTextModel(config)
73-
74-
def test_inference_text2img(self):
75-
if torch_device != "cpu":
76-
return
77-
78-
unet = self.dummy_cond_unet
79-
scheduler = DDIMScheduler()
80-
vae = self.dummy_vae
81-
bert = self.dummy_text_encoder
76+
text_encoder = CLIPTextModel(text_encoder_config)
8277
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
8378

84-
ldm = LDMTextToImagePipeline(vqvae=vae, bert=bert, tokenizer=tokenizer, unet=unet, scheduler=scheduler)
85-
ldm.to(torch_device)
86-
ldm.set_progress_bar_config(disable=None)
87-
88-
prompt = "A painting of a squirrel eating a burger"
89-
90-
# Warmup pass when using mps (see #372)
91-
if torch_device == "mps":
92-
generator = torch.manual_seed(0)
93-
_ = ldm(
94-
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=1, output_type="numpy"
95-
).images
96-
97-
device = torch_device if torch_device != "mps" else "cpu"
98-
generator = torch.Generator(device=device).manual_seed(0)
99-
100-
image = ldm(
101-
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=2, output_type="numpy"
102-
).images
103-
104-
device = torch_device if torch_device != "mps" else "cpu"
105-
generator = torch.Generator(device=device).manual_seed(0)
106-
107-
image_from_tuple = ldm(
108-
[prompt],
109-
generator=generator,
110-
guidance_scale=6.0,
111-
num_inference_steps=2,
112-
output_type="numpy",
113-
return_dict=False,
114-
)[0]
79+
components = {
80+
"unet": unet,
81+
"scheduler": scheduler,
82+
"vqvae": vae,
83+
"bert": text_encoder,
84+
"tokenizer": tokenizer,
85+
}
86+
return components
87+
88+
def get_dummy_inputs(self, device, seed=0):
89+
if str(device).startswith("mps"):
90+
generator = torch.manual_seed(seed)
91+
else:
92+
generator = torch.Generator(device=device).manual_seed(seed)
93+
inputs = {
94+
"prompt": "A painting of a squirrel eating a burger",
95+
"generator": generator,
96+
"num_inference_steps": 2,
97+
"guidance_scale": 6.0,
98+
"output_type": "numpy",
99+
}
100+
return inputs
115101

116-
image_slice = image[0, -3:, -3:, -1]
117-
image_from_tuple_slice = image_from_tuple[0, -3:, -3:, -1]
118-
119-
assert image.shape == (1, 16, 16, 3)
120-
expected_slice = np.array([0.6806, 0.5454, 0.5638, 0.4893, 0.4656, 0.4257, 0.6248, 0.5217, 0.5498])
121-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
122-
assert np.abs(image_from_tuple_slice.flatten() - expected_slice).max() < 1e-2
123-
124-
125-
@slow
126-
@require_torch
127-
class LDMTextToImagePipelineIntegrationTests(unittest.TestCase):
128102
def test_inference_text2img(self):
129-
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
130-
ldm.to(torch_device)
131-
ldm.set_progress_bar_config(disable=None)
132-
133-
prompt = "A painting of a squirrel eating a burger"
134-
135-
device = torch_device if torch_device != "mps" else "cpu"
136-
generator = torch.Generator(device=device).manual_seed(0)
103+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
137104

138-
image = ldm(
139-
[prompt], generator=generator, guidance_scale=6.0, num_inference_steps=20, output_type="numpy"
140-
).images
105+
components = self.get_dummy_components()
106+
pipe = LDMTextToImagePipeline(**components)
107+
pipe.to(device)
108+
pipe.set_progress_bar_config(disable=None)
141109

110+
inputs = self.get_dummy_inputs(device)
111+
image = pipe(**inputs).images
142112
image_slice = image[0, -3:, -3:, -1]
143113

144-
assert image.shape == (1, 256, 256, 3)
145-
expected_slice = np.array([0.9256, 0.9340, 0.8933, 0.9361, 0.9113, 0.8727, 0.9122, 0.8745, 0.8099])
146-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
147-
148-
def test_inference_text2img_fast(self):
149-
ldm = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256")
150-
ldm.to(torch_device)
151-
ldm.set_progress_bar_config(disable=None)
152-
153-
prompt = "A painting of a squirrel eating a burger"
114+
assert image.shape == (1, 16, 16, 3)
115+
expected_slice = np.array([0.59450, 0.64078, 0.55509, 0.51229, 0.69640, 0.36960, 0.59296, 0.60801, 0.49332])
154116

155-
device = torch_device if torch_device != "mps" else "cpu"
156-
generator = torch.Generator(device=device).manual_seed(0)
117+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
157118

158-
image = ldm(prompt, generator=generator, num_inference_steps=1, output_type="numpy").images
159119

160-
image_slice = image[0, -3:, -3:, -1]
120+
@slow
121+
@require_torch_gpu
122+
class LDMTextToImagePipelineSlowTests(unittest.TestCase):
123+
def tearDown(self):
124+
super().tearDown()
125+
gc.collect()
126+
torch.cuda.empty_cache()
127+
128+
def get_inputs(self, device, dtype=torch.float32, seed=0):
129+
generator = torch.Generator(device=device).manual_seed(seed)
130+
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
131+
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
132+
inputs = {
133+
"prompt": "A painting of a squirrel eating a burger",
134+
"latents": latents,
135+
"generator": generator,
136+
"num_inference_steps": 3,
137+
"guidance_scale": 6.0,
138+
"output_type": "numpy",
139+
}
140+
return inputs
141+
142+
def test_ldm_default_ddim(self):
143+
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
144+
pipe.set_progress_bar_config(disable=None)
145+
146+
inputs = self.get_inputs(torch_device)
147+
image = pipe(**inputs).images
148+
image_slice = image[0, -3:, -3:, -1].flatten()
161149

162150
assert image.shape == (1, 256, 256, 3)
163-
expected_slice = np.array([0.3163, 0.8670, 0.6465, 0.1865, 0.6291, 0.5139, 0.2824, 0.3723, 0.4344])
164-
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2
151+
expected_slice = np.array([0.51825, 0.52850, 0.52543, 0.54258, 0.52304, 0.52569, 0.54363, 0.55276, 0.56878])
152+
max_diff = np.abs(expected_slice - image_slice).max()
153+
assert max_diff < 1e-3
154+
155+
156+
@nightly
157+
@require_torch_gpu
158+
class LDMTextToImagePipelineNightlyTests(unittest.TestCase):
159+
def tearDown(self):
160+
super().tearDown()
161+
gc.collect()
162+
torch.cuda.empty_cache()
163+
164+
def get_inputs(self, device, dtype=torch.float32, seed=0):
165+
generator = torch.Generator(device=device).manual_seed(seed)
166+
latents = np.random.RandomState(seed).standard_normal((1, 4, 32, 32))
167+
latents = torch.from_numpy(latents).to(device=device, dtype=dtype)
168+
inputs = {
169+
"prompt": "A painting of a squirrel eating a burger",
170+
"latents": latents,
171+
"generator": generator,
172+
"num_inference_steps": 50,
173+
"guidance_scale": 6.0,
174+
"output_type": "numpy",
175+
}
176+
return inputs
177+
178+
def test_ldm_default_ddim(self):
179+
pipe = LDMTextToImagePipeline.from_pretrained("CompVis/ldm-text2im-large-256").to(torch_device)
180+
pipe.set_progress_bar_config(disable=None)
181+
182+
inputs = self.get_inputs(torch_device)
183+
image = pipe(**inputs).images[0]
184+
185+
expected_image = load_numpy(
186+
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
187+
)
188+
max_diff = np.abs(expected_image - image).max()
189+
assert max_diff < 1e-3

0 commit comments

Comments
 (0)