Skip to content

Add inpaint lora scale support #3460

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import inspect
import warnings
from typing import Callable, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import numpy as np
import PIL
Expand Down Expand Up @@ -690,6 +690,7 @@ def __call__(
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
):
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -754,7 +755,10 @@ def __call__(
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.

cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
Examples:

```py
Expand Down Expand Up @@ -893,9 +897,13 @@ def __call__(
latent_model_input = torch.cat([latent_model_input, mask, masked_image_latents], dim=1)

# predict the noise residual
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds, return_dict=False)[
0
]
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
return_dict=False,
)[0]

# perform guidance
if do_classifier_free_guidance:
Expand Down
35 changes: 35 additions & 0 deletions tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
from diffusers.utils.testing_utils import require_torch_gpu

from ...models.test_models_unet_2d_condition import create_lora_layers
from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin

Expand Down Expand Up @@ -155,6 +156,40 @@ def test_stable_diffusion_inpaint_image_tensor(self):
assert out_pil.shape == (1, 64, 64, 3)
assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2

def test_stable_diffusion_inpaint_lora(self):
device = "cpu" # ensure determinism for the device-dependent torch.Generator

components = self.get_dummy_components()
sd_pipe = StableDiffusionInpaintPipeline(**components)
sd_pipe = sd_pipe.to(torch_device)
sd_pipe.set_progress_bar_config(disable=None)

# forward 1
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs)
image = output.images
image_slice = image[0, -3:, -3:, -1]

# set lora layers
lora_attn_procs = create_lora_layers(sd_pipe.unet)
sd_pipe.unet.set_attn_processor(lora_attn_procs)
sd_pipe = sd_pipe.to(torch_device)

Comment on lines +174 to +177
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clean!

# forward 2
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0})
image = output.images
image_slice_1 = image[0, -3:, -3:, -1]

# forward 3
inputs = self.get_dummy_inputs(device)
output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5})
image = output.images
image_slice_2 = image[0, -3:, -3:, -1]

assert np.abs(image_slice - image_slice_1).max() < 1e-2
assert np.abs(image_slice - image_slice_2).max() > 1e-2

def test_inference_batch_single_identical(self):
super().test_inference_batch_single_identical(expected_max_diff=3e-3)

Expand Down