Skip to content

Commit 69f19ca

Browse files
LucasWilkinsonjeejeeleeaarnphm
authored andcommitted
[BugFix] Fix vllm_flash_attn install issues (vllm-project#17267)
Signed-off-by: Lucas Wilkinson <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Co-authored-by: Aaron Pham <[email protected]>
1 parent d338563 commit 69f19ca

File tree

11 files changed

+28
-284
lines changed

11 files changed

+28
-284
lines changed

.github/CODEOWNERS

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
/vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth
1313
/vllm/model_executor/guided_decoding @mgoin @russellb
1414
/vllm/multimodal @DarkLight1337 @ywang96
15+
/vllm/vllm_flash_attn @LucasWilkinson
1516
CMakeLists.txt @tlrmchlsmth
1617

1718
# vLLM V1

.gitignore

-2
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
# vllm-flash-attn built from source
55
vllm/vllm_flash_attn/*
6-
!vllm/vllm_flash_attn/__init__.py
7-
!vllm/vllm_flash_attn/fa_utils.py
86

97
# Byte-compiled / optimized / DLL files
108
__pycache__/

setup.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -269,15 +269,17 @@ def run(self):
269269
# First, run the standard build_ext command to compile the extensions
270270
super().run()
271271

272-
# copy vllm/vllm_flash_attn/*.py from self.build_lib to current
272+
# copy vllm/vllm_flash_attn/**/*.py from self.build_lib to current
273273
# directory so that they can be included in the editable build
274274
import glob
275-
files = glob.glob(
276-
os.path.join(self.build_lib, "vllm", "vllm_flash_attn", "*.py"))
275+
files = glob.glob(os.path.join(self.build_lib, "vllm",
276+
"vllm_flash_attn", "**", "*.py"),
277+
recursive=True)
277278
for file in files:
278279
dst_file = os.path.join("vllm/vllm_flash_attn",
279-
os.path.basename(file))
280+
file.split("vllm/vllm_flash_attn/")[-1])
280281
print(f"Copying {file} to {dst_file}")
282+
os.makedirs(os.path.dirname(dst_file), exist_ok=True)
281283
self.copy_file(file, dst_file)
282284

283285

@@ -377,12 +379,22 @@ def run(self) -> None:
377379
"vllm/_flashmla_C.abi3.so",
378380
"vllm/vllm_flash_attn/_vllm_fa2_C.abi3.so",
379381
"vllm/vllm_flash_attn/_vllm_fa3_C.abi3.so",
380-
"vllm/vllm_flash_attn/flash_attn_interface.py",
381382
"vllm/cumem_allocator.abi3.so",
382383
# "vllm/_version.py", # not available in nightly wheels yet
383384
]
384-
file_members = filter(lambda x: x.filename in files_to_copy,
385-
wheel.filelist)
385+
386+
file_members = list(
387+
filter(lambda x: x.filename in files_to_copy, wheel.filelist))
388+
389+
# vllm_flash_attn python code:
390+
# Regex from
391+
# `glob.translate('vllm/vllm_flash_attn/**/*.py', recursive=True)`
392+
import re
393+
compiled_regex = re.compile(
394+
r"vllm/vllm_flash_attn/(?:[^/.][^/]*/)*(?!\.)[^/]*\.py")
395+
file_members += list(
396+
filter(lambda x: compiled_regex.match(x.filename),
397+
wheel.filelist))
386398

387399
for file in file_members:
388400
print(f"Extracting and including {file.filename} "

vllm/attention/backends/flash_attn.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens,
2323
get_seq_len_block_table_args, is_all_cross_attn_metadata_set,
2424
is_all_encoder_attn_metadata_set, is_block_tables_empty)
25+
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
26+
get_flash_attn_version)
2527
from vllm.logger import init_logger
2628
from vllm.multimodal import MultiModalPlaceholderMap
2729
from vllm.utils import async_tensor_h2d, make_tensor_with_pad
2830
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
2931
flash_attn_with_kvcache)
30-
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
31-
get_flash_attn_version)
3232

3333
if TYPE_CHECKING:
3434
from vllm.worker.model_runner import (ModelInputForGPUBuilder,
@@ -689,7 +689,7 @@ def forward(
689689
assert output is not None, "Output tensor must be provided."
690690

691691
# NOTE(woosuk): FlashAttention2 does not support FP8 KV cache.
692-
if self.vllm_flash_attn_version < 3 or output.dtype != torch.bfloat16:
692+
if not flash_attn_supports_fp8() or output.dtype != torch.bfloat16:
693693
assert (
694694
layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0), (
695695
"key/v_scale is only supported in FlashAttention 3 with "

vllm/attention/backends/mla/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,7 @@
205205
compute_slot_mapping_start_idx,
206206
is_block_tables_empty)
207207
from vllm.attention.ops.merge_attn_states import merge_attn_states
208+
from vllm.attention.utils.fa_utils import get_flash_attn_version
208209
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
209210
LinearBase, RowParallelLinear,
210211
UnquantizedLinearMethod)
@@ -214,7 +215,6 @@
214215
from vllm.platforms import current_platform
215216
from vllm.triton_utils import HAS_TRITON
216217
from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
217-
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
218218

219219
if HAS_TRITON:
220220
from vllm.attention.ops.triton_flash_attention import triton_attention
File renamed without changes.

vllm/engine/arg_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1377,7 +1377,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
13771377
) or envs.VLLM_ATTENTION_BACKEND == "FLASH_ATTN_VLLM_V1"
13781378
supported = False
13791379
if fp8_attention and will_use_fa:
1380-
from vllm.vllm_flash_attn.fa_utils import (
1380+
from vllm.attention.utils.fa_utils import (
13811381
flash_attn_supports_fp8)
13821382
supported = flash_attn_supports_fp8()
13831383
if not supported:

vllm/v1/attention/backends/flash_attn.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@
1111
AttentionMetadata, AttentionType,
1212
is_quantized_kv_cache)
1313
from vllm.attention.ops.merge_attn_states import merge_attn_states
14+
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
15+
get_flash_attn_version)
1416
from vllm.logger import init_logger
1517
from vllm.platforms import current_platform
1618
from vllm.utils import cdiv
17-
from vllm.vllm_flash_attn.fa_utils import (flash_attn_supports_fp8,
18-
get_flash_attn_version)
1919

2020
if TYPE_CHECKING:
2121
from vllm.v1.core.sched.output import SchedulerOutput

vllm/v1/attention/backends/mla/common.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -197,14 +197,14 @@
197197
MLAAttentionImpl)
198198
from vllm.attention.backends.utils import get_mla_dims
199199
from vllm.attention.ops.merge_attn_states import merge_attn_states
200+
from vllm.attention.utils.fa_utils import get_flash_attn_version
200201
from vllm.logger import init_logger
201202
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
202203
LinearBase, RowParallelLinear,
203204
UnquantizedLinearMethod)
204205
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
205206
from vllm.platforms import current_platform
206207
from vllm.utils import cdiv, round_down
207-
from vllm.vllm_flash_attn.fa_utils import get_flash_attn_version
208208

209209
try:
210210
from vllm.vllm_flash_attn import flash_attn_varlen_func

vllm/vllm_flash_attn/__init__.py

-22
This file was deleted.

0 commit comments

Comments
 (0)