Skip to content

Commit 2d45a2a

Browse files
noooopliuzijing2014
authored andcommitted
[Frontend] Using matryoshka_dimensions control the allowed output dimensions. (vllm-project#16970)
1 parent 2363551 commit 2d45a2a

File tree

8 files changed

+177
-81
lines changed

8 files changed

+177
-81
lines changed

docs/source/models/pooling_models.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,14 +159,14 @@ For example, setting `dimensions` parameter while using the `BAAI/bge-m3` model
159159

160160
### Manually enable Matryoshka Embeddings
161161

162-
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, we simply check the existence of the fields `is_matryoshka` or `matryoshka_dimensions` inside `config.json`.
162+
There is currently no official interface for specifying support for Matryoshka Embeddings. In vLLM, if `is_matryoshka` is `True` in `config.json,` it is allowed to change the output to arbitrary dimensions. Using `matryoshka_dimensions` can control the allowed output dimensions.
163163

164-
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}` (offline) or `--hf_overrides '{"is_matryoshka": true}'` (online).
164+
For models that support Matryoshka Embeddings but not recognized by vLLM, please manually override the config using `hf_overrides={"is_matryoshka": True}`, `hf_overrides={"matryoshka_dimensions": [<allowed output dimensions>]}` (offline) or `--hf_overrides '{"is_matryoshka": true}'`, `--hf_overrides '{"matryoshka_dimensions": [<allowed output dimensions>]}'`(online).
165165

166166
Here is an example to serve a model with Matryoshka Embeddings enabled.
167167

168168
```text
169-
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"is_matryoshka":true}'
169+
vllm serve Snowflake/snowflake-arctic-embed-m-v1.5 --hf_overrides '{"matryoshka_dimensions":[256]}'
170170
```
171171

172172
### Offline Inference
@@ -204,14 +204,14 @@ curl http://127.0.0.1:8000/v1/embeddings \
204204
"input": "Follow the white rabbit.",
205205
"model": "jinaai/jina-embeddings-v3",
206206
"encoding_format": "float",
207-
"dimensions": 1
207+
"dimensions": 32
208208
}'
209209
```
210210

211211
Expected output:
212212

213213
```json
214-
{"id":"embd-0aab28c384d348c3b8f0eb783109dc5f","object":"list","created":1744195454,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-1.0]}],"usage":{"prompt_tokens":10,"total_tokens":10,"completion_tokens":0,"prompt_tokens_details":null}}
214+
{"id":"embd-5c21fc9a5c9d4384a1b021daccaf9f64","object":"list","created":1745476417,"model":"jinaai/jina-embeddings-v3","data":[{"index":0,"object":"embedding","embedding":[-0.3828125,-0.1357421875,0.03759765625,0.125,0.21875,0.09521484375,-0.003662109375,0.1591796875,-0.130859375,-0.0869140625,-0.1982421875,0.1689453125,-0.220703125,0.1728515625,-0.2275390625,-0.0712890625,-0.162109375,-0.283203125,-0.055419921875,-0.0693359375,0.031982421875,-0.04052734375,-0.2734375,0.1826171875,-0.091796875,0.220703125,0.37890625,-0.0888671875,-0.12890625,-0.021484375,-0.0091552734375,0.23046875]}],"usage":{"prompt_tokens":8,"total_tokens":8,"completion_tokens":0,"prompt_tokens_details":null}}
215215
```
216216

217217
A openai client example can be found here: <gh-file:examples/online_serving/openai_embedding_matryoshka_fy.py>

examples/online_serving/openai_embedding_matryoshka_fy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,11 +25,11 @@ def main():
2525
responses = client.embeddings.create(
2626
input=["Follow the white rabbit."],
2727
model=model,
28-
dimensions=1,
28+
dimensions=32,
2929
)
3030

3131
for data in responses.data:
32-
print(data.embedding) # List of float of len 1
32+
print(data.embedding) # List of float of len 32
3333

3434

3535
if __name__ == "__main__":

tests/entrypoints/openai/test_embedding.py

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,12 @@
1111
from vllm.entrypoints.openai.protocol import EmbeddingResponse
1212
from vllm.transformers_utils.tokenizer import get_tokenizer
1313

14-
from ...models.embedding.utils import check_embeddings_close
14+
from ...models.embedding.utils import correctness_test
1515
from ...utils import RemoteOpenAIServer
1616

1717
MODEL_NAME = "intfloat/multilingual-e5-small"
1818
DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501
19+
DTYPE = "bfloat16"
1920

