Skip to content

Commit 1d53c56

Browse files
yiyixuxuyiyixuxupatrickvonplaten
authored
updated doc for stable diffusion pipelines (huggingface#1770)
* add a doc page for each pipeline under api/pipelines/stable_diffusion * add pipeline examples to docstrings * updated stable_diffusion_2 page * updated default markdown syntax to list methods based on huggingface#1870 * add function decorator Co-authored-by: yiyixuxu <[email protected]> Co-authored-by: Patrick von Platen <[email protected]>
1 parent 1d3b169 commit 1d53c56

11 files changed

+221
-6
lines changed

models/embeddings_flax.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ def get_sinusoidal_embeddings(
2727
scale: float = 1.0,
2828
) -> jnp.ndarray:
2929
"""Returns the positional encoding (same as Tensor2Tensor).
30+
3031
Args:
3132
timesteps: a 1-D Tensor of N indices, one per batch element.
3233
These may be fractional.

pipelines/alt_diffusion/pipeline_alt_diffusion.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,29 @@
3131
LMSDiscreteScheduler,
3232
PNDMScheduler,
3333
)
34-
from ...utils import deprecate, logging
34+
from ...utils import deprecate, logging, replace_example_docstring
3535
from ..pipeline_utils import DiffusionPipeline
3636
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3737
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
3838

3939

4040
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4141

42+
EXAMPLE_DOC_STRING = """
43+
Examples:
44+
```py
45+
>>> import torch
46+
>>> from diffusers import AltDiffusionPipeline
47+
48+
>>> pipe = AltDiffusionPipeline.from_pretrained("BAAI/AltDiffusion-m9", torch_dtype=torch.float16)
49+
>>> pipe = pipe.to("cuda")
50+
51+
>>> # "dark elf princess, highly detailed, d & d, fantasy, highly detailed, digital painting, trending on artstation, concept art, sharp focus, illustration, art by artgerm and greg rutkowski and fuji choko and viktoria gavrilenko and hoang lap"
52+
>>> prompt = "黑暗精灵公主,非常详细,幻想,非常详细,数字绘画,概念艺术,敏锐的焦点,插图"
53+
>>> image = pipe(prompt).images[0]
54+
```
55+
"""
56+
4257

4358
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline with Stable->Alt, CLIPTextModel->RobertaSeriesModelWithTransformation, CLIPTokenizer->XLMRobertaTokenizer, AltDiffusionSafetyChecker->StableDiffusionSafetyChecker
4459
class AltDiffusionPipeline(DiffusionPipeline):
@@ -407,6 +422,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
407422
return latents
408423

409424
@torch.no_grad()
425+
@replace_example_docstring(EXAMPLE_DOC_STRING)
410426
def __call__(
411427
self,
412428
prompt: Union[str, List[str]],
@@ -471,6 +487,8 @@ def __call__(
471487
The frequency at which the `callback` function will be called. If not specified, the callback will be
472488
called at every step.
473489
490+
Examples:
491+
474492
Returns:
475493
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:
476494
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.

pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,43 @@
3333
LMSDiscreteScheduler,
3434
PNDMScheduler,
3535
)
36-
from ...utils import PIL_INTERPOLATION, deprecate, logging
36+
from ...utils import PIL_INTERPOLATION, deprecate, logging, replace_example_docstring
3737
from ..pipeline_utils import DiffusionPipeline
3838
from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
3939
from . import AltDiffusionPipelineOutput, RobertaSeriesModelWithTransformation
4040

4141

4242
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4343

44+
EXAMPLE_DOC_STRING = """
45+
Examples:
46+
```py
47+
>>> import requests
48+
>>> import torch
49+
>>> from PIL import Image
50+
>>> from io import BytesIO
51+
52+
>>> from diffusers import AltDiffusionImg2ImgPipeline
53+
54+
>>> device = "cuda"
55+
>>> model_id_or_path = "BAAI/AltDiffusion-m9"
56+
>>> pipe = AltDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
57+
>>> pipe = pipe.to(device)
58+
59+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
60+
61+
>>> response = requests.get(url)
62+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
63+
>>> init_image = init_image.resize((768, 512))
64+
65+
>>> # "A fantasy landscape, trending on artstation"
66+
>>> prompt = "幻想风景, artstation"
67+
68+
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
69+
>>> images[0].save("幻想风景.png")
70+
```
71+
"""
72+
4473

4574
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.preprocess
4675
def preprocess(image):
@@ -450,6 +479,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
450479
return latents
451480

452481
@torch.no_grad()
482+
@replace_example_docstring(EXAMPLE_DOC_STRING)
453483
def __call__(
454484
self,
455485
prompt: Union[str, List[str]],
@@ -514,6 +544,7 @@ def __call__(
514544
callback_steps (`int`, *optional*, defaults to 1):
515545
The frequency at which the `callback` function will be called. If not specified, the callback will be
516546
called at every step.
547+
Examples:
517548
518549
Returns:
519550
[`~pipelines.stable_diffusion.AltDiffusionPipelineOutput`] or `tuple`:

pipelines/stable_diffusion/pipeline_stable_diffusion.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import torch
1919

20-
from diffusers.utils import is_accelerate_available
2120
from packaging import version
2221
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2322

@@ -31,14 +30,28 @@
3130
LMSDiscreteScheduler,
3231
PNDMScheduler,
3332
)
34-
from ...utils import deprecate, logging
33+
from ...utils import deprecate, is_accelerate_available, logging, replace_example_docstring
3534
from ..pipeline_utils import DiffusionPipeline
3635
from . import StableDiffusionPipelineOutput
3736
from .safety_checker import StableDiffusionSafetyChecker
3837

3938

4039
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4140

41+
EXAMPLE_DOC_STRING = """
42+
Examples:
43+
```py
44+
>>> import torch
45+
>>> from diffusers import StableDiffusionPipeline
46+
47+
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
48+
>>> pipe = pipe.to("cuda")
49+
50+
>>> prompt = "a photo of an astronaut riding a horse on mars"
51+
>>> image = pipe(prompt).images[0]
52+
```
53+
"""
54+
4255

4356
class StableDiffusionPipeline(DiffusionPipeline):
4457
r"""
@@ -406,6 +419,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
406419
return latents
407420

408421
@torch.no_grad()
422+
@replace_example_docstring(EXAMPLE_DOC_STRING)
409423
def __call__(
410424
self,
411425
prompt: Union[str, List[str]],
@@ -470,6 +484,8 @@ def __call__(
470484
The frequency at which the `callback` function will be called. If not specified, the callback will be
471485
called at every step.
472486
487+
Examples:
488+
473489
Returns:
474490
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
475491
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.

pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,29 @@ def __call__(
505505
The frequency at which the `callback` function will be called. If not specified, the callback will be
506506
called at every step.
507507
508+
Examples:
509+
510+
```py
511+
>>> import torch
512+
>>> import requests
513+
>>> from PIL import Image
514+
515+
>>> from diffusers import StableDiffusionDepth2ImgPipeline
516+
517+
>>> pipe = StableDiffusionDepth2ImgPipeline.from_pretrained(
518+
... "stabilityai/stable-diffusion-2-depth",
519+
... torch_dtype=torch.float16,
520+
... )
521+
>>> pipe.to("cuda")
522+
523+
524+
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
525+
>>> init_image = Image.open(requests.get(url, stream=True).raw)
526+
>>> prompt = "two tigers"
527+
>>> n_propmt = "bad, deformed, ugly, bad anotomy"
528+
>>> image = pipe(prompt=prompt, image=init_image, negative_prompt=n_propmt, strength=0.7).images[0]
529+
```
530+
508531
Returns:
509532
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
510533
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.

pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020

2121
import PIL
22-
from diffusers.utils import is_accelerate_available
2322
from packaging import version
2423
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
2524

@@ -33,14 +32,42 @@
3332
LMSDiscreteScheduler,
3433
PNDMScheduler,
3534
)
36-
from ...utils import PIL_INTERPOLATION, deprecate, logging
35+
from ...utils import PIL_INTERPOLATION, deprecate, is_accelerate_available, logging, replace_example_docstring
3736
from ..pipeline_utils import DiffusionPipeline
3837
from . import StableDiffusionPipelineOutput
3938
from .safety_checker import StableDiffusionSafetyChecker
4039

4140

4241
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4342

43+
EXAMPLE_DOC_STRING = """
44+
Examples:
45+
```py
46+
>>> import requests
47+
>>> import torch
48+
>>> from PIL import Image
49+
>>> from io import BytesIO
50+
51+
>>> from diffusers import StableDiffusionImg2ImgPipeline
52+
53+
>>> device = "cuda"
54+
>>> model_id_or_path = "runwayml/stable-diffusion-v1-5"
55+
>>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16)
56+
>>> pipe = pipe.to(device)
57+
58+
>>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg"
59+
60+
>>> response = requests.get(url)
61+
>>> init_image = Image.open(BytesIO(response.content)).convert("RGB")
62+
>>> init_image = init_image.resize((768, 512))
63+
64+
>>> prompt = "A fantasy landscape, trending on artstation"
65+
66+
>>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images
67+
>>> images[0].save("fantasy_landscape.png")
68+
```
69+
"""
70+
4471

4572
def preprocess(image):
4673
if isinstance(image, torch.Tensor):
@@ -455,6 +482,7 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
455482
return latents
456483

457484
@torch.no_grad()
485+
@replace_example_docstring(EXAMPLE_DOC_STRING)
458486
def __call__(
459487
self,
460488
prompt: Union[str, List[str]],
@@ -519,6 +547,7 @@ def __call__(
519547
callback_steps (`int`, *optional*, defaults to 1):
520548
The frequency at which the `callback` function will be called. If not specified, the callback will be
521549
called at every step.
550+
Examples:
522551
523552
Returns:
524553
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:

pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -616,6 +616,37 @@ def __call__(
616616
The frequency at which the `callback` function will be called. If not specified, the callback will be
617617
called at every step.
618618
619+
Examples:
620+
621+
```py
622+
>>> import PIL
623+
>>> import requests
624+
>>> import torch
625+
>>> from io import BytesIO
626+
627+
>>> from diffusers import StableDiffusionInpaintPipeline
628+
629+
630+
>>> def download_image(url):
631+
... response = requests.get(url)
632+
... return PIL.Image.open(BytesIO(response.content)).convert("RGB")
633+
634+
635+
>>> img_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
636+
>>> mask_url = "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
637+
638+
>>> init_image = download_image(img_url).resize((512, 512))
639+
>>> mask_image = download_image(mask_url).resize((512, 512))
640+
641+
>>> pipe = StableDiffusionInpaintPipeline.from_pretrained(
642+
... "runwayml/stable-diffusion-inpainting", torch_dtype=torch.float16
643+
... )
644+
>>> pipe = pipe.to("cuda")
645+
646+
>>> prompt = "Face of a yellow cat, high resolution, sitting on a park bench"
647+
>>> image = pipe(prompt=prompt, image=init_image, mask_image=mask_image).images[0]
648+
```
649+
619650
Returns:
620651
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
621652
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.

pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,32 @@ def __call__(
390390
The frequency at which the `callback` function will be called. If not specified, the callback will be
391391
called at every step.
392392
393+
Examples:
394+
```py
395+
>>> import requests
396+
>>> from PIL import Image
397+
>>> from io import BytesIO
398+
>>> from diffusers import StableDiffusionUpscalePipeline
399+
>>> import torch
400+
401+
>>> # load model and scheduler
402+
>>> model_id = "stabilityai/stable-diffusion-x4-upscaler"
403+
>>> pipeline = StableDiffusionUpscalePipeline.from_pretrained(
404+
... model_id, revision="fp16", torch_dtype=torch.float16
405+
... )
406+
>>> pipeline = pipeline.to("cuda")
407+
408+
>>> # let's download an image
409+
>>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
410+
>>> response = requests.get(url)
411+
>>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
412+
>>> low_res_img = low_res_img.resize((128, 128))
413+
>>> prompt = "a white cat"
414+
415+
>>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
416+
>>> upscaled_image.save("upsampled_cat.png")
417+
```
418+
393419
Returns:
394420
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
395421
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.

training_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
6161
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
6262
at 215.4k steps).
63+
6364
Args:
6465
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
6566
power (float): Exponential factor of EMA warmup. Default: 2/3.

utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
WEIGHTS_NAME,
3333
)
3434
from .deprecation_utils import deprecate
35+
from .doc_utils import replace_example_docstring
3536
from .dynamic_modules_utils import get_class_from_dynamic_module
3637
from .hub_utils import HF_HUB_OFFLINE, http_user_agent
3738
from .import_utils import (

utils/doc_utils.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright 2022 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Doc utilities: Utilities related to documentation
16+
"""
17+
import re
18+
19+
20+
def replace_example_docstring(example_docstring):
21+
def docstring_decorator(fn):
22+
func_doc = fn.__doc__
23+
lines = func_doc.split("\n")
24+
i = 0
25+
while i < len(lines) and re.search(r"^\s*Examples?:\s*$", lines[i]) is None:
26+
i += 1
27+
if i < len(lines):
28+
lines[i] = example_docstring
29+
func_doc = "\n".join(lines)
30+
else:
31+
raise ValueError(
32+
f"The function {fn} should have an empty 'Examples:' in its docstring as placeholder, "
33+
f"current docstring is:\n{func_doc}"
34+
)
35+
fn.__doc__ = func_doc
36+
return fn
37+
38+
return docstring_decorator

0 commit comments

Comments
 (0)