Skip to content

Commit 084470f

Browse files
author
ochafik
committed
agent: fix response_format
1 parent 34a0623 commit 084470f

File tree

6 files changed

+81
-17
lines changed

6 files changed

+81
-17
lines changed

examples/agent/README.md

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -138,19 +138,28 @@ If you'd like to debug each binary separately (rather than have an agent spawing
138138
```bash
139139
# C++ server
140140
make -j server
141-
./server --model mixtral.gguf --port 8081
141+
./server \
142+
--model mixtral.gguf \
143+
--metrics \
144+
-ctk q4_0 \
145+
-ctv f16 \
146+
-c 32768 \
147+
--port 8081
142148
143149
# OpenAI compatibility layer
144150
python -m examples.openai \
145-
--port 8080
151+
--port 8080 \
146152
--endpoint http://localhost:8081 \
147-
--template_hf_model_id_fallback mistralai/Mixtral-8x7B-Instruct-v0.1
153+
--template-hf-model-id-fallback mistralai/Mixtral-8x7B-Instruct-v0.1
148154
149155
# Or have the OpenAI compatibility layer spawn the C++ server under the hood:
150156
# python -m examples.openai --model mixtral.gguf
151157
152158
# Agent itself:
153159
python -m examples.agent --endpoint http://localhost:8080 \
160+
--tools examples/agent/tools/example_summaries.py \
161+
--format PyramidalSummary \
162+
--goal "Create a pyramidal summary of Mankind's recent advancements"
154163
```
155164
156165
## Use existing tools (WIP)

examples/agent/agent.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from examples.json_schema_to_grammar import SchemaConverter
1212
from examples.agent.tools.std_tools import StandardTools
13-
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, Tool, ToolFunction
13+
from examples.openai.api import ChatCompletionRequest, ChatCompletionResponse, Message, ResponseFormat, Tool, ToolFunction
1414
from examples.agent.utils import collect_functions, load_module
1515
from examples.openai.prompting import ToolsPromptStyle
1616

