Skip to content

Commit a387f63

Browse files
committed
Enforce json schema on more chat actors to improve schema compliance
Including infer webpage urls, gemini documents search, pick default mode tools chat actors
1 parent ccd9de7 commit a387f63

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

src/khoj/processor/conversation/google/gemini_chat.py

+5
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import pyjson5
66
from langchain.schema import ChatMessage
7+
from pydantic import BaseModel
78

89
from khoj.database.models import Agent, ChatModel, KhojUser
910
from khoj.processor.conversation import prompts
@@ -96,12 +97,16 @@ def extract_questions_gemini(
9697
messages.append(ChatMessage(content=prompt, role="user"))
9798
messages.append(ChatMessage(content=system_prompt, role="system"))
9899

100+
class DocumentQueries(BaseModel):
101+
queries: List[str]
102+
99103
response = gemini_send_message_to_model(
100104
messages,
101105
api_key,
102106
model,
103107
api_base_url=api_base_url,
104108
response_type="json_object",
109+
response_schema=DocumentQueries,
105110
tracer=tracer,
106111
)
107112

src/khoj/routers/helpers.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -399,10 +399,15 @@ async def aget_data_sources_and_output_format(
399399

400400
agent_chat_model = agent.chat_model if agent else None
401401

402+
class PickTools(BaseModel):
403+
source: List[str]
404+
output: str
405+
402406
with timer("Chat actor: Infer information sources to refer", logger):
403407
response = await send_message_to_model_wrapper(
404408
relevant_tools_prompt,
405409
response_type="json_object",
410+
response_schema=PickTools,
406411
user=user,
407412
query_files=query_files,
408413
agent_chat_model=agent_chat_model,
@@ -483,11 +488,15 @@ async def infer_webpage_urls(
483488

484489
agent_chat_model = agent.chat_model if agent else None
485490

491+
class WebpageUrls(BaseModel):
492+
links: List[str]
493+
486494
with timer("Chat actor: Infer webpage urls to read", logger):
487495
response = await send_message_to_model_wrapper(
488496
online_queries_prompt,
489497
query_images=query_images,
490498
response_type="json_object",
499+
response_schema=WebpageUrls,
491500
user=user,
492501
query_files=query_files,
493502
agent_chat_model=agent_chat_model,
@@ -563,11 +572,13 @@ class OnlineQueries(BaseModel):
563572
response = pyjson5.loads(response)
564573
response = {q.strip() for q in response["queries"] if q.strip()}
565574
if not isinstance(response, set) or not response or len(response) == 0:
566-
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
575+
logger.error(
576+
f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}"
577+
)
567578
return {q}
568579
return response
569580
except Exception as e:
570-
logger.error(f"Invalid response for constructing subqueries: {response}. Returning original query: {q}")
581+
logger.error(f"Invalid response for constructing online subqueries: {response}. Returning original query: {q}")
571582
return {q}
572583

573584

@@ -2054,7 +2065,7 @@ def schedule_automation(
20542065
try:
20552066
user_timezone = pytz.timezone(timezone)
20562067
except pytz.UnknownTimeZoneError:
2057-
logger.error(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
2068+
logger.warning(f"Invalid timezone: {timezone}. Fallback to use UTC to schedule automation.")
20582069
user_timezone = pytz.utc
20592070

20602071
trigger = CronTrigger.from_crontab(crontime, user_timezone)

0 commit comments

Comments
 (0)