Skip to content

[SD-XL] Add inpainting #4098

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 102 additions & 11 deletions docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,50 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
image = pipe(prompt=prompt).images[0]
```

### Image-to-image

You can use SDXL as follows for *image-to-image*:

```py
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
from diffusers.utils import load_image

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it really intended to use the refiner model for general img2img? I've been trying to understand this - I've also seen it here for example - but I think I am missing something. My understanding is that the refiner model is intended as a kind of de-noising and/or fidelity-increasing step and it isn't good at generating the kind of baseline content of the image. If that's correct, feels like it'd perform poorly for img2img with lower strength values.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can use both! The refiner might be better suited for images that look already like the prompt which is the case here. We should maybe improve the docs after the official release.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Gotcha - thanks for clarifying!

)
pipe = pipe.to("cuda")
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"

init_image = load_image(url).convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, image=init_image).images[0]
```

### Inpainting

You can use SDXL as follows for *inpainting*

```py
import torch
from diffusers import StableDiffusionXLInpaintPipeline
from diffusers.utils import load_image

pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe.to("cuda")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")

prompt = "A majestic tiger sitting on a bench"
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
```

### Refining the image output

In addition to the [base model checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
Expand Down Expand Up @@ -183,24 +227,65 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
|---|---|
| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/init_image.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_image.png) |

### Image-to-image
<Tip>

```py
import torch
from diffusers import StableDiffusionXLImg2ImgPipeline
The refiner can also very well be used in an in-painting setting. To do so just make
sure you use the [`StableDiffusionXLInpaintPipeline`] classes as shown below

</Tip>

To use the refiner for inpainting in the Ensemble of Expert Denoisers setting you can do the following:

```py
from diffusers import StableDiffusionXLInpaintPipeline
from diffusers.utils import load_image

pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
)
pipe = pipe.to("cuda")
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
pipe.to("cuda")

init_image = load_image(url).convert("RGB")
prompt = "a photo of an astronaut riding a horse on mars"
image = pipe(prompt, image=init_image).images[0]
refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
"stabilityai/stable-diffusion-xl-refiner-0.9",
text_encoder_2=pipe.text_encoder_2,
vae=pipe.vae,
torch_dtype=torch.float16,
use_safetensors=True,
variant="fp16",
)
refiner.to("cuda")

img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"

init_image = load_image(img_url).convert("RGB")
mask_image = load_image(mask_url).convert("RGB")

prompt = "A majestic tiger sitting on a bench"
num_inference_steps = 75
high_noise_frac = 0.7

image = pipe(
prompt=prompt,
image=init_image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
strength=0.80,
denoising_start=high_noise_frac,
output_type="latent",
).images
image = refiner(
prompt=prompt,
image=image,
mask_image=mask_image,
num_inference_steps=num_inference_steps,
denoising_start=high_noise_frac,
).images[0]
```

To use the refiner for inpainting in the standard SDE-style setting, simply remove `denoising_end` and `denoising_start` and choose a smaller
number of inference steps for the refiner.

### Loading single file checkpoints / original file format

By making use of [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] you can also load the
Expand Down Expand Up @@ -271,3 +356,9 @@ pip install xformers
[[autodoc]] StableDiffusionXLImg2ImgPipeline
- all
- __call__

## StableDiffusionXLInpaintPipeline

[[autodoc]] StableDiffusionXLInpaintPipeline
- all
- __call__
6 changes: 5 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,11 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
from .pipelines import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)

try:
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,11 @@
except OptionalDependencyNotAvailable:
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
else:
from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
from .stable_diffusion_xl import (
StableDiffusionXLImg2ImgPipeline,
StableDiffusionXLInpaintPipeline,
StableDiffusionXLPipeline,
)

try:
if not is_onnx_available():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -981,8 +981,6 @@ def __call__(
generator,
do_classifier_free_guidance,
)
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is never used actually and was a copy-paste bug I think

init_image = self._encode_vae_image(init_image, generator=generator)

# 8. Check that sizes of mask, masked image and latents match
if num_channels_unet == 9:
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/pipelines/stable_diffusion_xl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Expand All @@ -75,7 +76,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):

class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
Pipeline for text-to-image generation using Stable Diffusion XL.

This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Expand All @@ -92,12 +93,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
Frozen text-encoder. Stable Diffusion XL uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([` CLIPTextModelWithProjection`]):
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
specifically the
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`CLIPTokenizer`):
Second Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
"""


# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
"""
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
Expand All @@ -80,7 +81,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):

class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
r"""
Pipeline for text-to-image generation using Stable Diffusion.
Pipeline for text-to-image generation using Stable Diffusion XL.

This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
Expand All @@ -97,12 +98,21 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
vae ([`AutoencoderKL`]):
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
text_encoder ([`CLIPTextModel`]):
Frozen text-encoder. Stable Diffusion uses the text portion of
Frozen text-encoder. Stable Diffusion XL uses the text portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
text_encoder_2 ([` CLIPTextModelWithProjection`]):
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
specifically the
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
variant.
tokenizer (`CLIPTokenizer`):
Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
tokenizer_2 (`CLIPTokenizer`):
Second Tokenizer of class
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
scheduler ([`SchedulerMixin`]):
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
Expand Down
Loading