Skip to content

Commit b024ebb

Browse files
[SD-XL] Add inpainting (#4098)
* Add more * more * up * Get ensemble of expert denoisers working * Fix code * add tests * up
1 parent ad8f985 commit b024ebb

File tree

10 files changed

+1775
-19
lines changed

10 files changed

+1775
-19
lines changed

docs/source/en/api/pipelines/stable_diffusion/stable_diffusion_xl.mdx

Lines changed: 102 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,50 @@ prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
5757
image = pipe(prompt=prompt).images[0]
5858
```
5959

60+
### Image-to-image
61+
62+
You can use SDXL as follows for *image-to-image*:
63+
64+
```py
65+
import torch
66+
from diffusers import StableDiffusionXLImg2ImgPipeline
67+
from diffusers.utils import load_image
68+
69+
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
70+
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
71+
)
72+
pipe = pipe.to("cuda")
73+
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
74+
75+
init_image = load_image(url).convert("RGB")
76+
prompt = "a photo of an astronaut riding a horse on mars"
77+
image = pipe(prompt, image=init_image).images[0]
78+
```
79+
80+
### Inpainting
81+
82+
You can use SDXL as follows for *inpainting*
83+
84+
```py
85+
import torch
86+
from diffusers import StableDiffusionXLInpaintPipeline
87+
from diffusers.utils import load_image
88+
89+
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
90+
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
91+
)
92+
pipe.to("cuda")
93+
94+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
95+
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
96+
97+
init_image = load_image(img_url).convert("RGB")
98+
mask_image = load_image(mask_url).convert("RGB")
99+
100+
prompt = "A majestic tiger sitting on a bench"
101+
image = pipe(prompt=prompt, image=init_image, mask_image=mask_image, num_inference_steps=50, strength=0.80).images[0]
102+
```
103+
60104
### Refining the image output
61105

62106
In addition to the [base model checkpoint](https://huggingface.co/stabilityai/stable-diffusion-xl-base-0.9),
@@ -183,24 +227,65 @@ image = refiner(prompt=prompt, image=image[None, :]).images[0]
183227
|---|---|
184228
| ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/init_image.png) | ![](https://huggingface.co/datasets/diffusers/docs-images/resolve/main/sd_xl/refined_image.png) |
185229

186-
### Image-to-image
230+
<Tip>
187231

188-
```py
189-
import torch
190-
from diffusers import StableDiffusionXLImg2ImgPipeline
232+
The refiner can also very well be used in an in-painting setting. To do so just make
233+
sure you use the [`StableDiffusionXLInpaintPipeline`] classes as shown below
234+
235+
</Tip>
236+
237+
To use the refiner for inpainting in the Ensemble of Expert Denoisers setting you can do the following:
238+
239+
```py
240+
from diffusers import StableDiffusionXLInpaintPipeline
191241
from diffusers.utils import load_image
192242

193-
pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
194-
"stabilityai/stable-diffusion-xl-refiner-0.9", torch_dtype=torch.float16
243+
pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
244+
"stabilityai/stable-diffusion-xl-base-0.9", torch_dtype=torch.float16, variant="fp16", use_safetensors=True
195245
)
196-
pipe = pipe.to("cuda")
197-
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
246+
pipe.to("cuda")
198247

199-
init_image = load_image(url).convert("RGB")
200-
prompt = "a photo of an astronaut riding a horse on mars"
201-
image = pipe(prompt, image=init_image).images[0]
248+
refiner = StableDiffusionXLInpaintPipeline.from_pretrained(
249+
"stabilityai/stable-diffusion-xl-refiner-0.9",
250+
text_encoder_2=pipe.text_encoder_2,
251+
vae=pipe.vae,
252+
torch_dtype=torch.float16,
253+
use_safetensors=True,
254+
variant="fp16",
255+
)
256+
refiner.to("cuda")
257+
258+
img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
259+
mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
260+
261+
init_image = load_image(img_url).convert("RGB")
262+
mask_image = load_image(mask_url).convert("RGB")
263+
264+
prompt = "A majestic tiger sitting on a bench"
265+
num_inference_steps = 75
266+
high_noise_frac = 0.7
267+
268+
image = pipe(
269+
prompt=prompt,
270+
image=init_image,
271+
mask_image=mask_image,
272+
num_inference_steps=num_inference_steps,
273+
strength=0.80,
274+
denoising_start=high_noise_frac,
275+
output_type="latent",
276+
).images
277+
image = refiner(
278+
prompt=prompt,
279+
image=image,
280+
mask_image=mask_image,
281+
num_inference_steps=num_inference_steps,
282+
denoising_start=high_noise_frac,
283+
).images[0]
202284
```
203285

286+
To use the refiner for inpainting in the standard SDE-style setting, simply remove `denoising_end` and `denoising_start` and choose a smaller
287+
number of inference steps for the refiner.
288+
204289
### Loading single file checkpoints / original file format
205290