@@ -46,7 +46,7 @@ def completion_with_tool_usage(
4646
else:
4747
type_adapter = TypeAdapter(response_model)
4848
schema = type_adapter.json_schema()
49-
response_format={"type": "json_object", "schema": schema }
49+
response_format=ResponseFormat(type="json_object", schema=schema)
5050

5151
tool_map = {fn.__name__: fn for fn in tools}
5252
tools_schemas = [
@@ -77,14 +77,15 @@ def completion_with_tool_usage(
7777
if auth:
7878
headers["Authorization"] = auth
7979
response = requests.post(
80-
endpoint,
80+
f'{endpoint}/v1/chat/completions',
8181
headers=headers,
8282
json=request.model_dump(),
8383
)
8484
if response.status_code != 200:
8585
raise Exception(f"Request failed ({response.status_code}): {response.text}")
8686

87-
response = ChatCompletionResponse(**response.json())
87+
response_json = response.json()
88+
response = ChatCompletionResponse(**response_json)
8889
if verbose:
8990
sys.stderr.write(f'# RESPONSE: {response.model_dump_json(indent=2)}\n')
9091
if response.error:
@@ -169,7 +170,7 @@ def main(
169170
if not endpoint:
170171
server_port = 8080
171172
server_host = 'localhost'
172-
endpoint: str = f'http://{server_host}:{server_port}/v1/chat/completions'
173+
endpoint = f'http://{server_host}:{server_port}'
173174
if verbose:
174175
sys.stderr.write(f"# Starting C++ server with model {model} on {endpoint}\n")
175176
cmd = [

examples/openai/api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ class Tool(BaseModel):
2828
function: ToolFunction
2929

3030
class ResponseFormat(BaseModel):
31-
type: str
32-
json_schema: Optional[Any] = None
31+
type: Literal["json_object"]
32+
schema: Optional[Dict] = None
3333

3434
class LlamaCppParams(BaseModel):
3535
n_predict: Optional[int] = None

examples/openai/prompting.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,15 +712,19 @@ def get_chat_handler(args: ChatHandlerArgs, parallel_calls: bool, tool_style: Op
712712
else:
713713
raise ValueError(f"Unsupported tool call style: {args.chat_template.tool_style}")
714714

715-
_ts_converter = SchemaToTypeScriptConverter()
716-
717715
# os.environ.get('NO_TS')
718716
def _please_respond_with_schema(schema: dict) -> str:
719717
# sig = json.dumps(schema, indent=2)
718+
_ts_converter = SchemaToTypeScriptConverter()
719+
_ts_converter.resolve_refs(schema, 'schema')
720720
sig = _ts_converter.visit(schema)
721721
return f'Please respond in JSON format with the following schema: {sig}'
722722

723723
def _tools_typescript_signatures(tools: list[Tool]) -> str:
724+
_ts_converter = SchemaToTypeScriptConverter()
725+
for tool in tools:
726+
_ts_converter.resolve_refs(tool.function.parameters, tool.function.name)
727+
724728
return 'namespace functions {\n' + '\n'.join(
725729
'// ' + tool.function.description.replace('\n', '\n// ') + '\n' + ''
726730
'type ' + tool.function.name + ' = (_: ' + _ts_converter.visit(tool.function.parameters) + ") => any;\n"

examples/openai/server.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def main(
7373
]
7474
server_process = subprocess.Popen(cmd, stdout=sys.stderr)
7575
atexit.register(server_process.kill)
76-
endpoint = f"http://{server_host}:{server_port}/completions"
76+
endpoint = f"http://{server_host}:{server_port}"
7777

7878

7979
# print(chat_template.render([
@@ -125,7 +125,7 @@ async def chat_completions(request: Request, chat_request: ChatCompletionRequest
125125

126126
if chat_request.response_format is not None:
127127
assert chat_request.response_format.type == "json_object", f"Unsupported response format: {chat_request.response_format.type}"
128-
response_schema = chat_request.response_format.json_schema or {}
128+
response_schema = chat_request.response_format.schema or {}
129129
else:
130130
response_schema = None
131131

@@ -164,7 +164,7 @@ async def chat_completions(request: Request, chat_request: ChatCompletionRequest
164164

165165
async with httpx.AsyncClient() as client:
166166
response = await client.post(
167-
f"{endpoint}",
167+
f'{endpoint}/completions',
168168
json=data,
169169
headers=headers,
170170
timeout=None)

examples/openai/ts_converter.py

Lines changed: 52 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,56 @@ class SchemaToTypeScriptConverter:
1414
# // where to get weather.
1515
# location: string,
1616
# }) => any;
17+
18+
def __init__(self):
19+
self._refs = {}
20+
self._refs_being_resolved = set()
21+
22+
def resolve_refs(self, schema: dict, url: str):
23+
'''
24+
Resolves all $ref fields in the given schema, fetching any remote schemas,
25+
replacing $ref with absolute reference URL and populating self._refs with the
26+
respective referenced (sub)schema dictionaries.
27+
'''
28+
def visit(n: dict):
29+
if isinstance(n, list):
30+
return [visit(x) for x in n]
31+
elif isinstance(n, dict):
32+
ref = n.get('$ref')
33+
if ref is not None and ref not in self._refs:
34+
if ref.startswith('https://'):
35+
assert self._allow_fetch, 'Fetching remote schemas is not allowed (use --allow-fetch for force)'
36+
import requests
37+
38+
frag_split = ref.split('#')
39+
base_url = frag_split[0]
40+
41+
target = self._refs.get(base_url)
42+
if target is None:
43+
target = self.resolve_refs(requests.get(ref).json(), base_url)
44+
self._refs[base_url] = target
45+
46+
if len(frag_split) == 1 or frag_split[-1] == '':
47+
return target
48+
elif ref.startswith('#/'):
49+
target = schema
50+
ref = f'{url}{ref}'
51+
n['$ref'] = ref
52+
else:
53+
raise ValueError(f'Unsupported ref {ref}')
54+
55+
for sel in ref.split('#')[-1].split('/')[1:]:
56+
assert target is not None and sel in target, f'Error resolving ref {ref}: {sel} not in {target}'
57+
target = target[sel]
58+
59+
self._refs[ref] = target
60+
else:
61+
for v in n.values():
62+
visit(v)
63+
64+
return n
65+
return visit(schema)
66+
1767
def _desc_comment(self, schema: dict):
1868
desc = schema.get("description", "").replace("\n", "\n// ") if 'description' in schema else None
1969
return f'// {desc}\n' if desc else ''
@@ -78,7 +128,7 @@ def add_component(comp_schema, is_required):
78128
else:
79129
add_component(t, is_required=True)
80130

81-
return self._build_object_rule(properties, required, additional_properties=[])
131+
return self._build_object_rule(properties, required, additional_properties={})
82132

83133
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):
84134
items = schema.get('items') or schema['prefixItems']
@@ -94,4 +144,4 @@ def add_component(comp_schema, is_required):
94144
return 'any'
95145

96146
else:
97-
return 'number' if schema_type == 'integer' else schema_type
147+
return 'number' if schema_type == 'integer' else schema_type or 'any'

0 commit comments

Comments
 (0)