Skip to content

Commit 21662d8

Browse files
authored
add support for pre-calculated prompt embeds to Stable Diffusion ONNX pipelines (huggingface#2597)
* add support for prompt embeds to SD ONNX pipeline * fix up the pipeline copies * add prompt embeds param to other ONNX pipelines * fix up prompt embeds param for SD upscaling ONNX pipeline * add missing type annotations to ONNX pipes
1 parent 79da679 commit 21662d8

5 files changed

+643
-176
lines changed

pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 173 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,15 @@ def __init__(
111111
)
112112
self.register_to_config(requires_safety_checker=requires_safety_checker)
113113

114-
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
114+
def _encode_prompt(
115+
self,
116+
prompt: Union[str, List[str]],
117+
num_images_per_prompt: Optional[int],
118+
do_classifier_free_guidance: bool,
119+
negative_prompt: Optional[str],
120+
prompt_embeds: Optional[np.ndarray] = None,
121+
negative_prompt_embeds: Optional[np.ndarray] = None,
122+
):
115123
r"""
116124
Encodes the prompt into text encoder hidden states.
117125
@@ -125,32 +133,48 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
125133
negative_prompt (`str` or `List[str]`):
126134
The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
127135
if `guidance_scale` is less than `1`).
136+
prompt_embeds (`np.ndarray`, *optional*):
137+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
138+
provided, text embeddings will be generated from `prompt` input argument.
139+
negative_prompt_embeds (`np.ndarray`, *optional*):
140+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
141+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
142+
argument.
128143
"""
129-
batch_size = len(prompt) if isinstance(prompt, list) else 1
130-
131-
# get prompt text embeddings
132-
text_inputs = self.tokenizer(
133-
prompt,
134-
padding="max_length",
135-
max_length=self.tokenizer.model_max_length,
136-
truncation=True,
137-
return_tensors="np",
138-
)
139-
text_input_ids = text_inputs.input_ids
140-
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
144+
if prompt is not None and isinstance(prompt, str):
145+
batch_size = 1
146+
elif prompt is not None and isinstance(prompt, list):
147+
batch_size = len(prompt)
148+
else:
149+
batch_size = prompt_embeds.shape[0]
141150

142-
if not np.array_equal(text_input_ids, untruncated_ids):
143-
removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1])
144-
logger.warning(
145-
"The following part of your input was truncated because CLIP can only handle sequences up to"
146-
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
151+
if prompt_embeds is None:
152+
# get prompt text embeddings
153+
text_inputs = self.tokenizer(
154+
prompt,
155+
padding="max_length",
156+
max_length=self.tokenizer.model_max_length,
157+
truncation=True,
158+
return_tensors="np",
147159
)
160+
text_input_ids = text_inputs.input_ids
161+
untruncated_ids = self.tokenizer(prompt, padding="max_length", return_tensors="np").input_ids
162+
163+
if not np.array_equal(text_input_ids, untruncated_ids):
164+
removed_text = self.tokenizer.batch_decode(
165+
untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
166+
)
167+
logger.warning(
168+
"The following part of your input was truncated because CLIP can only handle sequences up to"
169+
f" {self.tokenizer.model_max_length} tokens: {removed_text}"
170+
)
171+
172+
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
148173

149-
prompt_embeds = self.text_encoder(input_ids=text_input_ids.astype(np.int32))[0]
150174
prompt_embeds = np.repeat(prompt_embeds, num_images_per_prompt, axis=0)
151175

152176
# get unconditional embeddings for classifier free guidance
153-
if do_classifier_free_guidance:
177+
if do_classifier_free_guidance and negative_prompt_embeds is None:
154178
uncond_tokens: List[str]
155179
if negative_prompt is None:
156180
uncond_tokens = [""] * batch_size
@@ -170,7 +194,7 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
170194
else:
171195
uncond_tokens = negative_prompt
172196

173-
max_length = text_input_ids.shape[-1]
197+
max_length = prompt_embeds.shape[1]
174198
uncond_input = self.tokenizer(
175199
uncond_tokens,
176200
padding="max_length",
@@ -179,6 +203,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
179203
return_tensors="np",
180204
)
181205
negative_prompt_embeds = self.text_encoder(input_ids=uncond_input.input_ids.astype(np.int32))[0]
206+
207+
if do_classifier_free_guidance:
182208
negative_prompt_embeds = np.repeat(negative_prompt_embeds, num_images_per_prompt, axis=0)
183209

184210
# For classifier free guidance, we need to do two forward passes.
@@ -188,9 +214,56 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
188214

189215
return prompt_embeds
190216

191-
def __call__(
217+
def check_inputs(
192218
self,
193219
prompt: Union[str, List[str]],
220+
height: Optional[int],
221+
width: Optional[int],
222+
callback_steps: int,
223+
negative_prompt: Optional[str] = None,
224+
prompt_embeds: Optional[np.ndarray] = None,
225+
negative_prompt_embeds: Optional[np.ndarray] = None,
226+
):
227+
if height % 8 != 0 or width % 8 != 0:
228+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
229+
230+
if (callback_steps is None) or (
231+
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
232+
):
233+
raise ValueError(
234+
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
235+
f" {type(callback_steps)}."
236+
)
237+
238+
if prompt is not None and prompt_embeds is not None:
239+
raise ValueError(
240+
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
241+
" only forward one of the two."
242+
)
243+
elif prompt is None and prompt_embeds is None:
244+
raise ValueError(
245+
"Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
246+
)
247+
elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
248+
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
249+
250+
if negative_prompt is not None and negative_prompt_embeds is not None:
251+
raise ValueError(
252+
f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
253+
f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
254+
)
255+
256+
if prompt_embeds is not None and negative_prompt_embeds is not None:
257+
if prompt_embeds.shape != negative_prompt_embeds.shape:
258+
raise ValueError(
259+
"`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
260+
f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
261+
f" {negative_prompt_embeds.shape}."
262+
)
263+
264+
def __call__(
265+
self,
266+
prompt: Union[str, List[str]] = None,
194267
height: Optional[int] = 512,
195268
width: Optional[int] = 512,
196269
num_inference_steps: Optional[int] = 50,
@@ -200,28 +273,86 @@ def __call__(
200273
eta: Optional[float] = 0.0,
201274
generator: Optional[np.random.RandomState] = None,
202275
latents: Optional[np.ndarray] = None,
276+
prompt_embeds: Optional[np.ndarray] = None,
277+
negative_prompt_embeds: Optional[np.ndarray] = None,
203278
output_type: Optional[str] = "pil",
204279
return_dict: bool = True,
205280
callback: Optional[Callable[[int, int, np.ndarray], None]] = None,
206281
callback_steps: int = 1,
207282
):
208-
if isinstance(prompt, str):
283+
r"""
284+
Function invoked when calling the pipeline for generation.
285+
286+
Args:
287+
prompt (`str` or `List[str]`, *optional*):
288+
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
289+
instead.
290+
image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
291+
`Image`, or tensor representing an image batch which will be upscaled. *
292+
num_inference_steps (`int`, *optional*, defaults to 50):
293+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
294+
expense of slower inference.
295+
guidance_scale (`float`, *optional*, defaults to 7.5):
296+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
297+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
298+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
299+
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
300+
usually at the expense of lower image quality.
301+
negative_prompt (`str` or `List[str]`, *optional*):
302+
The prompt or prompts not to guide the image generation. If not defined, one has to pass
303+
`negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
304+
is less than `1`).
305+
num_images_per_prompt (`int`, *optional*, defaults to 1):
306+
The number of images to generate per prompt.
307+
eta (`float`, *optional*, defaults to 0.0):
308+
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
309+
[`schedulers.DDIMScheduler`], will be ignored for others.
310+
generator (`np.random.RandomState`, *optional*):
311+
One or a list of [numpy generator(s)](TODO) to make generation deterministic.
312+
latents (`np.ndarray`, *optional*):
313+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
314+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
315+
tensor will ge generated by sampling using the supplied random `generator`.
316+
prompt_embeds (`np.ndarray`, *optional*):
317+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
318+
provided, text embeddings will be generated from `prompt` input argument.
319+
negative_prompt_embeds (`np.ndarray`, *optional*):
320+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
321+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
322+
argument.
323+
output_type (`str`, *optional*, defaults to `"pil"`):
324+
The output format of the generate image. Choose between
325+
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
326+
return_dict (`bool`, *optional*, defaults to `True`):
327+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
328+
plain tuple.
329+
callback (`Callable`, *optional*):
330+
A function that will be called every `callback_steps` steps during inference. The function will be
331+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
332+
callback_steps (`int`, *optional*, defaults to 1):
333+
The frequency at which the `callback` function will be called. If not specified, the callback will be
334+
called at every step.
335+
336+
Returns:
337+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
338+
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
339+
When returning a tuple, the first element is a list with the generated images, and the second element is a
340+
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
341+
(nsfw) content, according to the `safety_checker`.
342+
"""
343+
344+
# check inputs. Raise error if not correct
345+
self.check_inputs(
346+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
347+
)
348+
349+
# define call parameters
350+
if prompt is not None and isinstance(prompt, str):
209351
batch_size = 1
210-
elif isinstance(prompt, list):
352+
elif prompt is not None and isinstance(prompt, list):
211353
batch_size = len(prompt)
212354
else:
213-
raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
214-
215-
if height % 8 != 0 or width % 8 != 0:
216-
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
217-
218-
if (callback_steps is None) or (
219-
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
220-
):
221-
raise ValueError(
222-
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
223-
f" {type(callback_steps)}."
224-
)
355+
batch_size = prompt_embeds.shape[0]
225356

226357
if generator is None:
227358
generator = np.random
@@ -232,7 +363,12 @@ def __call__(
232363
do_classifier_free_guidance = guidance_scale > 1.0
233364

234365
prompt_embeds = self._encode_prompt(
235-
prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt
366+
prompt,
367+
num_images_per_prompt,
368+
do_classifier_free_guidance,
369+
negative_prompt,
370+
prompt_embeds=prompt_embeds,
371+
negative_prompt_embeds=negative_prompt_embeds,
236372
)
237373

238374
# get the initial random noise unless the user supplied it

0 commit comments

Comments
 (0)