Skip to content

Commit 93fe711

Browse files
stainless-botRobertCraigie
authored andcommitted
feat: use numpy for faster embeddings decoding
1 parent 316b6df commit 93fe711

File tree

3 files changed

+69
-22
lines changed

3 files changed

+69
-22
lines changed

src/openai/_extras/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
from .numpy_proxy import numpy as numpy
2+
from .numpy_proxy import has_numpy as has_numpy
23
from .pandas_proxy import pandas as pandas

src/openai/_extras/numpy_proxy.py

+9
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,12 @@ def __load__(self) -> Any:
2727

2828
if not TYPE_CHECKING:
2929
numpy = NumpyProxy()
30+
31+
32+
def has_numpy() -> bool:
33+
try:
34+
import numpy # noqa: F401 # pyright: ignore[reportUnusedImport]
35+
except ImportError:
36+
return False
37+
38+
return True

src/openai/resources/embeddings.py

+59-22
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22

33
from __future__ import annotations
44

5-
from typing import List, Union
5+
import base64
6+
from typing import List, Union, cast
67
from typing_extensions import Literal
78

89
from ..types import CreateEmbeddingResponse, embedding_create_params
910
from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
10-
from .._utils import maybe_transform
11+
from .._utils import is_given, maybe_transform
12+
from .._extras import numpy as np
13+
from .._extras import has_numpy
1114
from .._resource import SyncAPIResource, AsyncAPIResource
1215
from .._base_client import make_request_options
1316

@@ -61,23 +64,40 @@ def create(
6164
6265
timeout: Override the client-level default timeout for this request, in seconds
6366
"""
64-
return self._post(
67+
params = {
68+
"input": input,
69+
"model": model,
70+
"user": user,
71+
"encoding_format": encoding_format,
72+
}
73+
if not is_given(encoding_format) and has_numpy():
74+
params["encoding_format"] = "base64"
75+
76+
response = self._post(
6577
"/embeddings",
66-
body=maybe_transform(
67-
{
68-
"input": input,
69-
"model": model,
70-
"encoding_format": encoding_format,
71-
"user": user,
72-
},
73-
embedding_create_params.EmbeddingCreateParams,
74-
),
78+
body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams),
7579
options=make_request_options(
7680
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
7781
),
7882
cast_to=CreateEmbeddingResponse,
7983
)
8084

85+
if is_given(encoding_format):
86+
# don't modify the response object if a user explicitly asked for a format
87+
return response
88+
89+
for embedding in response.data:
90+
data = cast(object, embedding.embedding)
91+
if not isinstance(data, str):
92+
# numpy is not installed / base64 optimisation isn't enabled for this model yet
93+
continue
94+
95+
embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
96+
base64.b64decode(data), dtype="float32"
97+
).tolist()
98+
99+
return response
100+
81101

82102
class AsyncEmbeddings(AsyncAPIResource):
83103
async def create(
@@ -126,19 +146,36 @@ async def create(
126146
127147
timeout: Override the client-level default timeout for this request, in seconds
128148
"""
129-
return await self._post(
149+
params = {
150+
"input": input,
151+
"model": model,
152+
"user": user,
153+
"encoding_format": encoding_format,
154+
}
155+
if not is_given(encoding_format) and has_numpy():
156+
params["encoding_format"] = "base64"
157+
158+
response = await self._post(
130159
"/embeddings",
131-
body=maybe_transform(
132-
{
133-
"input": input,
134-
"model": model,
135-
"encoding_format": encoding_format,
136-
"user": user,
137-
},
138-
embedding_create_params.EmbeddingCreateParams,
139-
),
160+
body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams),
140161
options=make_request_options(
141162
extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
142163
),
143164
cast_to=CreateEmbeddingResponse,
144165
)
166+
167+
if is_given(encoding_format):
168+
# don't modify the response object if a user explicitly asked for a format
169+
return response
170+
171+
for embedding in response.data:
172+
data = cast(object, embedding.embedding)
173+
if not isinstance(data, str):
174+
# numpy is not installed / base64 optimisation isn't enabled for this model yet
175+
continue
176+
177+
embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call]
178+
base64.b64decode(data), dtype="float32"
179+
).tolist()
180+
181+
return response

0 commit comments

Comments
 (0)