Skip to content

Commit e459834

Browse files
StableDiffusionInpaintingPipeline - resize image w.r.t height and width (huggingface#3322)
* StableDiffusionInpaintingPipeline now resizes input images and masks w.r.t to passed input height and width. Default is already set to 512. This addresses the common tensor mismatch error. Also moved type check into relevant funciton to keep main pipeline body tidy. * Fixed StableDiffusionInpaintingPrepareMaskAndMaskedImageTests Due to previous commit these tests were failing as height and width need to be passed into the prepare_mask_and_masked_image function, I have updated the code and added a height/width variable per unit test as it seemed more appropriate than the current hard coded solution * Added a resolution test to StableDiffusionInpaintPipelineSlowTests this unit test simply gets the input and resizes it into some that would fail (e.g. would throw a tensor mismatch error/not a mult of 8). Then passes it through the pipeline and verifies it produces output with correct dims w.r.t the passed height and width --------- Co-authored-by: Patrick von Platen <[email protected]>
1 parent aabde54 commit e459834

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
3737

3838

39-
def prepare_mask_and_masked_image(image, mask):
39+
def prepare_mask_and_masked_image(image, mask, height, width):
4040
"""
4141
Prepares a pair (image, mask) to be consumed by the Stable Diffusion pipeline. This means that those inputs will be
4242
converted to ``torch.Tensor`` with shapes ``batch x channels x height x width`` where ``channels`` is ``3`` for the
@@ -64,6 +64,13 @@ def prepare_mask_and_masked_image(image, mask):
6464
tuple[torch.Tensor]: The pair (mask, masked_image) as ``torch.Tensor`` with 4
6565
dimensions: ``batch x channels x height x width``.
6666
"""
67+
68+
if image is None:
69+
raise ValueError("`image` input cannot be undefined.")
70+
71+
if mask is None:
72+
raise ValueError("`mask_image` input cannot be undefined.")
73+
6774
if isinstance(image, torch.Tensor):
6875
if not isinstance(mask, torch.Tensor):
6976
raise TypeError(f"`image` is a torch.Tensor but `mask` (type: {type(mask)} is not")
@@ -111,8 +118,9 @@ def prepare_mask_and_masked_image(image, mask):
111118
# preprocess image
112119
if isinstance(image, (PIL.Image.Image, np.ndarray)):
113120
image = [image]
114-
115121
if isinstance(image, list) and isinstance(image[0], PIL.Image.Image):
122+
# resize all images w.r.t passed height an width
123+
image = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in image]
116124
image = [np.array(i.convert("RGB"))[None, :] for i in image]
117125
image = np.concatenate(image, axis=0)
118126
elif isinstance(image, list) and isinstance(image[0], np.ndarray):
@@ -126,6 +134,7 @@ def prepare_mask_and_masked_image(image, mask):
126134
mask = [mask]
127135

128136
if isinstance(mask, list) and isinstance(mask[0], PIL.Image.Image):
137+
mask = [i.resize((width, height), resample=PIL.Image.LANCZOS) for i in mask]
129138
mask = np.concatenate([np.array(m.convert("L"))[None, None, :] for m in mask], axis=0)
130139
mask = mask.astype(np.float32) / 255.0
131140
elif isinstance(mask, list) and isinstance(mask[0], np.ndarray):
@@ -799,12 +808,6 @@ def __call__(
799808
negative_prompt_embeds,
800809
)
801810

802-
if image is None:
803-
raise ValueError("`image` input cannot be undefined.")
804-
805-
if mask_image is None:
806-
raise ValueError("`mask_image` input cannot be undefined.")
807-
808811
# 2. Define call parameters
809812
if prompt is not None and isinstance(prompt, str):
810813
batch_size = 1
@@ -830,8 +833,8 @@ def __call__(
830833
negative_prompt_embeds=negative_prompt_embeds,
831834
)
832835

833-
# 4. Preprocess mask and image
834-
mask, masked_image = prepare_mask_and_masked_image(image, mask_image)
836+
# 4. Preprocess mask and image - resizes image and mask w.r.t height and width
837+
mask, masked_image = prepare_mask_and_masked_image(image, mask_image, height, width)
835838

836839
# 5. set timesteps
837840
self.scheduler.set_timesteps(num_inference_steps, device=device)

0 commit comments

Comments
 (0)