@@ -241,7 +241,45 @@ from diffusers import StableDiffusionPipeline
241
241
from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
242
242
from diffusers.configuration_utils import register_to_config
243
243
import torch
244
- from typing import Any, Dict, Optional
244
+ from typing import Any, Dict, Tuple, Union
245
+
246
+
247
+ class SDPromptSchedulingCallback (PipelineCallback ):
248
+ @register_to_config
249
+ def __init__ (
250
+ self ,
251
+ encoded_prompt : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
252
+ cutoff_step_ratio = None ,
253
+ cutoff_step_index = None ,
254
+ ):
255
+ super ().__init__ (
256
+ cutoff_step_ratio = cutoff_step_ratio, cutoff_step_index = cutoff_step_index
257
+ )
258
+
259
+ tensor_inputs = [" prompt_embeds" ]
260
+
261
+ def callback_fn (
262
+ self , pipeline , step_index , timestep , callback_kwargs
263
+ ) -> Dict[str , Any]:
264
+ cutoff_step_ratio = self .config.cutoff_step_ratio
265
+ cutoff_step_index = self .config.cutoff_step_index
266
+ if isinstance (self .config.encoded_prompt, tuple ):
267
+ prompt_embeds, negative_prompt_embeds = self .config.encoded_prompt
268
+ else :
269
+ prompt_embeds = self .config.encoded_prompt
270
+
271
+ # Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
272
+ cutoff_step = (
273
+ cutoff_step_index
274
+ if cutoff_step_index is not None
275
+ else int (pipeline.num_timesteps * cutoff_step_ratio)
276
+ )
277
+
278
+ if step_index == cutoff_step:
279
+ if pipeline.do_classifier_free_guidance:
280
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
281
+ callback_kwargs[self .tensor_inputs[0 ]] = prompt_embeds
282
+ return callback_kwargs
245
283
246
284
247
285
pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
@@ -253,28 +291,73 @@ pipeline: StableDiffusionPipeline = StableDiffusionPipeline.from_pretrained(
253
291
pipeline.safety_checker = None
254
292
pipeline.requires_safety_checker = False
255
293
294
+ callback = MultiPipelineCallbacks(
295
+ [
296
+ SDPromptSchedulingCallback(
297
+ encoded_prompt = pipeline.encode_prompt(
298
+ prompt = f " prompt { index} " ,
299
+ negative_prompt = f " negative prompt { index} " ,
300
+ device = pipeline._execution_device,
301
+ num_images_per_prompt = 1 ,
302
+ # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
303
+ do_classifier_free_guidance = True ,
304
+ ),
305
+ cutoff_step_index = index,
306
+ ) for index in range (1 , 20 )
307
+ ]
308
+ )
309
+
310
+ image = pipeline(
311
+ prompt = " prompt"
312
+ negative_prompt = " negative prompt" ,
313
+ callback_on_step_end = callback,
314
+ callback_on_step_end_tensor_inputs = [" prompt_embeds" ],
315
+ ).images[0 ]
316
+ torch.cuda.empty_cache()
317
+ image.save(' image.png' )
318
+ ```
256
319
257
- class SDPromptScheduleCallback (PipelineCallback ):
320
+ ``` python
321
+ from diffusers import StableDiffusionXLPipeline
322
+ from diffusers.callbacks import PipelineCallback, MultiPipelineCallbacks
323
+ from diffusers.configuration_utils import register_to_config
324
+ import torch
325
+ from typing import Any, Dict, Tuple, Union
326
+
327
+
328
+ class SDXLPromptSchedulingCallback (PipelineCallback ):
258
329
@register_to_config
259
330
def __init__ (
260
331
self ,
261
- prompt : str ,
262
- negative_prompt : Optional[ str ] = None ,
263
- num_images_per_prompt : int = 1 ,
264
- cutoff_step_ratio = 1.0 ,
332
+ encoded_prompt : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
333
+ add_text_embeds : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
334
+ add_time_ids : Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]] ,
335
+ cutoff_step_ratio = None ,
265
336
cutoff_step_index = None ,
266
337
):
267
338
super ().__init__ (
268
339
cutoff_step_ratio = cutoff_step_ratio, cutoff_step_index = cutoff_step_index
269
340
)
270
341
271
- tensor_inputs = [" prompt_embeds" ]
342
+ tensor_inputs = [" prompt_embeds" , " add_text_embeds " , " add_time_ids " ]
272
343
273
344
def callback_fn (
274
345
self , pipeline , step_index , timestep , callback_kwargs
275
346
) -> Dict[str , Any]:
276
347
cutoff_step_ratio = self .config.cutoff_step_ratio
277
348
cutoff_step_index = self .config.cutoff_step_index
349
+ if isinstance (self .config.encoded_prompt, tuple ):
350
+ prompt_embeds, negative_prompt_embeds = self .config.encoded_prompt
351
+ else :
352
+ prompt_embeds = self .config.encoded_prompt
353
+ if isinstance (self .config.add_text_embeds, tuple ):
354
+ add_text_embeds, negative_add_text_embeds = self .config.add_text_embeds
355
+ else :
356
+ add_text_embeds = self .config.add_text_embeds
357
+ if isinstance (self .config.add_time_ids, tuple ):
358
+ add_time_ids, negative_add_time_ids = self .config.add_time_ids
359
+ else :
360
+ add_time_ids = self .config.add_time_ids
278
361
279
362
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
280
363
cutoff_step = (
@@ -284,34 +367,73 @@ class SDPromptScheduleCallback(PipelineCallback):
284
367
)
285
368
286
369
if step_index == cutoff_step:
287
- prompt_embeds, negative_prompt_embeds = pipeline.encode_prompt(
288
- prompt = self .config.prompt,
289
- negative_prompt = self .config.negative_prompt,
290
- device = pipeline._execution_device,
291
- num_images_per_prompt = self .config.num_images_per_prompt,
292
- do_classifier_free_guidance = pipeline.do_classifier_free_guidance,
293
- )
294
370
if pipeline.do_classifier_free_guidance:
295
371
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
372
+ add_text_embeds = torch.cat([negative_add_text_embeds, add_text_embeds])
373
+ add_time_ids = torch.cat([negative_add_time_ids, add_time_ids])
296
374
callback_kwargs[self .tensor_inputs[0 ]] = prompt_embeds
375
+ callback_kwargs[self .tensor_inputs[1 ]] = add_text_embeds
376
+ callback_kwargs[self .tensor_inputs[2 ]] = add_time_ids
297
377
return callback_kwargs
298
378
299
- callback = MultiPipelineCallbacks(
300
- [
301
- SDPromptScheduleCallback(
302
- prompt = " Official portrait of a smiling world war ii general, female, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski" ,
303
- negative_prompt = " Deformed, ugly, bad anatomy" ,
304
- cutoff_step_ratio = 0.25 ,
379
+
380
+ pipeline: StableDiffusionXLPipeline = StableDiffusionXLPipeline.from_pretrained(
381
+ " stabilityai/stable-diffusion-xl-base-1.0" ,
382
+ torch_dtype = torch.float16,
383
+ variant = " fp16" ,
384
+ use_safetensors = True ,
385
+ ).to(" cuda" )
386
+
387
+ callbacks = []
388
+ for index in range (1 , 20 ):
389
+ (
390
+ prompt_embeds,
391
+ negative_prompt_embeds,
392
+ pooled_prompt_embeds,
393
+ negative_pooled_prompt_embeds,
394
+ ) = pipeline.encode_prompt(
395
+ prompt = f " prompt { index} " ,
396
+ negative_prompt = f " prompt { index} " ,
397
+ device = pipeline._execution_device,
398
+ num_images_per_prompt = 1 ,
399
+ # pipeline.do_classifier_free_guidance can't be accessed until after pipeline is ran
400
+ do_classifier_free_guidance = True ,
401
+ )
402
+ text_encoder_projection_dim = int (pooled_prompt_embeds.shape[- 1 ])
403
+ add_time_ids = pipeline._get_add_time_ids(
404
+ (1024 , 1024 ),
405
+ (0 , 0 ),
406
+ (1024 , 1024 ),
407
+ dtype = prompt_embeds.dtype,
408
+ text_encoder_projection_dim = text_encoder_projection_dim,
409
+ )
410
+ negative_add_time_ids = pipeline._get_add_time_ids(
411
+ (1024 , 1024 ),
412
+ (0 , 0 ),
413
+ (1024 , 1024 ),
414
+ dtype = prompt_embeds.dtype,
415
+ text_encoder_projection_dim = text_encoder_projection_dim,
416
+ )
417
+ callbacks.append(
418
+ SDXLPromptSchedulingCallback(
419
+ encoded_prompt = (prompt_embeds, negative_prompt_embeds),
420
+ add_text_embeds = (pooled_prompt_embeds, negative_pooled_prompt_embeds),
421
+ add_time_ids = (add_time_ids, negative_add_time_ids),
422
+ cutoff_step_index = index,
305
423
)
306
- ]
307
- )
424
+ )
425
+
426
+
427
+ callback = MultiPipelineCallbacks(callbacks)
308
428
309
429
image = pipeline(
310
- prompt = " Official portrait of a smiling world war ii general, male, cheerful, happy, detailed face, 20th century, highly detailed, cinematic lighting, digital art painting by Greg Rutkowski " ,
311
- negative_prompt = " Deformed, ugly, bad anatomy " ,
430
+ prompt = " prompt " ,
431
+ negative_prompt = " negative prompt " ,
312
432
callback_on_step_end = callback,
313
- callback_on_step_end_tensor_inputs = [" prompt_embeds" ],
433
+ callback_on_step_end_tensor_inputs = [
434
+ " prompt_embeds" ,
435
+ " add_text_embeds" ,
436
+ " add_time_ids" ,
437
+ ],
314
438
).images[0 ]
315
- torch.cuda.empty_cache()
316
- image.save(' image.png' )
317
439
```
0 commit comments