Skip to content

Commit af28ae2

Browse files
add PAG support for SD Img2Img (#9463)
* added pag to sd img2img pipeline --------- Co-authored-by: YiYi Xu <[email protected]>
1 parent 31058cd commit af28ae2

File tree

8 files changed

+1401
-0
lines changed

8 files changed

+1401
-0
lines changed

docs/source/en/api/pipelines/pag.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
5353
- all
5454
- __call__
5555

56+
## StableDiffusionPAGImg2ImgPipeline
57+
[[autodoc]] StableDiffusionPAGImg2ImgPipeline
58+
- all
59+
- __call__
60+
5661
## StableDiffusionControlNetPAGPipeline
5762
[[autodoc]] StableDiffusionControlNetPAGPipeline
5863

src/diffusers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,7 @@
344344
"StableDiffusionLatentUpscalePipeline",
345345
"StableDiffusionLDM3DPipeline",
346346
"StableDiffusionModelEditingPipeline",
347+
"StableDiffusionPAGImg2ImgPipeline",
347348
"StableDiffusionPAGPipeline",
348349
"StableDiffusionPanoramaPipeline",
349350
"StableDiffusionParadigmsPipeline",
@@ -795,6 +796,7 @@
795796
StableDiffusionLatentUpscalePipeline,
796797
StableDiffusionLDM3DPipeline,
797798
StableDiffusionModelEditingPipeline,
799+
StableDiffusionPAGImg2ImgPipeline,
798800
StableDiffusionPAGPipeline,
799801
StableDiffusionPanoramaPipeline,
800802
StableDiffusionParadigmsPipeline,

src/diffusers/pipelines/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,7 @@
164164
"HunyuanDiTPAGPipeline",
165165
"StableDiffusion3PAGPipeline",
166166
"StableDiffusionPAGPipeline",
167+
"StableDiffusionPAGImg2ImgPipeline",
167168
"StableDiffusionControlNetPAGPipeline",
168169
"StableDiffusionXLPAGPipeline",
169170
"StableDiffusionXLPAGInpaintPipeline",
@@ -569,6 +570,7 @@
569570
StableDiffusion3PAGPipeline,
570571
StableDiffusionControlNetPAGInpaintPipeline,
571572
StableDiffusionControlNetPAGPipeline,
573+
StableDiffusionPAGImg2ImgPipeline,
572574
StableDiffusionPAGPipeline,
573575
StableDiffusionXLControlNetPAGImg2ImgPipeline,
574576
StableDiffusionXLControlNetPAGPipeline,

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@
6363
StableDiffusion3PAGPipeline,
6464
StableDiffusionControlNetPAGInpaintPipeline,
6565
StableDiffusionControlNetPAGPipeline,
66+
StableDiffusionPAGImg2ImgPipeline,
6667
StableDiffusionPAGPipeline,
6768
StableDiffusionXLControlNetPAGImg2ImgPipeline,
6869
StableDiffusionXLControlNetPAGPipeline,
@@ -131,6 +132,7 @@
131132
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
132133
("kandinsky3", Kandinsky3Img2ImgPipeline),
133134
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
135+
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
134136
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
135137
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
136138
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),

src/diffusers/pipelines/pag/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
3333
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
3434
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
35+
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
3536
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
3637
_import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
3738
_import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -54,6 +55,7 @@
5455
from .pipeline_pag_sd import StableDiffusionPAGPipeline
5556
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
5657
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
58+
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
5759
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
5860
from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
5961
from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline

src/diffusers/pipelines/pag/pipeline_pag_sd_img2img.py

Lines changed: 1091 additions & 0 deletions
Large diffs are not rendered by default.

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1592,6 +1592,21 @@ def from_pretrained(cls, *args, **kwargs):
15921592
requires_backends(cls, ["torch", "transformers"])
15931593

15941594

1595+
class StableDiffusionPAGImg2ImgPipeline(metaclass=DummyObject):
1596+
_backends = ["torch", "transformers"]
1597+
1598+
def __init__(self, *args, **kwargs):
1599+
requires_backends(self, ["torch", "transformers"])
1600+
1601+
@classmethod
1602+
def from_config(cls, *args, **kwargs):
1603+
requires_backends(cls, ["torch", "transformers"])
1604+
1605+
@classmethod
1606+
def from_pretrained(cls, *args, **kwargs):
1607+
requires_backends(cls, ["torch", "transformers"])
1608+
1609+
15951610
class StableDiffusionPAGPipeline(metaclass=DummyObject):
15961611
_backends = ["torch", "transformers"]
15971612

