Skip to content

Add StableDiffusion3PAGImg2Img Pipeline + Fix SD3 Unconditional PAG #9932

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Dec 3, 2024
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
adefc1b
fix progress bar updates in SD 1.5 PAG Img2Img pipeline
painebenjamin Nov 14, 2024
4bb5b73
Merge branch 'huggingface:main' into main
painebenjamin Nov 14, 2024
47563bc
catch attention mask in PAG-only attention processor for SD3 pipelines
painebenjamin Nov 14, 2024
63a31d5
Merge branch 'main' of github.com:painebenjamin/diffusers
painebenjamin Nov 15, 2024
d49eb5b
Add SD3PAGImg2Img Pipeline and tests
painebenjamin Nov 15, 2024
399f6cd
add pipeline to docs and correct documentation, ruff
painebenjamin Nov 15, 2024
38c2c7b
add autogenerated stub
painebenjamin Nov 15, 2024
357beb1
remove typo
painebenjamin Nov 15, 2024
628bbbf
accidental delete!
painebenjamin Nov 15, 2024
f4448b3
Merge branch 'main' into main
painebenjamin Nov 16, 2024
7f1058a
Merge branch 'main' into main
painebenjamin Nov 19, 2024
2a8aa20
Merge branch 'main' into main
painebenjamin Nov 19, 2024
de3afaf
Merge branch 'main' into main
painebenjamin Nov 20, 2024
6ed9607
Update src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
painebenjamin Nov 20, 2024
d49cbc9
Update src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py
painebenjamin Nov 20, 2024
66c2ab4
style changes as suggested
painebenjamin Nov 20, 2024
1d192dd
Merge branch 'main' of github.com:painebenjamin/diffusers
painebenjamin Nov 20, 2024
664d47b
Merge branch 'main' into main
painebenjamin Nov 21, 2024
3736af3
update expected values
painebenjamin Nov 22, 2024
d35f725
Merge branch 'main' into main
painebenjamin Nov 22, 2024
43c34aa
Merge branch 'main' into main
painebenjamin Nov 25, 2024
561844b
Merge branch 'main' into main
painebenjamin Nov 26, 2024
340493a
Merge branch 'main' into main
sayakpaul Nov 28, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/source/en/api/pipelines/pag.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ Since RegEx is supported as a way for matching layer identifiers, it is crucial
- all
- __call__

## StableDiffusion3PAGImg2ImgPipeline
[[autodoc]] StableDiffusion3PAGImg2ImgPipeline
- all
- __call__