206291
By making use of [`~diffusers.loaders.FromSingleFileMixin.from_single_file`] you can also load the
@@ -271,3 +356,9 @@ pip install xformers
271356
[[autodoc]] StableDiffusionXLImg2ImgPipeline
272357
- all
273358
- __call__
359+
360+
## StableDiffusionXLInpaintPipeline
361+
362+
[[autodoc]] StableDiffusionXLInpaintPipeline
363+
- all
364+
- __call__

src/diffusers/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,11 @@
195195
except OptionalDependencyNotAvailable:
196196
from .utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
197197
else:
198-
from .pipelines import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
198+
from .pipelines import (
199+
StableDiffusionXLImg2ImgPipeline,
200+
StableDiffusionXLInpaintPipeline,
201+
StableDiffusionXLPipeline,
202+
)
199203

200204
try:
201205
if not (is_torch_available() and is_transformers_available() and is_k_diffusion_available()):

src/diffusers/pipelines/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,11 @@
119119
except OptionalDependencyNotAvailable:
120120
from ..utils.dummy_torch_and_transformers_and_invisible_watermark_objects import * # noqa F403
121121
else:
122-
from .stable_diffusion_xl import StableDiffusionXLImg2ImgPipeline, StableDiffusionXLPipeline
122+
from .stable_diffusion_xl import (
123+
StableDiffusionXLImg2ImgPipeline,
124+
StableDiffusionXLInpaintPipeline,
125+
StableDiffusionXLPipeline,
126+
)
123127

124128
try:
125129
if not is_onnx_available():

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -981,8 +981,6 @@ def __call__(
981981
generator,
982982
do_classifier_free_guidance,
983983
)
984-
init_image = init_image.to(device=device, dtype=masked_image_latents.dtype)
985-
init_image = self._encode_vae_image(init_image, generator=generator)
986984

987985
# 8. Check that sizes of mask, masked image and latents match
988986
if num_channels_unet == 9:

src/diffusers/pipelines/stable_diffusion_xl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,4 @@ class StableDiffusionXLPipelineOutput(BaseOutput):
2424
if is_transformers_available() and is_torch_available() and is_invisible_watermark_available():
2525
from .pipeline_stable_diffusion_xl import StableDiffusionXLPipeline
2626
from .pipeline_stable_diffusion_xl_img2img import StableDiffusionXLImg2ImgPipeline
27+
from .pipeline_stable_diffusion_xl_inpaint import StableDiffusionXLInpaintPipeline

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
"""
6060

6161

62+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
6263
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
6364
"""
6465
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -75,7 +76,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
7576

7677
class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
7778
r"""
78-
Pipeline for text-to-image generation using Stable Diffusion.
79+
Pipeline for text-to-image generation using Stable Diffusion XL.
7980
8081
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
8182
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -92,12 +93,21 @@ class StableDiffusionXLPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoad
9293
vae ([`AutoencoderKL`]):
9394
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
9495
text_encoder ([`CLIPTextModel`]):
95-
Frozen text-encoder. Stable Diffusion uses the text portion of
96+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
9697
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
9798
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
99+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
100+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
101+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
102+
specifically the
103+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
104+
variant.
98105
tokenizer (`CLIPTokenizer`):
99106
Tokenizer of class
100107
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
108+
tokenizer_2 (`CLIPTokenizer`):
109+
Second Tokenizer of class
110+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
101111
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
102112
scheduler ([`SchedulerMixin`]):
103113
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
"""
6565

6666

67+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
6768
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
6869
"""
6970
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
@@ -80,7 +81,7 @@ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
8081

8182
class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, LoraLoaderMixin):
8283
r"""
83-
Pipeline for text-to-image generation using Stable Diffusion.
84+
Pipeline for text-to-image generation using Stable Diffusion XL.
8485
8586
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
8687
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
@@ -97,12 +98,21 @@ class StableDiffusionXLImg2ImgPipeline(DiffusionPipeline, FromSingleFileMixin, L
9798
vae ([`AutoencoderKL`]):
9899
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
99100
text_encoder ([`CLIPTextModel`]):
100-
Frozen text-encoder. Stable Diffusion uses the text portion of
101+
Frozen text-encoder. Stable Diffusion XL uses the text portion of
101102
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
102103
the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
104+
text_encoder_2 ([` CLIPTextModelWithProjection`]):
105+
Second frozen text-encoder. Stable Diffusion XL uses the text and pool portion of
106+
[CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModelWithProjection),
107+
specifically the
108+
[laion/CLIP-ViT-bigG-14-laion2B-39B-b160k](https://huggingface.co/laion/CLIP-ViT-bigG-14-laion2B-39B-b160k)
109+
variant.
103110
tokenizer (`CLIPTokenizer`):
104111
Tokenizer of class
105112
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
113+
tokenizer_2 (`CLIPTokenizer`):
114+
Second Tokenizer of class
115+
[CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
106116
unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
107117
scheduler ([`SchedulerMixin`]):
108118
A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of

0 commit comments

Comments
 (0)