-
Notifications
You must be signed in to change notification settings - Fork 6k
add PAG support #7944
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
add PAG support #7944
Changes from 11 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
a6a0429
first draft
yiyixuxu 3605df9
refactor
yiyixuxu f94376c
update
yiyixuxu 54c3fd6
up
yiyixuxu f571430
style
yiyixuxu 91d0a5b
style
01585ab
update
yiyixuxu 03bdbcd
inpaint + controlnet
yiyixuxu b662207
Merge branch 'pag' of github.com:huggingface/diffusers into pag
yiyixuxu 1fb2c33
style
219f4b9
up
5641cb4
Update src/diffusers/pipelines/pag_utils.py
yiyixuxu 8950e80
fix controlnet
KKIEEK 4cc0b8b
fix compatability issue between PAG and IP-adapter (#8379)
sunovivid 5cbf226
up
yiyixuxu 58804a0
refactor ip-adapter
yiyixuxu 7bc9229
style
e09e079
Merge branch 'main' into pag
yiyixuxu 1fa54df
style
ba366f0
u[
d5a6761
up
854b70e
fix
9e4c1b6
add controlnet pag
623d237
copy
f30c2bc
add from pipe test for pag + controlnet
1df4391
up
191505e
support guess mode
yiyixuxu 58b8330
style
71cf2f7
add pag + img2img
6da3bb6
Merge branch 'main' into pag
sayakpaul 1e79c59
remove guess model support from pag controlnet pipeline
yiyixuxu 14b4ddd
noise_pred_uncond -> noise_pred_text
yiyixuxu 91c41e8
Apply suggestions from code review
yiyixuxu b72ef1c
fix more
yiyixuxu b7f4ccd
Merge branch 'main' into pag
d12b4a0
update docstring example
28e1301
add copied from
5653b2a
add doc
17520f2
up
e11180a
Merge branch 'main' into pag
074a4f0
fix copies
18d8b0e
up
0e337bf
up
434f63a
up
41b1ddc
up
24cadb4
Update docs/source/en/api/pipelines/pag.md
yiyixuxu c4ceee9
Apply suggestions from code review
yiyixuxu 9db27cf
Update src/diffusers/models/attention_processor.py
yiyixuxu 8ae87e2
add a tip about extending pag support and explain pag scale
19eb55f
Merge branch 'pag' of github.com:huggingface/diffusers into pag
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2651,6 +2651,232 @@ def __call__( | |
return hidden_states | ||
|
||
|
||
class PAGIdentitySelfAttnProcessor2_0: | ||
r""" | ||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | ||
""" | ||
|
||
def __init__(self): | ||
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.FloatTensor, | ||
encoder_hidden_states: Optional[torch.FloatTensor] = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
temb: Optional[torch.FloatTensor] = None, | ||
*args, | ||
**kwargs, | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> torch.FloatTensor: | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(args) > 0 or kwargs.get("scale", None) is not None: | ||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | ||
deprecate("scale", "1.0.0", deprecation_message) | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
residual = hidden_states | ||
if attn.spatial_norm is not None: | ||
hidden_states = attn.spatial_norm(hidden_states, temb) | ||
|
||
input_ndim = hidden_states.ndim | ||
if input_ndim == 4: | ||
batch_size, channel, height, width = hidden_states.shape | ||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||
|
||
# chunk | ||
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2) | ||
|
||
# original path | ||
batch_size, sequence_length, _ = hidden_states_org.shape | ||
|
||
if attention_mask is not None: | ||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||
|
||
if attn.group_norm is not None: | ||
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) | ||
|
||
query = attn.to_q(hidden_states_org) | ||
key = attn.to_k(hidden_states_org) | ||
value = attn.to_v(hidden_states_org) | ||
|
||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
|
||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
hidden_states_org = F.scaled_dot_product_attention( | ||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||
) | ||
|
||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states_org = hidden_states_org.to(query.dtype) | ||
|
||
# linear proj | ||
hidden_states_org = attn.to_out[0](hidden_states_org) | ||
# dropout | ||
hidden_states_org = attn.to_out[1](hidden_states_org) | ||
|
||
if input_ndim == 4: | ||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
|
||
# perturbed path (identity attention) | ||
batch_size, sequence_length, _ = hidden_states_ptb.shape | ||
|
||
if attention_mask is not None: | ||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||
|
||
if attn.group_norm is not None: | ||
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) | ||
|
||
value = attn.to_v(hidden_states_ptb) | ||
|
||
# hidden_states_ptb = torch.zeros(value.shape).to(value.get_device()) | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
hidden_states_ptb = value | ||
|
||
hidden_states_ptb = hidden_states_ptb.to(query.dtype) | ||
|
||
# linear proj | ||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb) | ||
# dropout | ||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit): I think these lines of code could be clubbed together and accompanied with a comment. |
||
|
||
if input_ndim == 4: | ||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
|
||
# cat | ||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) | ||
|
||
if attn.residual_connection: | ||
hidden_states = hidden_states + residual | ||
|
||
hidden_states = hidden_states / attn.rescale_output_factor | ||
|
||
return hidden_states | ||
|
||
|
||
class PAGCFGIdentitySelfAttnProcessor2_0: | ||
r""" | ||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
""" | ||
|
||
def __init__(self): | ||
if not hasattr(F, "scaled_dot_product_attention"): | ||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.") | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
def __call__( | ||
self, | ||
attn: Attention, | ||
hidden_states: torch.FloatTensor, | ||
encoder_hidden_states: Optional[torch.FloatTensor] = None, | ||
attention_mask: Optional[torch.FloatTensor] = None, | ||
temb: Optional[torch.FloatTensor] = None, | ||
*args, | ||
**kwargs, | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
) -> torch.FloatTensor: | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
if len(args) > 0 or kwargs.get("scale", None) is not None: | ||
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." | ||
deprecate("scale", "1.0.0", deprecation_message) | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
residual = hidden_states | ||
if attn.spatial_norm is not None: | ||
hidden_states = attn.spatial_norm(hidden_states, temb) | ||
|
||
input_ndim = hidden_states.ndim | ||
if input_ndim == 4: | ||
batch_size, channel, height, width = hidden_states.shape | ||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) | ||
|
||
# chunk | ||
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3) | ||
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org]) | ||
|
||
# original path | ||
batch_size, sequence_length, _ = hidden_states_org.shape | ||
|
||
if attention_mask is not None: | ||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||
|
||
if attn.group_norm is not None: | ||
hidden_states_org = attn.group_norm(hidden_states_org.transpose(1, 2)).transpose(1, 2) | ||
|
||
query = attn.to_q(hidden_states_org) | ||
key = attn.to_k(hidden_states_org) | ||
value = attn.to_v(hidden_states_org) | ||
|
||
inner_dim = key.shape[-1] | ||
head_dim = inner_dim // attn.heads | ||
|
||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | ||
|
||
# the output of sdp = (batch, num_heads, seq_len, head_dim) | ||
# TODO: add support for attn.scale when we move to Torch 2.1 | ||
hidden_states_org = F.scaled_dot_product_attention( | ||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | ||
) | ||
|
||
hidden_states_org = hidden_states_org.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim) | ||
hidden_states_org = hidden_states_org.to(query.dtype) | ||
|
||
# linear proj | ||
hidden_states_org = attn.to_out[0](hidden_states_org) | ||
# dropout | ||
hidden_states_org = attn.to_out[1](hidden_states_org) | ||
|
||
if input_ndim == 4: | ||
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
|
||
# perturbed path (identity attention) | ||
batch_size, sequence_length, _ = hidden_states_ptb.shape | ||
|
||
if attention_mask is not None: | ||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) | ||
# scaled_dot_product_attention expects attention_mask shape to be | ||
# (batch, heads, source_length, target_length) | ||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1]) | ||
|
||
if attn.group_norm is not None: | ||
hidden_states_ptb = attn.group_norm(hidden_states_ptb.transpose(1, 2)).transpose(1, 2) | ||
|
||
value = attn.to_v(hidden_states_ptb) | ||
hidden_states_ptb = value | ||
hidden_states_ptb = hidden_states_ptb.to(query.dtype) | ||
|
||
# linear proj | ||
hidden_states_ptb = attn.to_out[0](hidden_states_ptb) | ||
# dropout | ||
hidden_states_ptb = attn.to_out[1](hidden_states_ptb) | ||
|
||
if input_ndim == 4: | ||
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width) | ||
|
||
# cat | ||
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb]) | ||
yiyixuxu marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
if attn.residual_connection: | ||
hidden_states = hidden_states + residual | ||
|
||
hidden_states = hidden_states / attn.rescale_output_factor | ||
|
||
return hidden_states | ||
|
||
|
||
LORA_ATTENTION_PROCESSORS = ( | ||
LoRAAttnProcessor, | ||
LoRAAttnProcessor2_0, | ||
|
@@ -2691,6 +2917,8 @@ def __call__( | |
CustomDiffusionAttnProcessor, | ||
CustomDiffusionXFormersAttnProcessor, | ||
CustomDiffusionAttnProcessor2_0, | ||
PAGCFGIdentitySelfAttnProcessor2_0, | ||
PAGIdentitySelfAttnProcessor2_0, | ||
# deprecated | ||
LoRAAttnProcessor, | ||
LoRAAttnProcessor2_0, | ||
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.