## PixArtSigmaPAGPipeline
[[autodoc]] PixArtSigmaPAGPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@
"StableDiffusion3Img2ImgPipeline",
"StableDiffusion3InpaintPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusion3Pipeline",
"StableDiffusionAdapterPipeline",
"StableDiffusionAttendAndExcitePipeline",
Expand Down Expand Up @@ -795,6 +796,7 @@
StableDiffusion3ControlNetPipeline,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3InpaintPipeline,
StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusion3Pipeline,
StableDiffusionAdapterPipeline,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1171,6 +1171,7 @@ def __call__(
attn: Attention,
hidden_states: torch.FloatTensor,
encoder_hidden_states: torch.FloatTensor = None,
attention_mask: Optional[torch.FloatTensor] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we add this here? it is not used, no?

Copy link
Contributor Author

@painebenjamin painebenjamin Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's the second part of what I wrote above - when using SD3+PAG and foregoing CFG (e.g. calling a PAG pipeline with guidance_scale=0,) PAGJointAttnProcessor2_0 is used instead of PAGCFGJointAttnProcessor2_0, and the following error is produced:

StableDiffusion3PAGImg2ImgPipelineIntegrationTests.test_pag_uncond 
__________________________________________________

self = <tests.pipelines.pag.test_pag_sd3_img2img.StableDiffusion3PAGImg2ImgPipelineIntegrationTests testMethod=test_pag_uncond>

    def test_pag_uncond(self):
        pipeline = AutoPipelineForImage2Image.from_pretrained(
            self.repo_id, enable_pag=True, torch_dtype=torch.float16, pag_applied_layers=["blocks.(4|17)"]
        )
        pipeline.enable_model_cpu_offload()
        pipeline.set_progress_bar_config(disable=None)
    
        inputs = self.get_inputs(torch_device, guidance_scale=0.0, pag_scale=1.8)
>       image = pipeline(**inputs).images

tests/pipelines/pag/test_pag_sd3_img2img.py:261: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/utils/_contextlib.py:116: in decorate_context
    return func(*args, **kwargs)
src/diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py:975: in __call__
    noise_pred = self.transformer(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/accelerate/hooks.py:170: in new_forward
    output = module._old_forward(*args, **kwargs)
src/diffusers/models/transformers/transformer_sd3.py:346: in forward
    encoder_hidden_states, hidden_states = block(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
src/diffusers/models/attention.py:208: in forward
    attn_output, context_attn_output = self.attn(
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1553: in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
../miniconda3/envs/taproot/lib/python3.10/site-packages/torch/nn/modules/module.py:1562: in _call_impl
    return forward_call(*args, **kwargs)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _

self = Attention(
  (to_q): Linear(in_features=1536, out_features=1536, bias=True)
  (to_k): Linear(in_features=1536, out_fea...ue)
    (1): Dropout(p=0.0, inplace=False)
  )
  (to_add_out): Linear(in_features=1536, out_features=1536, bias=True)
)
hidden_states = tensor([[[-0.0430, -3.7031,  0.2078,  ...,  0.3115,  0.0703,  0.0383],
         [ 0.0179, -2.4727,  0.1594,  ..., -0.0...,
         [-0.0490, -0.3691,  0.2568,  ..., -1.0303, -0.0298,  0.5527]]],
       device='cuda:0', dtype=torch.float16)
encoder_hidden_states = tensor([[[-0.0503,  0.0515, -0.0623,  ..., -0.0044, -0.0186, -0.0752],
         [ 0.1860, -0.2595,  0.0835,  ...,  0.1...,
         [ 0.6958, -0.4875, -0.1246,  ...,  0.2664, -0.1700,  0.0030]]],
       device='cuda:0', dtype=torch.float16)
attention_mask = None, cross_attention_kwargs = {}, unused_kwargs = []

    def forward(
        self,
        hidden_states: torch.Tensor,
        encoder_hidden_states: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        **cross_attention_kwargs,
    ) -> torch.Tensor:
        r"""
        The forward method of the `Attention` class.
    
        Args:
            hidden_states (`torch.Tensor`):
                The hidden states of the query.
            encoder_hidden_states (`torch.Tensor`, *optional*):
                The hidden states of the encoder.
            attention_mask (`torch.Tensor`, *optional*):
                The attention mask to use. If `None`, no mask is applied.
            **cross_attention_kwargs:
                Additional keyword arguments to pass along to the cross attention.
    
        Returns:
            `torch.Tensor`: The output of the attention layer.
        """
        # The `Attention` class can call different attention processors / attention functions
        # here we simply pass along all tensors to the selected processor class
        # For standard processors that are defined here, `**cross_attention_kwargs` is empty
    
        attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
        quiet_attn_parameters = {"ip_adapter_masks"}
        unused_kwargs = [
            k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
        ]
        if len(unused_kwargs) > 0:
            logger.warning(
                f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
            )
        cross_attention_kwargs = {k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters}
    
>       return self.processor(
            self,
            hidden_states,
            encoder_hidden_states=encoder_hidden_states,
            attention_mask=attention_mask,
            **cross_attention_kwargs,
        )
E       TypeError: PAGJointAttnProcessor2_0.__call__() got an unexpected keyword argument 'attention_mask'

src/diffusers/models/attention_processor.py:530: TypeError
======================================================================= short test summary info =======================================================================
FAILED tests/pipelines/pag/test_pag_sd3_img2img.py::StableDiffusion3PAGImg2ImgPipelineIntegrationTests::test_pag_uncond - TypeError: PAGJointAttnProcessor2_0.__call__() got an unexpected keyword argument 'attention_mask'

An alternative to adding this particular keyword argument would be to catch all other keyword arguments with **kwargs, which there is precedent for in other attention processors, but I generally default to being more restrictive and not less. For whatever it's worth, PAGCFGJointAttnProcessor2_0 does both of those things; it captures attention_mask and does nothing with it, and also has *args and **kwargs.

If there is any particular way that you think is the most in-line with the rest of the codebase, I'll be happy to adjust.

) -> torch.FloatTensor:
residual = hidden_states

Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@
"KolorsPAGPipeline",
"HunyuanDiTPAGPipeline",
"StableDiffusion3PAGPipeline",
"StableDiffusion3PAGImg2ImgPipeline",
"StableDiffusionPAGPipeline",
"StableDiffusionPAGImg2ImgPipeline",
"StableDiffusionControlNetPAGPipeline",
Expand Down Expand Up @@ -579,6 +580,7 @@
HunyuanDiTPAGPipeline,
KolorsPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/auto_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
from .pag import (
HunyuanDiTPAGPipeline,
PixArtSigmaPAGPipeline,
StableDiffusion3PAGImg2ImgPipeline,
StableDiffusion3PAGPipeline,
StableDiffusionControlNetPAGInpaintPipeline,
StableDiffusionControlNetPAGPipeline,
Expand Down Expand Up @@ -129,6 +130,7 @@
("stable-diffusion", StableDiffusionImg2ImgPipeline),
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline),
("if", IFImg2ImgPipeline),
("kandinsky", KandinskyImg2ImgCombinedPipeline),
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/pipelines/pag/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
_import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
_import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
_import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
_import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
_import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
_import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
Expand All @@ -54,6 +55,7 @@
from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
from .pipeline_pag_sd import StableDiffusionPAGPipeline
from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline
from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
Expand Down
Loading