2021

2122
@pytest.fixture(scope="module")
@@ -25,7 +26,7 @@ def server():
2526
"embed",
2627
# use half precision for speed and memory savings in CI environment
2728
"--dtype",
28-
"bfloat16",
29+
DTYPE,
2930
"--enforce-eager",
3031
"--max-model-len",
3132
"512",
@@ -43,9 +44,17 @@ async def client(server):
4344
yield async_client
4445

4546

47+
@pytest.fixture(scope="module")
48+
def hf_model(hf_runner):
49+
with hf_runner(MODEL_NAME, dtype=DTYPE,
50+
is_sentence_transformer=True) as hf_model:
51+
yield hf_model
52+
53+
4654
@pytest.mark.asyncio
4755
@pytest.mark.parametrize("model_name", [MODEL_NAME])
48-
async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
56+
async def test_single_embedding(hf_model, client: openai.AsyncOpenAI,
57+
model_name: str):
4958
input_texts = [
5059
"The chef prepared a delicious meal.",
5160
]
@@ -66,6 +75,9 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
6675
assert embeddings.usage.prompt_tokens == 11
6776
assert embeddings.usage.total_tokens == 11
6877

78+
vllm_outputs = [d.embedding for d in embeddings.data]
79+
correctness_test(hf_model, input_texts, vllm_outputs)
80+
6981
# test using token IDs
7082
input_tokens = [1, 1, 1, 1, 1]
7183
embedding_response = await client.embeddings.create(
@@ -86,7 +98,8 @@ async def test_single_embedding(client: openai.AsyncOpenAI, model_name: str):
8698

8799
@pytest.mark.asyncio
88100
@pytest.mark.parametrize("model_name", [MODEL_NAME])
89-
async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
101+
async def test_batch_embedding(hf_model, client: openai.AsyncOpenAI,
102+
model_name: str):
90103
# test list[str]
91104
input_texts = [
92105
"The cat sat on the mat.", "A feline was resting on a rug.",
@@ -107,6 +120,9 @@ async def test_batch_embedding(client: openai.AsyncOpenAI, model_name: str):
107120
assert embeddings.usage.prompt_tokens == 33
108121
assert embeddings.usage.total_tokens == 33
109122

123+
vllm_outputs = [d.embedding for d in embeddings.data]
124+
correctness_test(hf_model, input_texts, vllm_outputs)
125+
110126
# test list[list[int]]
111127
input_tokens = [[4, 5, 7, 9, 20], [15, 29, 499], [24, 24, 24, 24, 24],
112128
[25, 32, 64, 77]]
@@ -181,7 +197,7 @@ async def test_conversation_embedding(server: RemoteOpenAIServer,
181197

182198
@pytest.mark.asyncio
183199
@pytest.mark.parametrize("model_name", [MODEL_NAME])
184-
async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
200+
async def test_batch_base64_embedding(hf_model, client: openai.AsyncOpenAI,
185201
model_name: str):
186202
input_texts = [
187203
"Hello my name is",
@@ -192,6 +208,7 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
192208
model=model_name,
193209
encoding_format="float")
194210
float_data = [d.embedding for d in responses_float.data]
211+
correctness_test(hf_model, input_texts, float_data)
195212

196213
responses_base64 = await client.embeddings.create(input=input_texts,
197214
model=model_name,
@@ -202,24 +219,13 @@ async def test_batch_base64_embedding(client: openai.AsyncOpenAI,
202219
np.frombuffer(base64.b64decode(data.embedding),
203220
dtype="float32").tolist())
204221

205-
check_embeddings_close(
206-
embeddings_0_lst=float_data,
207-
embeddings_1_lst=base64_data,
208-
name_0="float",
209-
name_1="base64",
210-
)
222+
correctness_test(hf_model, input_texts, base64_data)
211223

212224
# Default response is float32 decoded from base64 by OpenAI Client
213225
responses_default = await client.embeddings.create(input=input_texts,
214226
model=model_name)
215227
default_data = [d.embedding for d in responses_default.data]
216-
217-
check_embeddings_close(
218-
embeddings_0_lst=float_data,
219-
embeddings_1_lst=default_data,
220-
name_0="float",
221-
name_1="default",
222-
)
228+
correctness_test(hf_model, input_texts, default_data)
223229

224230

225231
@pytest.mark.asyncio

tests/entrypoints/openai/test_embedding_dimensions.py

Lines changed: 91 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,73 +3,121 @@
33
Run `pytest tests/entrypoints/openai/test_embedding_dimensions.py`.
44
"""
55

6+
from typing import Optional
7+
68
import openai
79
import pytest
810

911
from vllm.entrypoints.openai.protocol import EmbeddingResponse
1012

11-
from ...models.embedding.utils import EmbedModelInfo
13+
from ...conftest import HfRunner
14+
from ...models.embedding.utils import EmbedModelInfo, correctness_test
1215
from ...utils import RemoteOpenAIServer
1316

1417
MODELS = [
15-
EmbedModelInfo(name="BAAI/bge-m3", is_matryoshka=False),
16-
EmbedModelInfo(name="jinaai/jina-embeddings-v3", is_matryoshka=True),
18+
EmbedModelInfo("intfloat/multilingual-e5-small", is_matryoshka=False),
19+
EmbedModelInfo("Snowflake/snowflake-arctic-embed-m-v1.5",
20+
is_matryoshka=True,
21+
matryoshka_dimensions=[256]),
1722
]
1823

1924
input_texts = [
2025
"The chef prepared a delicious meal.",
21-
] * 3
26+
]
2227

2328

24-
@pytest.mark.asyncio
25-
@pytest.mark.parametrize("model", MODELS)
26-
async def test_validating_dimensions(model: EmbedModelInfo):
29+
@pytest.fixture(scope="module", params=MODELS)
30+
def model_info(request):
31+
return request.param
32+
33+
34+
@pytest.fixture(scope="module", params=["bfloat16"])
35+
def dtype(request):
36+
return request.param
37+
38+
39+
@pytest.fixture(scope="module")
40+
def server(model_info, dtype: str):
2741
args = [
2842
"--task",
2943
"embed",
3044
# use half precision for speed and memory savings in CI environment
3145
"--dtype",
32-
"bfloat16",
46+
dtype,
3347
"--enforce-eager",
3448
"--max-model-len",
35-
"512",
36-
"--trust_remote_code"
49+
"512"
3750
]
38-
with RemoteOpenAIServer(model.name, args) as remote_server:
39-
client = remote_server.get_async_client()
40-
41-
async def make_request(dimensions):
42-
embedding_response = await client.embeddings.create(
43-
model=model.name,
44-
input=input_texts,
45-
dimensions=dimensions,
46-
encoding_format="float",
47-
)
48-
embeddings = EmbeddingResponse.model_validate(
49-
embedding_response.model_dump(mode="json"))
50-
51-
assert embeddings.id is not None
52-
assert len(embeddings.data) == 3
53-
assert len(embeddings.data[0].embedding) > 0
54-
assert embeddings.usage.completion_tokens == 0
55-
assert embeddings.usage.prompt_tokens > 0
56-
assert embeddings.usage.total_tokens > 0
57-
58-
if dimensions is not None:
59-
assert len(embeddings.data[0].embedding) == dimensions
60-
61-
if model.is_matryoshka:
62-
for dimensions in [None, 16]:
63-
await make_request(dimensions)
6451

52+
if model_info.name == "Snowflake/snowflake-arctic-embed-m-v1.5":
53+
# Manually enable Matryoshka Embeddings
54+
args.extend([
55+
"--trust_remote_code", "--hf_overrides",
56+
'{"matryoshka_dimensions":[256]}'
57+
])
58+
59+
with RemoteOpenAIServer(model_info.name, args) as remote_server:
60+
yield remote_server
61+
62+
63+
@pytest.fixture(scope="module")
64+
def hf_model(hf_runner, model_info, dtype: str):
65+
with hf_runner(model_info.name, dtype=dtype,
66+
is_sentence_transformer=True) as hf_model:
67+
yield hf_model
68+
69+
70+
@pytest.mark.asyncio
71+
async def test_matryoshka(model_info: EmbedModelInfo,
72+
server: RemoteOpenAIServer, hf_model: HfRunner):
73+
client = server.get_async_client()
74+
75+
async def make_request_and_correctness_test(dimensions):
76+
prompts = input_texts * 3
77+
78+
embedding_response = await client.embeddings.create(
79+
model=model_info.name,
80+
input=prompts,
81+
dimensions=dimensions,
82+
encoding_format="float",
83+
)
84+
embeddings = EmbeddingResponse.model_validate(
85+
embedding_response.model_dump(mode="json"))
86+
87+
assert embeddings.id is not None
88+
assert len(embeddings.data) == 3
89+
assert len(embeddings.data[0].embedding) > 0
90+
assert embeddings.usage.completion_tokens == 0
91+
assert embeddings.usage.prompt_tokens > 0
92+
assert embeddings.usage.total_tokens > 0
93+
94+
if dimensions is not None:
95+
assert len(embeddings.data[0].embedding) == dimensions
96+
97+
vllm_outputs = [d.embedding for d in embeddings.data]
98+
correctness_test(hf_model, prompts, vllm_outputs, dimensions)
99+
100+
if model_info.is_matryoshka:
101+
valid_dimensions: list[Optional[int]] = [None]
102+
if model_info.matryoshka_dimensions is not None:
103+
valid_dimensions += model_info.matryoshka_dimensions[:2]
104+
105+
for dimensions in valid_dimensions:
106+
await make_request_and_correctness_test(dimensions)
107+
108+
invalid_dimensions: list[Optional[int]] = [-1]
109+
if model_info.matryoshka_dimensions is not None:
110+
assert 5 not in model_info.matryoshka_dimensions
111+
invalid_dimensions.append(5)
112+
113+
for dimensions in invalid_dimensions:
65114
with pytest.raises(openai.BadRequestError):
66-
for dimensions in [-1]:
67-
await make_request(dimensions)
115+
await make_request_and_correctness_test(dimensions)
68116

69-
else:
70-
for dimensions in [None]:
71-
await make_request(dimensions)
117+
else:
118+
for dimensions in [None]:
119+
await make_request_and_correctness_test(dimensions)
72120

121+
for dimensions in [-1, 16]:
73122
with pytest.raises(openai.BadRequestError):
74-
for dimensions in [-1, 16]:
75-
await make_request(dimensions)
123+
await make_request_and_correctness_test(dimensions)

tests/models/embedding/language/test_jina.py

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,14 +153,24 @@ def test_matryoshka(
153153

154154
with vllm_runner(model, task="embed", dtype=dtype,
155155
max_model_len=None) as vllm_model:
156-
vllm_outputs = vllm_model.encode(
157-
example_prompts,
158-
pooling_params=PoolingParams(dimensions=dimensions))
159-
160-
check_embeddings_close(
161-
embeddings_0_lst=hf_outputs,
162-
embeddings_1_lst=vllm_outputs,
163-
name_0="hf",
164-
name_1="vllm",
165-
tol=1e-2,
166-
)
156+
matryoshka_dimensions = (
157+
vllm_model.model.llm_engine.model_config.matryoshka_dimensions)
158+
assert matryoshka_dimensions is not None
159+
160+
if dimensions not in matryoshka_dimensions:
161+
with pytest.raises(ValueError):
162+
vllm_model.encode(
163+
example_prompts,
164+
pooling_params=PoolingParams(dimensions=dimensions))
165+
else:
166+
vllm_outputs = vllm_model.encode(
167+
example_prompts,
168+
pooling_params=PoolingParams(dimensions=dimensions))
169+
170+
check_embeddings_close(
171+
embeddings_0_lst=hf_outputs,
172+
embeddings_1_lst=vllm_outputs,
173+
name_0="hf",
174+
name_1="vllm",
175+
tol=1e-2,
176+
)

tests/models/embedding/utils.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# SPDX-License-Identifier: Apache-2.0
22

33
from collections.abc import Sequence
4-
from typing import NamedTuple
4+
from typing import NamedTuple, Optional
55

66
import torch
77
import torch.nn.functional as F
@@ -43,5 +43,24 @@ def matryoshka_fy(tensor, dimensions):
4343
class EmbedModelInfo(NamedTuple):
4444
name: str
4545
is_matryoshka: bool
46+
matryoshka_dimensions: Optional[list[int]] = None
4647
architecture: str = ""
4748
enable_test: bool = True
49+
50+
51+
def correctness_test(hf_model,
52+
inputs,
53+
vllm_outputs: Sequence[list[float]],
54+
dimensions: Optional[int] = None):
55+
56+
hf_outputs = hf_model.encode(inputs)
57+
if dimensions:
58+
hf_outputs = matryoshka_fy(hf_outputs, dimensions)
59+
60+
check_embeddings_close(
61+
embeddings_0_lst=hf_outputs,
62+
embeddings_1_lst=vllm_outputs,
63+
name_0="hf",
64+
name_1="vllm",
65+
tol=1e-2,
66+
)

0 commit comments

Comments
 (0)