Skip to content

Commit ad754e6

Browse files
leisuzz蒋硕
authored andcommitted
NPU Adaption for FLUX (#9751)
* NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX * NPU implementation for FLUX --------- Co-authored-by: 蒋硕 <[email protected]>
1 parent c538dea commit ad754e6

File tree

3 files changed

+243
-3
lines changed

3 files changed

+243
-3
lines changed

examples/dreambooth/train_dreambooth_flux.py

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@
5757
is_wandb_available,
5858
)
5959
from diffusers.utils.hub_utils import load_or_create_model_card, populate_model_card
60+
from diffusers.utils.import_utils import is_torch_npu_available
6061
from diffusers.utils.torch_utils import is_compiled_module
6162

6263

@@ -68,6 +69,12 @@
6869

6970
logger = get_logger(__name__)
7071

72+
if is_torch_npu_available():
73+
import torch_npu
74+
75+
torch.npu.config.allow_internal_format = False
76+
torch.npu.set_compile_mode(jit_compile=False)
77+
7178

7279
def save_model_card(
7380
repo_id: str,
@@ -189,6 +196,8 @@ def log_validation(
189196
del pipeline
190197
if torch.cuda.is_available():
191198
torch.cuda.empty_cache()
199+
elif is_torch_npu_available():
200+
torch_npu.npu.empty_cache()
192201

193202
return images
194203

@@ -1035,7 +1044,9 @@ def main(args):
10351044
cur_class_images = len(list(class_images_dir.iterdir()))
10361045

10371046
if cur_class_images < args.num_class_images:
1038-
has_supported_fp16_accelerator = torch.cuda.is_available() or torch.backends.mps.is_available()
1047+
has_supported_fp16_accelerator = (
1048+
torch.cuda.is_available() or torch.backends.mps.is_available() or is_torch_npu_available()
1049+
)
10391050
torch_dtype = torch.float16 if has_supported_fp16_accelerator else torch.float32
10401051
if args.prior_generation_precision == "fp32":
10411052
torch_dtype = torch.float32
@@ -1073,6 +1084,8 @@ def main(args):
10731084
del pipeline
10741085
if torch.cuda.is_available():
10751086
torch.cuda.empty_cache()
1087+
elif is_torch_npu_available():
1088+
torch_npu.npu.empty_cache()
10761089

10771090
# Handle the repository creation
10781091
if accelerator.is_main_process:
@@ -1354,6 +1367,8 @@ def compute_text_embeddings(prompt, text_encoders, tokenizers):
13541367
gc.collect()
13551368
if torch.cuda.is_available():
13561369
torch.cuda.empty_cache()
1370+
elif is_torch_npu_available():
1371+
torch_npu.npu.empty_cache()
13571372

13581373
# If custom instance prompts are NOT provided (i.e. the instance prompt is used for all images),
13591374
# pack the statically computed variables appropriately here. This is so that we don't
@@ -1719,7 +1734,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
17191734
)
17201735
if not args.train_text_encoder:
17211736
del text_encoder_one, text_encoder_two
1722-
torch.cuda.empty_cache()
1737+
if torch.cuda.is_available():
1738+
torch.cuda.empty_cache()
1739+
elif is_torch_npu_available():
1740+
torch_npu.npu.empty_cache()
17231741
gc.collect()
17241742

17251743
# Save the lora layers

src/diffusers/models/attention_processor.py

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1893,6 +1893,112 @@ def __call__(
18931893
return hidden_states
18941894

18951895

1896+
class FluxAttnProcessor2_0_NPU:
1897+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1898+
1899+
def __init__(self):
1900+
if not hasattr(F, "scaled_dot_product_attention"):
1901+
raise ImportError(
1902+
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
1903+
)
1904+
1905+
def __call__(
1906+
self,
1907+
attn: Attention,
1908+
hidden_states: torch.FloatTensor,
1909+
encoder_hidden_states: torch.FloatTensor = None,
1910+
attention_mask: Optional[torch.FloatTensor] = None,
1911+
image_rotary_emb: Optional[torch.Tensor] = None,
1912+
) -> torch.FloatTensor:
1913+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
1914+
1915+
# `sample` projections.
1916+
query = attn.to_q(hidden_states)
1917+
key = attn.to_k(hidden_states)
1918+
value = attn.to_v(hidden_states)
1919+
1920+
inner_dim = key.shape[-1]
1921+
head_dim = inner_dim // attn.heads
1922+
1923+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1924+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1925+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1926+
1927+
if attn.norm_q is not None:
1928+
query = attn.norm_q(query)
1929+
if attn.norm_k is not None:
1930+
key = attn.norm_k(key)
1931+
1932+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1933+
if encoder_hidden_states is not None:
1934+
# `context` projections.
1935+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1936+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1937+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1938+
1939+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
1940+
batch_size, -1, attn.heads, head_dim
1941+
).transpose(1, 2)
1942+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
1943+
batch_size, -1, attn.heads, head_dim
1944+
).transpose(1, 2)
1945+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
1946+
batch_size, -1, attn.heads, head_dim
1947+
).transpose(1, 2)
1948+
1949+
if attn.norm_added_q is not None:
1950+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
1951+
if attn.norm_added_k is not None:
1952+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
1953+
1954+
# attention
1955+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
1956+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
1957+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
1958+
1959+
if image_rotary_emb is not None:
1960+
from .embeddings import apply_rotary_emb
1961+
1962+
query = apply_rotary_emb(query, image_rotary_emb)
1963+
key = apply_rotary_emb(key, image_rotary_emb)
1964+
1965+
if query.dtype in (torch.float16, torch.bfloat16):
1966+
hidden_states = torch_npu.npu_fusion_attention(
1967+
query,
1968+
key,
1969+
value,
1970+
attn.heads,
1971+
input_layout="BNSD",
1972+
pse=None,
1973+
scale=1.0 / math.sqrt(query.shape[-1]),
1974+
pre_tockens=65536,
1975+
next_tockens=65536,
1976+
keep_prob=1.0,
1977+
sync=False,
1978+
inner_precise=0,
1979+
)[0]
1980+
else:
1981+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
1982+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1983+
hidden_states = hidden_states.to(query.dtype)
1984+
1985+
if encoder_hidden_states is not None:
1986+
encoder_hidden_states, hidden_states = (
1987+
hidden_states[:, : encoder_hidden_states.shape[1]],
1988+
hidden_states[:, encoder_hidden_states.shape[1] :],
1989+
)
1990+
1991+
# linear proj
1992+
hidden_states = attn.to_out[0](hidden_states)
1993+
# dropout
1994+
hidden_states = attn.to_out[1](hidden_states)
1995+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1996+
1997+
return hidden_states, encoder_hidden_states
1998+
else:
1999+
return hidden_states
2000+
2001+
18962002
class FusedFluxAttnProcessor2_0:
18972003
"""Attention processor used typically in processing the SD3-like self-attention projections."""
18982004

