Skip to content

Commit d88116c

Browse files
Glaceon-HyyJimmy
authored and
Jimmy
committed
Add inpaint lora scale support (huggingface#3460)
* add inpaint lora scale support * add inpaint lora scale test --------- Co-authored-by: yueyang.hyy <[email protected]>
1 parent 4bb129d commit d88116c

File tree

2 files changed

+48
-5
lines changed

2 files changed

+48
-5
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
import inspect
1616
import warnings
17-
from typing import Callable, List, Optional, Union
17+
from typing import Any, Callable, Dict, List, Optional, Union
1818

1919
import numpy as np
2020
import PIL
@@ -744,6 +744,7 @@ def __call__(
744744
return_dict: bool = True,
745745
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
746746
callback_steps: int = 1,
747+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
747748
):
748749
r"""
749750
Function invoked when calling the pipeline for generation.
@@ -815,7 +816,10 @@ def __call__(
815816
callback_steps (`int`, *optional*, defaults to 1):
816817
The frequency at which the `callback` function will be called. If not specified, the callback will be
817818
called at every step.
818-
819+
cross_attention_kwargs (`dict`, *optional*):
820+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
821+
`self.processor` in
822+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
819823
Examples:
820824
821825
```py
@@ -966,9 +970,13 @@ def __call__(
966970
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)
967971

968972
# predict the noise residual
969-
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
970-
0
971-
]
973+
noise_pred = self.unet(
974+
latent_model_input,
975+
t,
976+
encoder_hidden_states=prompt_embeds,
977+
cross_attention_kwargs=cross_attention_kwargs,
978+
return_dict=False,
979+
)[0]
972980

973981
# perform guidance
974982
if do_classifier_free_guidance:

tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
3636
from diffusers.utils.testing_utils import require_torch_gpu
3737

38+
from ...models.test_models_unet_2d_condition import create_lora_layers
3839
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
3940
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
4041

@@ -155,6 +156,40 @@ def test_stable_diffusion_inpaint_image_tensor(self):
155156
assert out_pil.shape == (1, 64, 64, 3)
156157
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
157158

159+
def test_stable_diffusion_inpaint_lora(self):
160+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
161+
162+
components = self.get_dummy_components()
163+
sd_pipe = StableDiffusionInpaintPipeline(**components)
164+
sd_pipe = sd_pipe.to(torch_device)
165+
sd_pipe.set_progress_bar_config(disable=None)
166+
167+
# forward 1
168+
inputs = self.get_dummy_inputs(device)
169+
output = sd_pipe(**inputs)
170+
image = output.images
171+
image_slice = image[0, -3:, -3:, -1]
172+
173+
# set lora layers
174+
lora_attn_procs = create_lora_layers(sd_pipe.unet)
175+
sd_pipe.unet.set_attn_processor(lora_attn_procs)
176+
sd_pipe = sd_pipe.to(torch_device)
177+
178+
# forward 2
179+
inputs = self.get_dummy_inputs(device)
180+
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0})
181+
image = output.images
182+
image_slice_1 = image[0, -3:, -3:, -1]
183+
184+
# forward 3
185+
inputs = self.get_dummy_inputs(device)
186+
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5})
187+
image = output.images
188+
image_slice_2 = image[0, -3:, -3:, -1]
189+
190+
assert np.abs(image_slice - image_slice_1).max() < 1e-2
191+
assert np.abs(image_slice - image_slice_2).max() > 1e-2
192+
158193
def test_inference_batch_single_identical(self):
159194
super().test_inference_batch_single_identical(expected_max_diff=3e-3)
160195

0 commit comments

Comments
 (0)