@@ -111,7 +111,15 @@ def __init__(
111
111
)
112
112
self .register_to_config (requires_safety_checker = requires_safety_checker )
113
113
114
- def _encode_prompt (self , prompt , num_images_per_prompt , do_classifier_free_guidance , negative_prompt ):
114
+ def _encode_prompt (
115
+ self ,
116
+ prompt : Union [str , List [str ]],
117
+ num_images_per_prompt : Optional [int ],
118
+ do_classifier_free_guidance : bool ,
119
+ negative_prompt : Optional [str ],
120
+ prompt_embeds : Optional [np .ndarray ] = None ,
121
+ negative_prompt_embeds : Optional [np .ndarray ] = None ,
122
+ ):
115
123
r"""
116
124
Encodes the prompt into text encoder hidden states.
117
125
@@ -125,32 +133,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
125
133
negative_prompt (`str` or `List[str]`):
126
134
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
127
135
if `guidance_scale` is less than `1`).
136
+ prompt_embeds (`np.ndarray`, *optional*):
137
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
138
+ provided, text embeddings will be generated from `prompt` input argument.
139
+ negative_prompt_embeds (`np.ndarray`, *optional*):
140
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
141
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
142
+ argument.
128
143
"""
129
- batch_size = len (prompt ) if isinstance (prompt , list ) else 1
130
-
131
- # get prompt text embeddings
132
- text_inputs = self .tokenizer (
133
- prompt ,
134
- padding = "max_length" ,
135
- max_length = self .tokenizer .model_max_length ,
136
- truncation = True ,
137
- return_tensors = "np" ,
138
- )
139
- text_input_ids = text_inputs .input_ids
140
- untruncated_ids = self .tokenizer (prompt , padding = "max_length" , return_tensors = "np" ).input_ids
144
+ if prompt is not None and isinstance (prompt , str ):
145
+ batch_size = 1
146
+ elif prompt is not None and isinstance (prompt , list ):
147
+ batch_size = len (prompt )
148
+ else :
149
+ batch_size = prompt_embeds .shape [0 ]
141
150
142
- if not np .array_equal (text_input_ids , untruncated_ids ):
143
- removed_text = self .tokenizer .batch_decode (untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ])
144
- logger .warning (
145
- "The following part of your input was truncated because CLIP can only handle sequences up to"
146
- f" { self .tokenizer .model_max_length } tokens: { removed_text } "
151
+ if prompt_embeds is None :
152
+ # get prompt text embeddings
153
+ text_inputs = self .tokenizer (
154
+ prompt ,
155
+ padding = "max_length" ,
156
+ max_length = self .tokenizer .model_max_length ,
157
+ truncation = True ,
158
+ return_tensors = "np" ,
147
159
)
160
+ text_input_ids = text_inputs .input_ids
161
+ untruncated_ids = self .tokenizer (prompt , padding = "max_length" , return_tensors = "np" ).input_ids
162
+
163
+ if not np .array_equal (text_input_ids , untruncated_ids ):
164
+ removed_text = self .tokenizer .batch_decode (
165
+ untruncated_ids [:, self .tokenizer .model_max_length - 1 : - 1 ]
166
+ )
167
+ logger .warning (
168
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
169
+ f" { self .tokenizer .model_max_length } tokens: { removed_text } "
170
+ )
171
+
172
+ prompt_embeds = self .text_encoder (input_ids = text_input_ids .astype (np .int32 ))[0 ]
148
173
149
- prompt_embeds = self .text_encoder (input_ids = text_input_ids .astype (np .int32 ))[0 ]
150
174
prompt_embeds = np .repeat (prompt_embeds , num_images_per_prompt , axis = 0 )
151
175
152
176
# get unconditional embeddings for classifier free guidance
153
- if do_classifier_free_guidance :
177
+ if do_classifier_free_guidance and negative_prompt_embeds is None :
154
178
uncond_tokens : List [str ]
155
179
if negative_prompt is None :
156
180
uncond_tokens = ["" ] * batch_size
@@ -170,7 +194,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
170
194
else :
171
195
uncond_tokens = negative_prompt
172
196
173
- max_length = text_input_ids .shape [- 1 ]
197
+ max_length = prompt_embeds .shape [1 ]
174
198
uncond_input = self .tokenizer (
175
199
uncond_tokens ,
176
200
padding = "max_length" ,
@@ -179,6 +203,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
179
203
return_tensors = "np" ,
180
204
)
181
205
negative_prompt_embeds = self .text_encoder (input_ids = uncond_input .input_ids .astype (np .int32 ))[0 ]
206
+
207
+ if do_classifier_free_guidance :
182
208
negative_prompt_embeds = np .repeat (negative_prompt_embeds , num_images_per_prompt , axis = 0 )
183
209
184
210
# For classifier free guidance, we need to do two forward passes.
@@ -188,9 +214,56 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
188
214
189
215
return prompt_embeds
190
216
191
- def __call__ (
217
+ def check_inputs (
192
218
self ,
193
219
prompt : Union [str , List [str ]],
220
+ height : Optional [int ],
221
+ width : Optional [int ],
222
+ callback_steps : int ,
223
+ negative_prompt : Optional [str ] = None ,
224
+ prompt_embeds : Optional [np .ndarray ] = None ,
225
+ negative_prompt_embeds : Optional [np .ndarray ] = None ,
226
+ ):
227
+ if height % 8 != 0 or width % 8 != 0 :
228
+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
229
+
230
+ if (callback_steps is None ) or (
231
+ callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
232
+ ):
233
+ raise ValueError (
234
+ f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
235
+ f" { type (callback_steps )} ."
236
+ )
237
+
238
+ if prompt is not None and prompt_embeds is not None :
239
+ raise ValueError (
240
+ f"Cannot forward both `prompt`: { prompt } and `prompt_embeds`: { prompt_embeds } . Please make sure to"
241
+ " only forward one of the two."
242
+ )
243
+ elif prompt is None and prompt_embeds is None :
244
+ raise ValueError (
245
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
246
+ )
247
+ elif prompt is not None and (not isinstance (prompt , str ) and not isinstance (prompt , list )):
248
+ raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
249
+
250
+ if negative_prompt is not None and negative_prompt_embeds is not None :
251
+ raise ValueError (
252
+ f"Cannot forward both `negative_prompt`: { negative_prompt } and `negative_prompt_embeds`:"
253
+ f" { negative_prompt_embeds } . Please make sure to only forward one of the two."
254
+ )
255
+
256
+ if prompt_embeds is not None and negative_prompt_embeds is not None :
257
+ if prompt_embeds .shape != negative_prompt_embeds .shape :
258
+ raise ValueError (
259
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
260
+ f" got: `prompt_embeds` { prompt_embeds .shape } != `negative_prompt_embeds`"
261
+ f" { negative_prompt_embeds .shape } ."
262
+ )
263
+
264
+ def __call__ (
265
+ self ,
266
+ prompt : Union [str , List [str ]] = None ,
194
267
height : Optional [int ] = 512 ,
195
268
width : Optional [int ] = 512 ,
196
269
num_inference_steps : Optional [int ] = 50 ,
@@ -200,28 +273,86 @@ def __call__(
200
273
eta : Optional [float ] = 0.0 ,
201
274
generator : Optional [np .random .RandomState ] = None ,
202
275
latents : Optional [np .ndarray ] = None ,
276
+ prompt_embeds : Optional [np .ndarray ] = None ,
277
+ negative_prompt_embeds : Optional [np .ndarray ] = None ,
203
278
output_type : Optional [str ] = "pil" ,
204
279
return_dict : bool = True ,
205
280
callback : Optional [Callable [[int , int , np .ndarray ], None ]] = None ,
206
281
callback_steps : int = 1 ,
207
282
):
208
- if isinstance (prompt , str ):
283
+ r"""
284
+ Function invoked when calling the pipeline for generation.
285
+
286
+ Args:
287
+ prompt (`str` or `List[str]`, *optional*):
288
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
289
+ instead.
290
+ image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
291
+ `Image`, or tensor representing an image batch which will be upscaled. *
292
+ num_inference_steps (`int`, *optional*, defaults to 50):
293
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
294
+ expense of slower inference.
295
+ guidance_scale (`float`, *optional*, defaults to 7.5):
296
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
297
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
298
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
299
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
300
+ usually at the expense of lower image quality.
301
+ negative_prompt (`str` or `List[str]`, *optional*):
302
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
303
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
304
+ is less than `1`).
305
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
306
+ The number of images to generate per prompt.
307
+ eta (`float`, *optional*, defaults to 0.0):
308
+ Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
309
+ [`schedulers.DDIMScheduler`], will be ignored for others.
310
+ generator (`np.random.RandomState`, *optional*):
311
+ One or a list of [numpy generator(s)](TODO) to make generation deterministic.
312
+ latents (`np.ndarray`, *optional*):
313
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
314
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
315
+ tensor will ge generated by sampling using the supplied random `generator`.
316
+ prompt_embeds (`np.ndarray`, *optional*):
317
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
318
+ provided, text embeddings will be generated from `prompt` input argument.
319
+ negative_prompt_embeds (`np.ndarray`, *optional*):
320
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
321
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
322
+ argument.
323
+ output_type (`str`, *optional*, defaults to `"pil"`):
324
+ The output format of the generate image. Choose between
325
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
326
+ return_dict (`bool`, *optional*, defaults to `True`):
327
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
328
+ plain tuple.
329
+ callback (`Callable`, *optional*):
330
+ A function that will be called every `callback_steps` steps during inference. The function will be
331
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
332
+ callback_steps (`int`, *optional*, defaults to 1):
333
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
334
+ called at every step.
335
+
336
+ Returns:
337
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
338
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
339
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
340
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
341
+ (nsfw) content, according to the `safety_checker`.
342
+ """
343
+
344
+ # check inputs. Raise error if not correct
345
+ self .check_inputs (
346
+ prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
347
+ )
348
+
349
+ # define call parameters
350
+ if prompt is not None and isinstance (prompt , str ):
209
351
batch_size = 1
210
- elif isinstance (prompt , list ):
352
+ elif prompt is not None and isinstance (prompt , list ):
211
353
batch_size = len (prompt )
212
354
else :
213
- raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
214
-
215
- if height % 8 != 0 or width % 8 != 0 :
216
- raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
217
-
218
- if (callback_steps is None ) or (
219
- callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
220
- ):
221
- raise ValueError (
222
- f"`callback_steps` has to be a positive integer but is { callback_steps } of type"
223
- f" { type (callback_steps )} ."
224
- )
355
+ batch_size = prompt_embeds .shape [0 ]
225
356
226
357
if generator is None :
227
358
generator = np .random
@@ -232,7 +363,12 @@ def __call__(
232
363
do_classifier_free_guidance = guidance_scale > 1.0
233
364
234
365
prompt_embeds = self ._encode_prompt (
235
- prompt , num_images_per_prompt , do_classifier_free_guidance , negative_prompt
366
+ prompt ,
367
+ num_images_per_prompt ,
368
+ do_classifier_free_guidance ,
369
+ negative_prompt ,
370
+ prompt_embeds = prompt_embeds ,
371
+ negative_prompt_embeds = negative_prompt_embeds ,
236
372
)
237
373
238
374
# get the initial random noise unless the user supplied it
0 commit comments