Skip to content

Commit 3992b0a

Browse files
authored
Add inpaint lora scale support (huggingface#3460)
* add inpaint lora scale support * add inpaint lora scale test --------- Co-authored-by: yueyang.hyy <[email protected]>
1 parent 691d40d commit 3992b0a

File tree

1 file changed

+13
-5
lines changed

1 file changed

+13
-5
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
import warnings
17-
from typing import Callable, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Union
1818

1919
import numpy as np
2020
import PIL
@@ -744,6 +744,7 @@ def __call__(
744744
return_dict: bool = True,
745745
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
746746
callback_steps: int = 1,
747+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
747748
):
748749
r"""
749750
Function invoked when calling the pipeline for generation.
@@ -815,7 +816,10 @@ def __call__(
815816
callback_steps (`int`, *optional*, defaults to 1):
816817
The frequency at which the `callback` function will be called. If not specified, the callback will be
817818
called at every step.
818-
819+
cross_attention_kwargs (`dict`, *optional*):
820+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
821+
`self.processor` in
822+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
819823
Examples:
820824
821825
```py
@@ -966,9 +970,13 @@ def __call__(
966970
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
967971

968972
# predict the noise residual
969-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
970-
0
971-
]
973+
noise_pred = self.unet(
974+
latent_model_input,
975+
t,
976+
encoder_hidden_states=prompt_embeds,
977+
cross_attention_kwargs=cross_attention_kwargs,
978+
return_dict=False,
979+
)[0]
972980

973981
# perform guidance
974982
if do_classifier_free_guidance:

0 commit comments

Comments
 (0)