@@ -1987,6 +2093,117 @@ def __call__(
19872093
return hidden_states
19882094

19892095

2096+
class FusedFluxAttnProcessor2_0_NPU:
2097+
"""Attention processor used typically in processing the SD3-like self-attention projections."""
2098+
2099+
def __init__(self):
2100+
if not hasattr(F, "scaled_dot_product_attention"):
2101+
raise ImportError(
2102+
"FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2103+
)
2104+
2105+
def __call__(
2106+
self,
2107+
attn: Attention,
2108+
hidden_states: torch.FloatTensor,
2109+
encoder_hidden_states: torch.FloatTensor = None,
2110+
attention_mask: Optional[torch.FloatTensor] = None,
2111+
image_rotary_emb: Optional[torch.Tensor] = None,
2112+
) -> torch.FloatTensor:
2113+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2114+
2115+
# `sample` projections.
2116+
qkv = attn.to_qkv(hidden_states)
2117+
split_size = qkv.shape[-1] // 3
2118+
query, key, value = torch.split(qkv, split_size, dim=-1)
2119+
2120+
inner_dim = key.shape[-1]
2121+
head_dim = inner_dim // attn.heads
2122+
2123+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2124+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2125+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2126+
2127+
if attn.norm_q is not None:
2128+
query = attn.norm_q(query)
2129+
if attn.norm_k is not None:
2130+
key = attn.norm_k(key)
2131+
2132+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2133+
# `context` projections.
2134+
if encoder_hidden_states is not None:
2135+
encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
2136+
split_size = encoder_qkv.shape[-1] // 3
2137+
(
2138+
encoder_hidden_states_query_proj,
2139+
encoder_hidden_states_key_proj,
2140+
encoder_hidden_states_value_proj,
2141+
) = torch.split(encoder_qkv, split_size, dim=-1)
2142+
2143+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
2144+
batch_size, -1, attn.heads, head_dim
2145+
).transpose(1, 2)
2146+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
2147+
batch_size, -1, attn.heads, head_dim
2148+
).transpose(1, 2)
2149+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
2150+
batch_size, -1, attn.heads, head_dim
2151+
).transpose(1, 2)
2152+
2153+
if attn.norm_added_q is not None:
2154+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
2155+
if attn.norm_added_k is not None:
2156+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
2157+
2158+
# attention
2159+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
2160+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
2161+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
2162+
2163+
if image_rotary_emb is not None:
2164+
from .embeddings import apply_rotary_emb
2165+
2166+
query = apply_rotary_emb(query, image_rotary_emb)
2167+
key = apply_rotary_emb(key, image_rotary_emb)
2168+
2169+
if query.dtype in (torch.float16, torch.bfloat16):
2170+
hidden_states = torch_npu.npu_fusion_attention(
2171+
query,
2172+
key,
2173+
value,
2174+
attn.heads,
2175+
input_layout="BNSD",
2176+
pse=None,
2177+
scale=1.0 / math.sqrt(query.shape[-1]),
2178+
pre_tockens=65536,
2179+
next_tockens=65536,
2180+
keep_prob=1.0,
2181+
sync=False,
2182+
inner_precise=0,
2183+
)[0]
2184+
else:
2185+
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
2186+
2187+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2188+
hidden_states = hidden_states.to(query.dtype)
2189+
2190+
if encoder_hidden_states is not None:
2191+
encoder_hidden_states, hidden_states = (
2192+
hidden_states[:, : encoder_hidden_states.shape[1]],
2193+
hidden_states[:, encoder_hidden_states.shape[1] :],
2194+
)
2195+
2196+
# linear proj
2197+
hidden_states = attn.to_out[0](hidden_states)
2198+
# dropout
2199+
hidden_states = attn.to_out[1](hidden_states)
2200+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
2201+
2202+
return hidden_states, encoder_hidden_states
2203+
else:
2204+
return hidden_states
2205+
2206+
19902207
class CogVideoXAttnProcessor2_0:
19912208
r"""
19922209
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,11 +27,13 @@
2727
Attention,
2828
AttentionProcessor,
2929
FluxAttnProcessor2_0,
30+
FluxAttnProcessor2_0_NPU,
3031
FusedFluxAttnProcessor2_0,
3132
)
3233
from ...models.modeling_utils import ModelMixin
3334
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
3435
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
36+
from ...utils.import_utils import is_torch_npu_available
3537
from ...utils.torch_utils import maybe_allow_in_graph
3638
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
3739
from ..modeling_outputs import Transformer2DModelOutput
@@ -64,7 +66,10 @@ def __init__(self, dim, num_attention_heads, attention_head_dim, mlp_ratio=4.0):
6466
self.act_mlp = nn.GELU(approximate="tanh")
6567
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
6668

67-
processor = FluxAttnProcessor2_0()
69+
if is_torch_npu_available():
70+
processor = FluxAttnProcessor2_0_NPU()
71+
else:
72+
processor = FluxAttnProcessor2_0()
6873
self.attn = Attention(
6974
query_dim=dim,
7075
cross_attention_dim=None,

0 commit comments

Comments
 (0)