|
14 | 14 |
|
15 | 15 | import inspect
|
16 | 16 | import warnings
|
17 |
| -from typing import Callable, List, Optional, Union |
| 17 | +from typing import Any, Callable, Dict, List, Optional, Union |
18 | 18 |
|
19 | 19 | import numpy as np
|
20 | 20 | import PIL
|
@@ -744,6 +744,7 @@ def __call__(
|
744 | 744 | return_dict: bool = True,
|
745 | 745 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
746 | 746 | callback_steps: int = 1,
|
| 747 | + cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
747 | 748 | ):
|
748 | 749 | r"""
|
749 | 750 | Function invoked when calling the pipeline for generation.
|
@@ -815,7 +816,10 @@ def __call__(
|
815 | 816 | callback_steps (`int`, *optional*, defaults to 1):
|
816 | 817 | The frequency at which the `callback` function will be called. If not specified, the callback will be
|
817 | 818 | called at every step.
|
818 |
| -
|
| 819 | + cross_attention_kwargs (`dict`, *optional*): |
| 820 | + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under |
| 821 | + `self.processor` in |
| 822 | + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). |
819 | 823 | Examples:
|
820 | 824 |
|
821 | 825 | ```py
|
@@ -966,9 +970,13 @@ def __call__(
|
966 | 970 | latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
|
967 | 971 |
|
968 | 972 | # predict the noise residual
|
969 |
| - noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[ |
970 |
| - 0 |
971 |
| - ] |
| 973 | + noise_pred = self.unet( |
| 974 | + latent_model_input, |
| 975 | + t, |
| 976 | + encoder_hidden_states=prompt_embeds, |
| 977 | + cross_attention_kwargs=cross_attention_kwargs, |
| 978 | + return_dict=False, |
| 979 | + )[0] |
972 | 980 |
|
973 | 981 | # perform guidance
|
974 | 982 | if do_classifier_free_guidance:
|
|
0 commit comments