|
2 | 2 |
|
3 | 3 | from __future__ import annotations
|
4 | 4 |
|
5 |
| -from typing import List, Union |
| 5 | +import base64 |
| 6 | +from typing import List, Union, cast |
6 | 7 | from typing_extensions import Literal
|
7 | 8 |
|
8 | 9 | from ..types import CreateEmbeddingResponse, embedding_create_params
|
9 | 10 | from .._types import NOT_GIVEN, Body, Query, Headers, NotGiven
|
10 |
| -from .._utils import maybe_transform |
| 11 | +from .._utils import is_given, maybe_transform |
| 12 | +from .._extras import numpy as np |
| 13 | +from .._extras import has_numpy |
11 | 14 | from .._resource import SyncAPIResource, AsyncAPIResource
|
12 | 15 | from .._base_client import make_request_options
|
13 | 16 |
|
@@ -61,23 +64,40 @@ def create(
|
61 | 64 |
|
62 | 65 | timeout: Override the client-level default timeout for this request, in seconds
|
63 | 66 | """
|
64 |
| - return self._post( |
| 67 | + params = { |
| 68 | + "input": input, |
| 69 | + "model": model, |
| 70 | + "user": user, |
| 71 | + "encoding_format": encoding_format, |
| 72 | + } |
| 73 | + if not is_given(encoding_format) and has_numpy(): |
| 74 | + params["encoding_format"] = "base64" |
| 75 | + |
| 76 | + response = self._post( |
65 | 77 | "/embeddings",
|
66 |
| - body=maybe_transform( |
67 |
| - { |
68 |
| - "input": input, |
69 |
| - "model": model, |
70 |
| - "encoding_format": encoding_format, |
71 |
| - "user": user, |
72 |
| - }, |
73 |
| - embedding_create_params.EmbeddingCreateParams, |
74 |
| - ), |
| 78 | + body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams), |
75 | 79 | options=make_request_options(
|
76 | 80 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
77 | 81 | ),
|
78 | 82 | cast_to=CreateEmbeddingResponse,
|
79 | 83 | )
|
80 | 84 |
|
| 85 | + if is_given(encoding_format): |
| 86 | + # don't modify the response object if a user explicitly asked for a format |
| 87 | + return response |
| 88 | + |
| 89 | + for embedding in response.data: |
| 90 | + data = cast(object, embedding.embedding) |
| 91 | + if not isinstance(data, str): |
| 92 | + # numpy is not installed / base64 optimisation isn't enabled for this model yet |
| 93 | + continue |
| 94 | + |
| 95 | + embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call] |
| 96 | + base64.b64decode(data), dtype="float32" |
| 97 | + ).tolist() |
| 98 | + |
| 99 | + return response |
| 100 | + |
81 | 101 |
|
82 | 102 | class AsyncEmbeddings(AsyncAPIResource):
|
83 | 103 | async def create(
|
@@ -126,19 +146,36 @@ async def create(
|
126 | 146 |
|
127 | 147 | timeout: Override the client-level default timeout for this request, in seconds
|
128 | 148 | """
|
129 |
| - return await self._post( |
| 149 | + params = { |
| 150 | + "input": input, |
| 151 | + "model": model, |
| 152 | + "user": user, |
| 153 | + "encoding_format": encoding_format, |
| 154 | + } |
| 155 | + if not is_given(encoding_format) and has_numpy(): |
| 156 | + params["encoding_format"] = "base64" |
| 157 | + |
| 158 | + response = await self._post( |
130 | 159 | "/embeddings",
|
131 |
| - body=maybe_transform( |
132 |
| - { |
133 |
| - "input": input, |
134 |
| - "model": model, |
135 |
| - "encoding_format": encoding_format, |
136 |
| - "user": user, |
137 |
| - }, |
138 |
| - embedding_create_params.EmbeddingCreateParams, |
139 |
| - ), |
| 160 | + body=maybe_transform(params, embedding_create_params.EmbeddingCreateParams), |
140 | 161 | options=make_request_options(
|
141 | 162 | extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout
|
142 | 163 | ),
|
143 | 164 | cast_to=CreateEmbeddingResponse,
|
144 | 165 | )
|
| 166 | + |
| 167 | + if is_given(encoding_format): |
| 168 | + # don't modify the response object if a user explicitly asked for a format |
| 169 | + return response |
| 170 | + |
| 171 | + for embedding in response.data: |
| 172 | + data = cast(object, embedding.embedding) |
| 173 | + if not isinstance(data, str): |
| 174 | + # numpy is not installed / base64 optimisation isn't enabled for this model yet |
| 175 | + continue |
| 176 | + |
| 177 | + embedding.embedding = np.frombuffer( # type: ignore[no-untyped-call] |
| 178 | + base64.b64decode(data), dtype="float32" |
| 179 | + ).tolist() |
| 180 | + |
| 181 | + return response |
0 commit comments