|
35 | 35 | from diffusers.utils import floats_tensor, load_image, load_numpy, nightly, slow, torch_device
|
36 | 36 | from diffusers.utils.testing_utils import require_torch_gpu
|
37 | 37 |
|
| 38 | +from ...models.test_models_unet_2d_condition import create_lora_layers |
38 | 39 | from ..pipeline_params import TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS, TEXT_GUIDED_IMAGE_INPAINTING_PARAMS
|
39 | 40 | from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin
|
40 | 41 |
|
@@ -155,6 +156,40 @@ def test_stable_diffusion_inpaint_image_tensor(self):
|
155 | 156 | assert out_pil.shape == (1, 64, 64, 3)
|
156 | 157 | assert np.abs(out_pil.flatten() - out_tensor.flatten()).max() < 5e-2
|
157 | 158 |
|
| 159 | + def test_stable_diffusion_inpaint_lora(self): |
| 160 | + device = "cpu" # ensure determinism for the device-dependent torch.Generator |
| 161 | + |
| 162 | + components = self.get_dummy_components() |
| 163 | + sd_pipe = StableDiffusionInpaintPipeline(**components) |
| 164 | + sd_pipe = sd_pipe.to(torch_device) |
| 165 | + sd_pipe.set_progress_bar_config(disable=None) |
| 166 | + |
| 167 | + # forward 1 |
| 168 | + inputs = self.get_dummy_inputs(device) |
| 169 | + output = sd_pipe(**inputs) |
| 170 | + image = output.images |
| 171 | + image_slice = image[0, -3:, -3:, -1] |
| 172 | + |
| 173 | + # set lora layers |
| 174 | + lora_attn_procs = create_lora_layers(sd_pipe.unet) |
| 175 | + sd_pipe.unet.set_attn_processor(lora_attn_procs) |
| 176 | + sd_pipe = sd_pipe.to(torch_device) |
| 177 | + |
| 178 | + # forward 2 |
| 179 | + inputs = self.get_dummy_inputs(device) |
| 180 | + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.0}) |
| 181 | + image = output.images |
| 182 | + image_slice_1 = image[0, -3:, -3:, -1] |
| 183 | + |
| 184 | + # forward 3 |
| 185 | + inputs = self.get_dummy_inputs(device) |
| 186 | + output = sd_pipe(**inputs, cross_attention_kwargs={"scale": 0.5}) |
| 187 | + image = output.images |
| 188 | + image_slice_2 = image[0, -3:, -3:, -1] |
| 189 | + |
| 190 | + assert np.abs(image_slice - image_slice_1).max() < 1e-2 |
| 191 | + assert np.abs(image_slice - image_slice_2).max() > 1e-2 |
| 192 | + |
158 | 193 | def test_inference_batch_single_identical(self):
|
159 | 194 | super().test_inference_batch_single_identical(expected_max_diff=3e-3)
|
160 | 195 |
|
|
0 commit comments