13
13
# See the License for the specific language governing permissions and
14
14
# limitations under the License.
15
15
16
+ import gc
16
17
import unittest
17
18
18
19
import numpy as np
19
20
import torch
20
21
21
22
from diffusers import AutoencoderKL , DDIMScheduler , LDMTextToImagePipeline , UNet2DConditionModel
22
- from diffusers .utils .testing_utils import require_torch , slow , torch_device
23
+ from diffusers .utils .testing_utils import load_numpy , nightly , require_torch_gpu , slow , torch_device
23
24
from transformers import CLIPTextConfig , CLIPTextModel , CLIPTokenizer
24
25
26
+ from ...test_pipelines_common import PipelineTesterMixin
27
+
25
28
26
29
torch .backends .cuda .matmul .allow_tf32 = False
27
30
28
31
29
- class LDMTextToImagePipelineFastTests (unittest .TestCase ):
30
- @property
31
- def dummy_cond_unet (self ):
32
+ class LDMTextToImagePipelineFastTests (PipelineTesterMixin , unittest .TestCase ):
33
+ pipeline_class = LDMTextToImagePipeline
34
+ test_cpu_offload = False
35
+
36
+ def get_dummy_components (self ):
32
37
torch .manual_seed (0 )
33
- model = UNet2DConditionModel (
38
+ unet = UNet2DConditionModel (
34
39
block_out_channels = (32 , 64 ),
35
40
layers_per_block = 2 ,
36
41
sample_size = 32 ,
@@ -40,25 +45,24 @@ def dummy_cond_unet(self):
40
45
up_block_types = ("CrossAttnUpBlock2D" , "UpBlock2D" ),
41
46
cross_attention_dim = 32 ,
42
47
)
43
- return model
44
-
45
- @property
46
- def dummy_vae (self ):
48
+ scheduler = DDIMScheduler (
49
+ beta_start = 0.00085 ,
50
+ beta_end = 0.012 ,
51
+ beta_schedule = "scaled_linear" ,
52
+ clip_sample = False ,
53
+ set_alpha_to_one = False ,
54
+ )
47
55
torch .manual_seed (0 )
48
- model = AutoencoderKL (
49
- block_out_channels = [ 32 , 64 ] ,
56
+ vae = AutoencoderKL (
57
+ block_out_channels = ( 32 , 64 ) ,
50
58
in_channels = 3 ,
51
59
out_channels = 3 ,
52
- down_block_types = [ "DownEncoderBlock2D" , "DownEncoderBlock2D" ] ,
53
- up_block_types = [ "UpDecoderBlock2D" , "UpDecoderBlock2D" ] ,
60
+ down_block_types = ( "DownEncoderBlock2D" , "DownEncoderBlock2D" ) ,
61
+ up_block_types = ( "UpDecoderBlock2D" , "UpDecoderBlock2D" ) ,
54
62
latent_channels = 4 ,
55
63
)
56
- return model
57
-
58
- @property
59
- def dummy_text_encoder (self ):
60
64
torch .manual_seed (0 )
61
- config = CLIPTextConfig (
65
+ text_encoder_config = CLIPTextConfig (
62
66
bos_token_id = 0 ,
63
67
eos_token_id = 2 ,
64
68
hidden_size = 32 ,
@@ -69,96 +73,117 @@ def dummy_text_encoder(self):
69
73
pad_token_id = 1 ,
70
74
vocab_size = 1000 ,
71
75
)
72
- return CLIPTextModel (config )
73
-
74
- def test_inference_text2img (self ):
75
- if torch_device != "cpu" :
76
- return
77
-
78
- unet = self .dummy_cond_unet
79
- scheduler = DDIMScheduler ()
80
- vae = self .dummy_vae
81
- bert = self .dummy_text_encoder
76
+ text_encoder = CLIPTextModel (text_encoder_config )
82
77
tokenizer = CLIPTokenizer .from_pretrained ("hf-internal-testing/tiny-random-clip" )
83
78
84
- ldm = LDMTextToImagePipeline (vqvae = vae , bert = bert , tokenizer = tokenizer , unet = unet , scheduler = scheduler )
85
- ldm .to (torch_device )
86
- ldm .set_progress_bar_config (disable = None )
87
-
88
- prompt = "A painting of a squirrel eating a burger"
89
-
90
- # Warmup pass when using mps (see #372)
91
- if torch_device == "mps" :
92
- generator = torch .manual_seed (0 )
93
- _ = ldm (
94
- [prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 1 , output_type = "numpy"
95
- ).images
96
-
97
- device = torch_device if torch_device != "mps" else "cpu"
98
- generator = torch .Generator (device = device ).manual_seed (0 )
99
-
100
- image = ldm (
101
- [prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 2 , output_type = "numpy"
102
- ).images
103
-
104
- device = torch_device if torch_device != "mps" else "cpu"
105
- generator = torch .Generator (device = device ).manual_seed (0 )
106
-
107
- image_from_tuple = ldm (
108
- [prompt ],
109
- generator = generator ,
110
- guidance_scale = 6.0 ,
111
- num_inference_steps = 2 ,
112
- output_type = "numpy" ,
113
- return_dict = False ,
114
- )[0 ]
79
+ components = {
80
+ "unet" : unet ,
81
+ "scheduler" : scheduler ,
82
+ "vqvae" : vae ,
83
+ "bert" : text_encoder ,
84
+ "tokenizer" : tokenizer ,
85
+ }
86
+ return components
87
+
88
+ def get_dummy_inputs (self , device , seed = 0 ):
89
+ if str (device ).startswith ("mps" ):
90
+ generator = torch .manual_seed (seed )
91
+ else :
92
+ generator = torch .Generator (device = device ).manual_seed (seed )
93
+ inputs = {
94
+ "prompt" : "A painting of a squirrel eating a burger" ,
95
+ "generator" : generator ,
96
+ "num_inference_steps" : 2 ,
97
+ "guidance_scale" : 6.0 ,
98
+ "output_type" : "numpy" ,
99
+ }
100
+ return inputs
115
101
116
- image_slice = image [0 , - 3 :, - 3 :, - 1 ]
117
- image_from_tuple_slice = image_from_tuple [0 , - 3 :, - 3 :, - 1 ]
118
-
119
- assert image .shape == (1 , 16 , 16 , 3 )
120
- expected_slice = np .array ([0.6806 , 0.5454 , 0.5638 , 0.4893 , 0.4656 , 0.4257 , 0.6248 , 0.5217 , 0.5498 ])
121
- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
122
- assert np .abs (image_from_tuple_slice .flatten () - expected_slice ).max () < 1e-2
123
-
124
-
125
- @slow
126
- @require_torch
127
- class LDMTextToImagePipelineIntegrationTests (unittest .TestCase ):
128
102
def test_inference_text2img (self ):
129
- ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
130
- ldm .to (torch_device )
131
- ldm .set_progress_bar_config (disable = None )
132
-
133
- prompt = "A painting of a squirrel eating a burger"
134
-
135
- device = torch_device if torch_device != "mps" else "cpu"
136
- generator = torch .Generator (device = device ).manual_seed (0 )
103
+ device = "cpu" # ensure determinism for the device-dependent torch.Generator
137
104
138
- image = ldm (
139
- [prompt ], generator = generator , guidance_scale = 6.0 , num_inference_steps = 20 , output_type = "numpy"
140
- ).images
105
+ components = self .get_dummy_components ()
106
+ pipe = LDMTextToImagePipeline (** components )
107
+ pipe .to (device )
108
+ pipe .set_progress_bar_config (disable = None )
141
109
110
+ inputs = self .get_dummy_inputs (device )
111
+ image = pipe (** inputs ).images
142
112
image_slice = image [0 , - 3 :, - 3 :, - 1 ]
143
113
144
- assert image .shape == (1 , 256 , 256 , 3 )
145
- expected_slice = np .array ([0.9256 , 0.9340 , 0.8933 , 0.9361 , 0.9113 , 0.8727 , 0.9122 , 0.8745 , 0.8099 ])
146
- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
147
-
148
- def test_inference_text2img_fast (self ):
149
- ldm = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" )
150
- ldm .to (torch_device )
151
- ldm .set_progress_bar_config (disable = None )
152
-
153
- prompt = "A painting of a squirrel eating a burger"
114
+ assert image .shape == (1 , 16 , 16 , 3 )
115
+ expected_slice = np .array ([0.59450 , 0.64078 , 0.55509 , 0.51229 , 0.69640 , 0.36960 , 0.59296 , 0.60801 , 0.49332 ])
154
116
155
- device = torch_device if torch_device != "mps" else "cpu"
156
- generator = torch .Generator (device = device ).manual_seed (0 )
117
+ assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-3
157
118
158
- image = ldm (prompt , generator = generator , num_inference_steps = 1 , output_type = "numpy" ).images
159
119
160
- image_slice = image [0 , - 3 :, - 3 :, - 1 ]
120
+ @slow
121
+ @require_torch_gpu
122
+ class LDMTextToImagePipelineSlowTests (unittest .TestCase ):
123
+ def tearDown (self ):
124
+ super ().tearDown ()
125
+ gc .collect ()
126
+ torch .cuda .empty_cache ()
127
+
128
+ def get_inputs (self , device , dtype = torch .float32 , seed = 0 ):
129
+ generator = torch .Generator (device = device ).manual_seed (seed )
130
+ latents = np .random .RandomState (seed ).standard_normal ((1 , 4 , 32 , 32 ))
131
+ latents = torch .from_numpy (latents ).to (device = device , dtype = dtype )
132
+ inputs = {
133
+ "prompt" : "A painting of a squirrel eating a burger" ,
134
+ "latents" : latents ,
135
+ "generator" : generator ,
136
+ "num_inference_steps" : 3 ,
137
+ "guidance_scale" : 6.0 ,
138
+ "output_type" : "numpy" ,
139
+ }
140
+ return inputs
141
+
142
+ def test_ldm_default_ddim (self ):
143
+ pipe = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" ).to (torch_device )
144
+ pipe .set_progress_bar_config (disable = None )
145
+
146
+ inputs = self .get_inputs (torch_device )
147
+ image = pipe (** inputs ).images
148
+ image_slice = image [0 , - 3 :, - 3 :, - 1 ].flatten ()
161
149
162
150
assert image .shape == (1 , 256 , 256 , 3 )
163
- expected_slice = np .array ([0.3163 , 0.8670 , 0.6465 , 0.1865 , 0.6291 , 0.5139 , 0.2824 , 0.3723 , 0.4344 ])
164
- assert np .abs (image_slice .flatten () - expected_slice ).max () < 1e-2
151
+ expected_slice = np .array ([0.51825 , 0.52850 , 0.52543 , 0.54258 , 0.52304 , 0.52569 , 0.54363 , 0.55276 , 0.56878 ])
152
+ max_diff = np .abs (expected_slice - image_slice ).max ()
153
+ assert max_diff < 1e-3
154
+
155
+
156
+ @nightly
157
+ @require_torch_gpu
158
+ class LDMTextToImagePipelineNightlyTests (unittest .TestCase ):
159
+ def tearDown (self ):
160
+ super ().tearDown ()
161
+ gc .collect ()
162
+ torch .cuda .empty_cache ()
163
+
164
+ def get_inputs (self , device , dtype = torch .float32 , seed = 0 ):
165
+ generator = torch .Generator (device = device ).manual_seed (seed )
166
+ latents = np .random .RandomState (seed ).standard_normal ((1 , 4 , 32 , 32 ))
167
+ latents = torch .from_numpy (latents ).to (device = device , dtype = dtype )
168
+ inputs = {
169
+ "prompt" : "A painting of a squirrel eating a burger" ,
170
+ "latents" : latents ,
171
+ "generator" : generator ,
172
+ "num_inference_steps" : 50 ,
173
+ "guidance_scale" : 6.0 ,
174
+ "output_type" : "numpy" ,
175
+ }
176
+ return inputs
177
+
178
+ def test_ldm_default_ddim (self ):
179
+ pipe = LDMTextToImagePipeline .from_pretrained ("CompVis/ldm-text2im-large-256" ).to (torch_device )
180
+ pipe .set_progress_bar_config (disable = None )
181
+
182
+ inputs = self .get_inputs (torch_device )
183
+ image = pipe (** inputs ).images [0 ]
184
+
185
+ expected_image = load_numpy (
186
+ "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/ldm_text2img/ldm_large_256_ddim.npy"
187
+ )
188
+ max_diff = np .abs (expected_image - image ).max ()
189
+ assert max_diff < 1e-3
0 commit comments