@@ -1893,6 +1893,112 @@ def __call__(
1893
1893
return hidden_states
1894
1894
1895
1895
1896
+ class FluxAttnProcessor2_0_NPU :
1897
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1898
+
1899
+ def __init__ (self ):
1900
+ if not hasattr (F , "scaled_dot_product_attention" ):
1901
+ raise ImportError (
1902
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0 and install torch NPU"
1903
+ )
1904
+
1905
+ def __call__ (
1906
+ self ,
1907
+ attn : Attention ,
1908
+ hidden_states : torch .FloatTensor ,
1909
+ encoder_hidden_states : torch .FloatTensor = None ,
1910
+ attention_mask : Optional [torch .FloatTensor ] = None ,
1911
+ image_rotary_emb : Optional [torch .Tensor ] = None ,
1912
+ ) -> torch .FloatTensor :
1913
+ batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
1914
+
1915
+ # `sample` projections.
1916
+ query = attn .to_q (hidden_states )
1917
+ key = attn .to_k (hidden_states )
1918
+ value = attn .to_v (hidden_states )
1919
+
1920
+ inner_dim = key .shape [- 1 ]
1921
+ head_dim = inner_dim // attn .heads
1922
+
1923
+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1924
+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1925
+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
1926
+
1927
+ if attn .norm_q is not None :
1928
+ query = attn .norm_q (query )
1929
+ if attn .norm_k is not None :
1930
+ key = attn .norm_k (key )
1931
+
1932
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
1933
+ if encoder_hidden_states is not None :
1934
+ # `context` projections.
1935
+ encoder_hidden_states_query_proj = attn .add_q_proj (encoder_hidden_states )
1936
+ encoder_hidden_states_key_proj = attn .add_k_proj (encoder_hidden_states )
1937
+ encoder_hidden_states_value_proj = attn .add_v_proj (encoder_hidden_states )
1938
+
1939
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
1940
+ batch_size , - 1 , attn .heads , head_dim
1941
+ ).transpose (1 , 2 )
1942
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
1943
+ batch_size , - 1 , attn .heads , head_dim
1944
+ ).transpose (1 , 2 )
1945
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
1946
+ batch_size , - 1 , attn .heads , head_dim
1947
+ ).transpose (1 , 2 )
1948
+
1949
+ if attn .norm_added_q is not None :
1950
+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
1951
+ if attn .norm_added_k is not None :
1952
+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
1953
+
1954
+ # attention
1955
+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
1956
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
1957
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
1958
+
1959
+ if image_rotary_emb is not None :
1960
+ from .embeddings import apply_rotary_emb
1961
+
1962
+ query = apply_rotary_emb (query , image_rotary_emb )
1963
+ key = apply_rotary_emb (key , image_rotary_emb )
1964
+
1965
+ if query .dtype in (torch .float16 , torch .bfloat16 ):
1966
+ hidden_states = torch_npu .npu_fusion_attention (
1967
+ query ,
1968
+ key ,
1969
+ value ,
1970
+ attn .heads ,
1971
+ input_layout = "BNSD" ,
1972
+ pse = None ,
1973
+ scale = 1.0 / math .sqrt (query .shape [- 1 ]),
1974
+ pre_tockens = 65536 ,
1975
+ next_tockens = 65536 ,
1976
+ keep_prob = 1.0 ,
1977
+ sync = False ,
1978
+ inner_precise = 0 ,
1979
+ )[0 ]
1980
+ else :
1981
+ hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
1982
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
1983
+ hidden_states = hidden_states .to (query .dtype )
1984
+
1985
+ if encoder_hidden_states is not None :
1986
+ encoder_hidden_states , hidden_states = (
1987
+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
1988
+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
1989
+ )
1990
+
1991
+ # linear proj
1992
+ hidden_states = attn .to_out [0 ](hidden_states )
1993
+ # dropout
1994
+ hidden_states = attn .to_out [1 ](hidden_states )
1995
+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
1996
+
1997
+ return hidden_states , encoder_hidden_states
1998
+ else :
1999
+ return hidden_states
2000
+
2001
+
1896
2002
class FusedFluxAttnProcessor2_0 :
1897
2003
"""Attention processor used typically in processing the SD3-like self-attention projections."""
1898
2004
@@ -1987,6 +2093,117 @@ def __call__(
1987
2093
return hidden_states
1988
2094
1989
2095
2096
+ class FusedFluxAttnProcessor2_0_NPU :
2097
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
2098
+
2099
+ def __init__ (self ):
2100
+ if not hasattr (F , "scaled_dot_product_attention" ):
2101
+ raise ImportError (
2102
+ "FluxAttnProcessor2_0_NPU requires PyTorch 2.0 and torch NPU, to use it, please upgrade PyTorch to 2.0, and install torch NPU"
2103
+ )
2104
+
2105
+ def __call__ (
2106
+ self ,
2107
+ attn : Attention ,
2108
+ hidden_states : torch .FloatTensor ,
2109
+ encoder_hidden_states : torch .FloatTensor = None ,
2110
+ attention_mask : Optional [torch .FloatTensor ] = None ,
2111
+ image_rotary_emb : Optional [torch .Tensor ] = None ,
2112
+ ) -> torch .FloatTensor :
2113
+ batch_size , _ , _ = hidden_states .shape if encoder_hidden_states is None else encoder_hidden_states .shape
2114
+
2115
+ # `sample` projections.
2116
+ qkv = attn .to_qkv (hidden_states )
2117
+ split_size = qkv .shape [- 1 ] // 3
2118
+ query , key , value = torch .split (qkv , split_size , dim = - 1 )
2119
+
2120
+ inner_dim = key .shape [- 1 ]
2121
+ head_dim = inner_dim // attn .heads
2122
+
2123
+ query = query .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2124
+ key = key .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2125
+ value = value .view (batch_size , - 1 , attn .heads , head_dim ).transpose (1 , 2 )
2126
+
2127
+ if attn .norm_q is not None :
2128
+ query = attn .norm_q (query )
2129
+ if attn .norm_k is not None :
2130
+ key = attn .norm_k (key )
2131
+
2132
+ # the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
2133
+ # `context` projections.
2134
+ if encoder_hidden_states is not None :
2135
+ encoder_qkv = attn .to_added_qkv (encoder_hidden_states )
2136
+ split_size = encoder_qkv .shape [- 1 ] // 3
2137
+ (
2138
+ encoder_hidden_states_query_proj ,
2139
+ encoder_hidden_states_key_proj ,
2140
+ encoder_hidden_states_value_proj ,
2141
+ ) = torch .split (encoder_qkv , split_size , dim = - 1 )
2142
+
2143
+ encoder_hidden_states_query_proj = encoder_hidden_states_query_proj .view (
2144
+ batch_size , - 1 , attn .heads , head_dim
2145
+ ).transpose (1 , 2 )
2146
+ encoder_hidden_states_key_proj = encoder_hidden_states_key_proj .view (
2147
+ batch_size , - 1 , attn .heads , head_dim
2148
+ ).transpose (1 , 2 )
2149
+ encoder_hidden_states_value_proj = encoder_hidden_states_value_proj .view (
2150
+ batch_size , - 1 , attn .heads , head_dim
2151
+ ).transpose (1 , 2 )
2152
+
2153
+ if attn .norm_added_q is not None :
2154
+ encoder_hidden_states_query_proj = attn .norm_added_q (encoder_hidden_states_query_proj )
2155
+ if attn .norm_added_k is not None :
2156
+ encoder_hidden_states_key_proj = attn .norm_added_k (encoder_hidden_states_key_proj )
2157
+
2158
+ # attention
2159
+ query = torch .cat ([encoder_hidden_states_query_proj , query ], dim = 2 )
2160
+ key = torch .cat ([encoder_hidden_states_key_proj , key ], dim = 2 )
2161
+ value = torch .cat ([encoder_hidden_states_value_proj , value ], dim = 2 )
2162
+
2163
+ if image_rotary_emb is not None :
2164
+ from .embeddings import apply_rotary_emb
2165
+
2166
+ query = apply_rotary_emb (query , image_rotary_emb )
2167
+ key = apply_rotary_emb (key , image_rotary_emb )
2168
+
2169
+ if query .dtype in (torch .float16 , torch .bfloat16 ):
2170
+ hidden_states = torch_npu .npu_fusion_attention (
2171
+ query ,
2172
+ key ,
2173
+ value ,
2174
+ attn .heads ,
2175
+ input_layout = "BNSD" ,
2176
+ pse = None ,
2177
+ scale = 1.0 / math .sqrt (query .shape [- 1 ]),
2178
+ pre_tockens = 65536 ,
2179
+ next_tockens = 65536 ,
2180
+ keep_prob = 1.0 ,
2181
+ sync = False ,
2182
+ inner_precise = 0 ,
2183
+ )[0 ]
2184
+ else :
2185
+ hidden_states = F .scaled_dot_product_attention (query , key , value , dropout_p = 0.0 , is_causal = False )
2186
+
2187
+ hidden_states = hidden_states .transpose (1 , 2 ).reshape (batch_size , - 1 , attn .heads * head_dim )
2188
+ hidden_states = hidden_states .to (query .dtype )
2189
+
2190
+ if encoder_hidden_states is not None :
2191
+ encoder_hidden_states , hidden_states = (
2192
+ hidden_states [:, : encoder_hidden_states .shape [1 ]],
2193
+ hidden_states [:, encoder_hidden_states .shape [1 ] :],
2194
+ )
2195
+
2196
+ # linear proj
2197
+ hidden_states = attn .to_out [0 ](hidden_states )
2198
+ # dropout
2199
+ hidden_states = attn .to_out [1 ](hidden_states )
2200
+ encoder_hidden_states = attn .to_add_out (encoder_hidden_states )
2201
+
2202
+ return hidden_states , encoder_hidden_states
2203
+ else :
2204
+ return hidden_states
2205
+
2206
+
1990
2207
class CogVideoXAttnProcessor2_0 :
1991
2208
r"""
1992
2209
Processor for implementing scaled dot-product attention for the CogVideoX model. It applies a rotary embedding on
0 commit comments