Skip to content

Commit d74c3a1

Browse files
committed
Simplify OpenAI reasoning model specific arguments to OpenAI API
Previously OpenAI reasoning models didn't support stream_options and response_format Add reasoning_effort arg for calls to OpenAI reasoning models via API. Right now it defaults to medium but can be changed to low or high
1 parent 9b6d626 commit d74c3a1

File tree

1 file changed

+12
-25
lines changed
  • src/khoj/processor/conversation/openai

1 file changed

+12
-25
lines changed

src/khoj/processor/conversation/openai/utils.py

+12-25
Original file line numberDiff line numberDiff line change
@@ -60,20 +60,13 @@ def completion_with_backoff(
6060

6161
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
6262

63-
# Update request parameters for compatability with o1 model series
64-
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
65-
stream = True
66-
model_kwargs["stream_options"] = {"include_usage": True}
67-
if model_name == "o1":
68-
temperature = 1
69-
stream = False
70-
model_kwargs.pop("stream_options", None)
71-
elif model_name.startswith("o1"):
72-
temperature = 1
73-
model_kwargs.pop("response_format", None)
74-
elif model_name.startswith("o3-"):
63+
# Tune reasoning models arguments
64+
if model_name.startswith("o1") or model_name.startswith("o3"):
7565
temperature = 1
66+
model_kwargs["reasoning_effort"] = "medium"
7667

68+
stream = True
69+
model_kwargs["stream_options"] = {"include_usage": True}
7770
if os.getenv("KHOJ_LLM_SEED"):
7871
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
7972

@@ -172,20 +165,13 @@ def llm_thread(
172165

173166
formatted_messages = [{"role": message.role, "content": message.content} for message in messages]
174167

175-
# Update request parameters for compatability with o1 model series
176-
# Refer: https://platform.openai.com/docs/guides/reasoning/beta-limitations
177-
stream = True
178-
model_kwargs["stream_options"] = {"include_usage": True}
179-
if model_name == "o1":
180-
temperature = 1
181-
stream = False
182-
model_kwargs.pop("stream_options", None)
183-
elif model_name.startswith("o1-"):
168+
# Tune reasoning models arguments
169+
if model_name.startswith("o1"):
184170
temperature = 1
185-
model_kwargs.pop("response_format", None)
186-
elif model_name.startswith("o3-"):
171+
elif model_name.startswith("o3"):
187172
temperature = 1
188-
# Get the first system message and add the string `Formatting re-enabled` to it. See https://platform.openai.com/docs/guides/reasoning-best-practices
173+
# Get the first system message and add the string `Formatting re-enabled` to it.
174+
# See https://platform.openai.com/docs/guides/reasoning-best-practices
189175
if len(formatted_messages) > 0:
190176
system_messages = [
191177
(i, message) for i, message in enumerate(formatted_messages) if message["role"] == "system"
@@ -195,7 +181,6 @@ def llm_thread(
195181
formatted_messages[first_system_message_index][
196182
"content"
197183
] = f"{first_system_message} Formatting re-enabled"
198-
199184
elif model_name.startswith("deepseek-reasoner"):
200185
# Two successive messages cannot be from the same role. Should merge any back-to-back messages from the same role.
201186
# The first message should always be a user message (except system message).
@@ -210,6 +195,8 @@ def llm_thread(
210195

211196
formatted_messages = updated_messages
212197

198+
stream = True
199+
model_kwargs["stream_options"] = {"include_usage": True}
213200
if os.getenv("KHOJ_LLM_SEED"):
214201
model_kwargs["seed"] = int(os.getenv("KHOJ_LLM_SEED"))
215202

0 commit comments

Comments
 (0)