Lines changed: 282 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,282 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import inspect
18+
import random
19+
import unittest
20+
21+
import numpy as np
22+
import torch
23+
from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer
24+
25+
from diffusers import (
26+
AutoencoderKL,
27+
AutoencoderTiny,
28+
AutoPipelineForImage2Image,
29+
EulerDiscreteScheduler,
30+
StableDiffusionImg2ImgPipeline,
31+
StableDiffusionPAGImg2ImgPipeline,
32+
UNet2DConditionModel,
33+
)
34+
from diffusers.utils.testing_utils import (
35+
enable_full_determinism,
36+
floats_tensor,
37+
load_image,
38+
require_torch_gpu,
39+
slow,
40+
torch_device,
41+
)
42+
43+
from ..pipeline_params import (
44+
IMAGE_TO_IMAGE_IMAGE_PARAMS,
45+
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
46+
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
47+
TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS,
48+
)
49+
from ..test_pipelines_common import (
50+
IPAdapterTesterMixin,
51+
PipelineKarrasSchedulerTesterMixin,
52+
PipelineLatentTesterMixin,
53+
PipelineTesterMixin,
54+
)
55+
56+
57+
enable_full_determinism()
58+
59+
60+
class StableDiffusionPAGImg2ImgPipelineFastTests(
61+
IPAdapterTesterMixin,
62+
PipelineLatentTesterMixin,
63+
PipelineKarrasSchedulerTesterMixin,
64+
PipelineTesterMixin,
65+
unittest.TestCase,
66+
):
67+
pipeline_class = StableDiffusionPAGImg2ImgPipeline
68+
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS.union({"pag_scale", "pag_adaptive_scale"}) - {"height", "width"}
69+
required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"}
70+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS
71+
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
72+
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
73+
callback_cfg_params = TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS
74+
75+
def get_dummy_components(self, time_cond_proj_dim=None):
76+
torch.manual_seed(0)
77+
unet = UNet2DConditionModel(
78+
block_out_channels=(32, 64),
79+
layers_per_block=2,
80+
time_cond_proj_dim=time_cond_proj_dim,
81+
sample_size=32,
82+
in_channels=4,
83+
out_channels=4,
84+
down_block_types=("DownBlock2D", "CrossAttnDownBlock2D"),
85+
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
86+
cross_attention_dim=32,
87+
)
88+
scheduler = EulerDiscreteScheduler(
89+
beta_start=0.00085,
90+
beta_end=0.012,
91+
steps_offset=1,
92+
beta_schedule="scaled_linear",
93+
timestep_spacing="leading",
94+
)
95+
torch.manual_seed(0)
96+
vae = AutoencoderKL(
97+
block_out_channels=[32, 64],
98+
in_channels=3,
99+
out_channels=3,
100+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
101+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
102+
latent_channels=4,
103+
sample_size=128,
104+
)
105+
text_encoder_config = CLIPTextConfig(
106+
bos_token_id=0,
107+
eos_token_id=2,
108+
hidden_size=32,
109+
intermediate_size=37,
110+
layer_norm_eps=1e-05,
111+
num_attention_heads=4,
112+
num_hidden_layers=5,
113+
pad_token_id=1,
114+
vocab_size=1000,
115+
)
116+
text_encoder = CLIPTextModel(text_encoder_config)
117+
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
118+
119+
components = {
120+
"unet": unet,
121+
"scheduler": scheduler,
122+
"vae": vae,
123+
"text_encoder": text_encoder,
124+
"tokenizer": tokenizer,
125+
"safety_checker": None,
126+
"feature_extractor": None,
127+
"image_encoder": None,
128+
}
129+
return components
130+
131+
def get_dummy_tiny_autoencoder(self):
132+
return AutoencoderTiny(in_channels=3, out_channels=3, latent_channels=4)
133+
134+
def get_dummy_inputs(self, device, seed=0):
135+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
136+
image = image / 2 + 0.5
137+
if str(device).startswith("mps"):
138+
generator = torch.manual_seed(seed)
139+
else:
140+
generator = torch.Generator(device=device).manual_seed(seed)
141+
inputs = {
142+
"prompt": "A painting of a squirrel eating a burger",
143+
"image": image,
144+
"generator": generator,
145+
"num_inference_steps": 2,
146+
"guidance_scale": 6.0,
147+
"pag_scale": 0.9,
148+
"output_type": "np",
149+
}
150+
return inputs
151+
152+
def test_pag_disable_enable(self):
153+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
154+
components = self.get_dummy_components()
155+
156+
# base pipeline (expect same output when pag is disabled)
157+
pipe_sd = StableDiffusionImg2ImgPipeline(**components)
158+
pipe_sd = pipe_sd.to(device)
159+
pipe_sd.set_progress_bar_config(disable=None)
160+
161+
inputs = self.get_dummy_inputs(device)
162+
del inputs["pag_scale"]
163+
assert (
164+
"pag_scale" not in inspect.signature(pipe_sd.__call__).parameters
165+
), f"`pag_scale` should not be a call parameter of the base pipeline {pipe_sd.__class__.__name__}."
166+
out = pipe_sd(**inputs).images[0, -3:, -3:, -1]
167+
168+
# pag disabled with pag_scale=0.0
169+
pipe_pag = self.pipeline_class(**components)
170+
pipe_pag = pipe_pag.to(device)
171+
pipe_pag.set_progress_bar_config(disable=None)
172+
173+
inputs = self.get_dummy_inputs(device)
174+
inputs["pag_scale"] = 0.0
175+
out_pag_disabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
176+
177+
# pag enabled
178+
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
179+
pipe_pag = pipe_pag.to(device)
180+
pipe_pag.set_progress_bar_config(disable=None)
181+
182+
inputs = self.get_dummy_inputs(device)
183+
out_pag_enabled = pipe_pag(**inputs).images[0, -3:, -3:, -1]
184+
185+
assert np.abs(out.flatten() - out_pag_disabled.flatten()).max() < 1e-3
186+
assert np.abs(out.flatten() - out_pag_enabled.flatten()).max() > 1e-3
187+
188+
def test_pag_inference(self):
189+
device = "cpu" # ensure determinism for the device-dependent torch.Generator
190+
components = self.get_dummy_components()
191+
192+
pipe_pag = self.pipeline_class(**components, pag_applied_layers=["mid", "up", "down"])
193+
pipe_pag = pipe_pag.to(device)
194+
pipe_pag.set_progress_bar_config(disable=None)
195+
196+
inputs = self.get_dummy_inputs(device)
197+
image = pipe_pag(**inputs).images
198+
image_slice = image[0, -3:, -3:, -1]
199+
200+
assert image.shape == (
201+
1,
202+
32,
203+
32,
204+
3,
205+
), f"the shape of the output image should be (1, 32, 32, 3) but got {image.shape}"
206+
207+
expected_slice = np.array(
208+
[0.44203848, 0.49598145, 0.42248967, 0.6707724, 0.5683791, 0.43603387, 0.58316565, 0.60077155, 0.5174199]
209+
)
210+
max_diff = np.abs(image_slice.flatten() - expected_slice).max()
211+
self.assertLessEqual(max_diff, 1e-3)
212+
213+
214+
@slow
215+
@require_torch_gpu
216+
class StableDiffusionPAGImg2ImgPipelineIntegrationTests(unittest.TestCase):
217+
pipeline_class = StableDiffusionPAGImg2ImgPipeline
218+
repo_id = "Jiali/stable-diffusion-1.5"
219+
220+
def setUp(self):
221+
super().setUp()
222+
gc.collect()
223+
torch.cuda.empty_cache()
224+
225+
def tearDown(self):
226+
super().tearDown()
227+
gc.collect()
228+
torch.cuda.empty_cache()
229+
230+
def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0):
231+
generator = torch.Generator(device=generator_device).manual_seed(seed)
232+
init_image = load_image(
233+
"https://huggingface.co/datasets/diffusers/test-arrays/resolve/main"
234+
"/stable_diffusion_img2img/sketch-mountains-input.png"
235+
)
236+
inputs = {
237+
"prompt": "a fantasy landscape, concept art, high resolution",
238+
"image": init_image,
239+
"generator": generator,
240+
"num_inference_steps": 3,
241+
"strength": 0.75,
242+
"guidance_scale": 7.5,
243+
"pag_scale": 3.0,
244+
"output_type": "np",
245+
}
246+
return inputs
247+
248+
def test_pag_cfg(self):
249+
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
250+
pipeline.enable_model_cpu_offload()
251+
pipeline.set_progress_bar_config(disable=None)
252+
253+
inputs = self.get_inputs(torch_device)
254+
image = pipeline(**inputs).images
255+
256+
image_slice = image[0, -3:, -3:, -1].flatten()
257+
assert image.shape == (1, 512, 512, 3)
258+
print(image_slice.flatten())
259+
expected_slice = np.array(
260+
[0.58251953, 0.5722656, 0.5683594, 0.55029297, 0.52001953, 0.52001953, 0.49951172, 0.45410156, 0.50146484]
261+
)
262+
assert (
263+
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
264+
), f"output is different from expected, {image_slice.flatten()}"
265+
266+
def test_pag_uncond(self):
267+
pipeline = AutoPipelineForImage2Image.from_pretrained(self.repo_id, enable_pag=True, torch_dtype=torch.float16)
268+
pipeline.enable_model_cpu_offload()
269+
pipeline.set_progress_bar_config(disable=None)
270+
271+
inputs = self.get_inputs(torch_device, guidance_scale=0.0)
272+
image = pipeline(**inputs).images
273+
274+
image_slice = image[0, -3:, -3:, -1].flatten()
275+
assert image.shape == (1, 512, 512, 3)
276+
expected_slice = np.array(
277+
[0.5986328, 0.52441406, 0.3972168, 0.4741211, 0.34985352, 0.22705078, 0.4128418, 0.2866211, 0.31713867]
278+
)
279+
print(image_slice.flatten())
280+
assert (
281+
np.abs(image_slice.flatten() - expected_slice).max() < 1e-3
282+
), f"output is different from expected, {image_slice.flatten()}"

0 commit comments

Comments
 (0)