Skip to content

Commit ba3d8aa

Browse files
authored
[Model][MiniCPM] support MiniCPM (#645)
### What this PR does / why we need it? This pr support minicpm in branch main. see #164 ### How was this patch tested? test locally with minicpm --------- Signed-off-by: MengqingCao <[email protected]>
1 parent 742f679 commit ba3d8aa

File tree

5 files changed

+134
-0
lines changed

5 files changed

+134
-0
lines changed

vllm_ascend/patch/__init__.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@
9696
# Related PR (if no, explain why): no related PR, we want add this ability into vllm
9797
# Future Plan:
9898
# Remove those patch when vllm merged them
99+
#
100+
#
99101
# * Worker Patch:
100102
# ===============
101103
# ** File: worker/patch_0_8_4/patch_metrics.py **
@@ -125,6 +127,20 @@
125127
# Future Plan:
126128
# Revert it when the related pr is merged in vllm.
127129
#
130+
# ** File: worker/patch_common/patch_minicpm.py **
131+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
132+
# 1. `vllm.model_executor.models.minicpm.MiniCPMAttention.forward`
133+
# Why:
134+
# The forward func of MiniCPMAttention in vllm do a datatype convert
135+
# (original datatype --> float32) to ensure the precision on cuda.
136+
# However float32 is not supported in cann rope op, thus we keep this patch
137+
# How:
138+
# Removed the dtype convert operations in forward
139+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
140+
# NO, only for npu due to rope op.
141+
# Future Plan:
142+
# Keep this patch in vllm-ascend.
143+
#
128144
# ** File: worker/patch_common/patch_multi_step_worker.py **
129145
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
130146
# 1. `vllm.spec_decode.multi_step_worker.MultiStepWorker.sampler_output`
@@ -156,3 +172,15 @@
156172
# - https://github.com/vllm-project/vllm-ascend/pull/395
157173
# Future Plan:
158174
# Revert it when the related pr is merged in vllm and vllm-ascend.
175+
#
176+
# ** File: worker/patch_0_8_4/patch_tritonplaceholder.py **
177+
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
178+
# 1. `triton` Module
179+
# Why:
180+
# Triton is not supported on npu currently, importing triton will break vllm-ascend
181+
# How:
182+
# ditto
183+
# Related PR (if no, explain why): 1. refused by vllm. 2. vllm doesn't support 3. prepare to submit....
184+
# TritonPlaceholder is only available in vllm>0.8.4
185+
# Future Plan:
186+
# Revert it when branch main doesn't maintain v0.8.4.

vllm_ascend/patch/worker/patch_0_8_4/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
#
1717

1818
import vllm_ascend.patch.worker.patch_0_8_4.patch_metrics # noqa
19+
import vllm_ascend.patch.worker.patch_0_8_4.patch_tritonplaceholder # noqa
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# Copyright 2023 The vLLM team.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
# This file is a part of the vllm-ascend project.
17+
# Adapted from vllm/triton_utils/importing.py
18+
#
19+
20+
import sys
21+
import types
22+
from importlib.util import find_spec
23+
24+
from vllm.logger import logger
25+
26+
HAS_TRITON = (
27+
find_spec("triton") is not None
28+
or find_spec("pytorch-triton-xpu") is not None # Not compatible
29+
)
30+
31+
if not HAS_TRITON:
32+
logger.info("Triton not installed or not compatible; certain GPU-related"
33+
" functions will not be available.")
34+
35+
class TritonPlaceholder(types.ModuleType):
36+
37+
def __init__(self):
38+
super().__init__("triton")
39+
self.jit = self._dummy_decorator("jit")
40+
self.autotune = self._dummy_decorator("autotune")
41+
self.heuristics = self._dummy_decorator("heuristics")
42+
self.language = TritonLanguagePlaceholder()
43+
logger.warning_once(
44+
"Triton is not installed. Using dummy decorators. "
45+
"Install it via `pip install triton` to enable kernel"
46+
" compilation.")
47+
48+
def _dummy_decorator(self, name):
49+
50+
def decorator(func=None, **kwargs):
51+
if func is None:
52+
return lambda f: f
53+
return func
54+
55+
return decorator
56+
57+
class TritonLanguagePlaceholder(types.ModuleType):
58+
59+
def __init__(self):
60+
super().__init__("triton.language")
61+
self.constexpr = None
62+
self.dtype = None
63+
64+
sys.modules['triton'] = TritonPlaceholder()
65+
sys.modules['triton.language'] = TritonLanguagePlaceholder()
66+
67+
if 'triton' in sys.modules:
68+
logger.info("Triton module has been replaced with a placeholder.")

vllm_ascend/patch/worker/patch_common/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,5 +16,6 @@
1616
#
1717

1818
import vllm_ascend.patch.worker.patch_common.patch_metrics # noqa
19+
import vllm_ascend.patch.worker.patch_common.patch_minicpm # noqa
1920
import vllm_ascend.patch.worker.patch_common.patch_multi_step_worker # noqa
2021
import vllm_ascend.patch.worker.patch_common.patch_spec_decode_worker # noqa
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#
2+
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
3+
# This file is a part of the vllm-ascend project.
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
#
17+
18+
import torch
19+
from vllm.model_executor.models.minicpm import MiniCPMAttention
20+
21+
22+
def forward(
23+
self,
24+
positions: torch.Tensor,
25+
hidden_states: torch.Tensor,
26+
) -> torch.Tensor:
27+
qkv, _ = self.qkv_proj(hidden_states)
28+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
29+
q, k = self.rotary_emb(positions, q, k)
30+
attn_output = self.attn(q, k, v)
31+
output, _ = self.o_proj(attn_output)
32+
return output
33+
34+
35+
# The type conversion in the forward function is deleted to support the rope operator.
36+
MiniCPMAttention.forward = forward

0 commit comments

Comments
 (0)