-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[Bugfix] Fix vllm_flash_attn rotary import #17247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,6 +280,14 @@ def run(self): | |
print(f"Copying {file} to {dst_file}") | ||
self.copy_file(file, dst_file) | ||
|
||
# copy these folders to use the vllm_flash_attn rotary_kernel. | ||
for folder in ("layers", "ops"): | ||
src = os.path.join(self.build_lib, "vllm", "vllm_flash_attn", | ||
folder) | ||
out = os.path.join("vllm", "vllm_flash_attn", folder) | ||
print(f"Copying {folder} from vllm/vllm_flash_attn") | ||
self.copy_tree(src, out) | ||
|
||
|
||
class repackage_wheel(build_ext): | ||
"""Extracts libraries and other files from an existing wheel.""" | ||
|
@@ -400,6 +408,31 @@ def run(self) -> None: | |
|
||
package_data[package_name].append(file_name) | ||
|
||
# Extract and include the layers and ops of rotary embedding. | ||
folders_to_copy = {"layers", "ops"} | ||
for folder in folders_to_copy: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also seems unnecessarily messing / complicated, I think we can just use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. i agree, overlook on my end. |
||
folder_path = f"vllm/vllm_flash_attn/{folder}" | ||
folder_files = [ | ||
f for f in wheel.filelist | ||
if f.filename.startswith(folder_path) | ||
] | ||
|
||
if folder_files: | ||
print(f"Include {folder} folder from vllm/vllm_flash_attn") | ||
for file in folder_files: | ||
wheel.extract(file) | ||
|
||
# Add the file to package_data if it's not a Python file | ||
rel_path = file.filename.split("/") | ||
# vllm/vllm_flash_attn/folder/file | ||
if len(rel_path) >= 4: | ||
package_name = "vllm.vllm_flash_attn." + folder | ||
file_name = rel_path[-1] | ||
if not file_name.endswith(".py"): | ||
if package_name not in package_data: | ||
package_data[package_name] = [] | ||
package_data[package_name].append(file_name) | ||
|
||
|
||
def _is_hpu() -> bool: | ||
# if VLLM_TARGET_DEVICE env var was set explicitly, skip HPU autodetection | ||
|
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems unnecessary, why not just do:
on line 275-276 instead?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this will only grab all of the python files, then we still need to check if the folder exists or not, then copy over here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would this conflict with #17260?