Skip to content

Commit 28cc146

Browse files
[azure] enable audio/whisper support (openai#613)
* enable azure for audio * reorder overloads * add additional tests * add helper function to utils * simplify - azure users will just need to pass model and deployment_id
1 parent b0dd091 commit 28cc146

File tree

1 file changed

+35
-9
lines changed

1 file changed

+35
-9
lines changed

openai/api_resources/audio.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@ class Audio(APIResource):
99
OBJECT_NAME = "audio"
1010

1111
@classmethod
12-
def _get_url(cls, action):
12+
def _get_url(cls, action, deployment_id=None, api_type=None, api_version=None):
13+
if api_type in (util.ApiType.AZURE, util.ApiType.AZURE_AD):
14+
return f"/{cls.azure_api_prefix}/deployments/{deployment_id}/audio/{action}?api-version={api_version}"
1315
return cls.class_url() + f"/{action}"
1416

1517
@classmethod
@@ -50,6 +52,8 @@ def transcribe(
5052
api_type=None,
5153
api_version=None,
5254
organization=None,
55+
*,
56+
deployment_id=None,
5357
**params,
5458
):
5559
requestor, files, data = cls._prepare_request(
@@ -63,7 +67,8 @@ def transcribe(
6367
organization=organization,
6468
**params,
6569
)
66-
url = cls._get_url("transcriptions")
70+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
71+
url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
6772
response, _, api_key = requestor.request("post", url, files=files, params=data)
6873
return util.convert_to_openai_object(
6974
response, api_key, api_version, organization
@@ -79,6 +84,8 @@ def translate(
7984
api_type=None,
8085
api_version=None,
8186
organization=None,
87+
*,
88+
deployment_id=None,
8289
**params,
8390
):
8491
requestor, files, data = cls._prepare_request(
@@ -92,7 +99,8 @@ def translate(
9299
organization=organization,
93100
**params,
94101
)
95-
url = cls._get_url("translations")
102+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
103+
url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
96104
response, _, api_key = requestor.request("post", url, files=files, params=data)
97105
return util.convert_to_openai_object(
98106
response, api_key, api_version, organization
@@ -109,6 +117,8 @@ def transcribe_raw(
109117
api_type=None,
110118
api_version=None,
111119
organization=None,
120+
*,
121+
deployment_id=None,
112122
**params,
113123
):
114124
requestor, files, data = cls._prepare_request(
@@ -122,7 +132,8 @@ def transcribe_raw(
122132
organization=organization,
123133
**params,
124134
)
125-
url = cls._get_url("transcriptions")
135+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
136+
url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
126137
response, _, api_key = requestor.request("post", url, files=files, params=data)
127138
return util.convert_to_openai_object(
128139
response, api_key, api_version, organization
@@ -139,6 +150,8 @@ def translate_raw(
139150
api_type=None,
140151
api_version=None,
141152
organization=None,
153+
*,
154+
deployment_id=None,
142155
**params,
143156
):
144157
requestor, files, data = cls._prepare_request(
@@ -152,7 +165,8 @@ def translate_raw(
152165
organization=organization,
153166
**params,
154167
)
155-
url = cls._get_url("translations")
168+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
169+
url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
156170
response, _, api_key = requestor.request("post", url, files=files, params=data)
157171
return util.convert_to_openai_object(
158172
response, api_key, api_version, organization
@@ -168,6 +182,8 @@ async def atranscribe(
168182
api_type=None,
169183
api_version=None,
170184
organization=None,
185+
*,
186+
deployment_id=None,
171187
**params,
172188
):
173189
requestor, files, data = cls._prepare_request(
@@ -181,7 +197,8 @@ async def atranscribe(
181197
organization=organization,
182198
**params,
183199
)
184-
url = cls._get_url("transcriptions")
200+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
201+
url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
185202
response, _, api_key = await requestor.arequest(
186203
"post", url, files=files, params=data
187204
)
@@ -199,6 +216,8 @@ async def atranslate(
199216
api_type=None,
200217
api_version=None,
201218
organization=None,
219+
*,
220+
deployment_id=None,
202221
**params,
203222
):
204223
requestor, files, data = cls._prepare_request(
@@ -212,7 +231,8 @@ async def atranslate(
212231
organization=organization,
213232
**params,
214233
)
215-
url = cls._get_url("translations")
234+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
235+
url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
216236
response, _, api_key = await requestor.arequest(
217237
"post", url, files=files, params=data
218238
)
@@ -231,6 +251,8 @@ async def atranscribe_raw(
231251
api_type=None,
232252
api_version=None,
233253
organization=None,
254+
*,
255+
deployment_id=None,
234256
**params,
235257
):
236258
requestor, files, data = cls._prepare_request(
@@ -244,7 +266,8 @@ async def atranscribe_raw(
244266
organization=organization,
245267
**params,
246268
)
247-
url = cls._get_url("transcriptions")
269+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
270+
url = cls._get_url("transcriptions", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
248271
response, _, api_key = await requestor.arequest(
249272
"post", url, files=files, params=data
250273
)
@@ -263,6 +286,8 @@ async def atranslate_raw(
263286
api_type=None,
264287
api_version=None,
265288
organization=None,
289+
*,
290+
deployment_id=None,
266291
**params,
267292
):
268293
requestor, files, data = cls._prepare_request(
@@ -276,7 +301,8 @@ async def atranslate_raw(
276301
organization=organization,
277302
**params,
278303
)
279-
url = cls._get_url("translations")
304+
api_type, api_version = cls._get_api_type_and_version(api_type, api_version)
305+
url = cls._get_url("translations", deployment_id=deployment_id, api_type=api_type, api_version=api_version)
280306
response, _, api_key = await requestor.arequest(
281307
"post", url, files=files, params=data
282308
)

0 commit comments

Comments
 (0)