Skip to content

Commit 4871ad8

Browse files
feat: add "tools: foo as bar" syntax
1 parent e9b8847 commit 4871ad8

File tree

6 files changed

+111
-11
lines changed

6 files changed

+111
-11
lines changed

pkg/tests/runner_test.go

+7
Original file line numberDiff line numberDiff line change
@@ -678,6 +678,13 @@ func TestContextArg(t *testing.T) {
678678
assert.Equal(t, "TEST RESULT CALL: 1", x)
679679
}
680680

681+
func TestToolAs(t *testing.T) {
682+
runner := tester.NewRunner(t)
683+
x, err := runner.Run("", `{}`)
684+
require.NoError(t, err)
685+
assert.Equal(t, "TEST RESULT CALL: 1", x)
686+
}
687+
681688
func TestCwd(t *testing.T) {
682689
runner := tester.NewRunner(t)
683690

+65
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
`{
2+
"Model": "gpt-4o-2024-05-13",
3+
"InternalSystemPrompt": null,
4+
"Tools": [
5+
{
6+
"function": {
7+
"toolID": "testdata/TestToolAs/test.gpt:6",
8+
"name": "local",
9+
"parameters": {
10+
"properties": {
11+
"defaultPromptParameter": {
12+
"description": "Prompt to send to the tool or assistant. This may be instructions or question.",
13+
"type": "string"
14+
}
15+
},
16+
"required": [
17+
"defaultPromptParameter"
18+
],
19+
"type": "object"
20+
}
21+
}
22+
},
23+
{
24+
"function": {
25+
"toolID": "testdata/TestToolAs/other.gpt:1",
26+
"name": "remote",
27+
"parameters": {
28+
"properties": {
29+
"defaultPromptParameter": {
30+
"description": "Prompt to send to the tool or assistant. This may be instructions or question.",
31+
"type": "string"
32+
}
33+
},
34+
"required": [
35+
"defaultPromptParameter"
36+
],
37+
"type": "object"
38+
}
39+
}
40+
}
41+
],
42+
"Messages": [
43+
{
44+
"role": "system",
45+
"content": [
46+
{
47+
"text": "A tool"
48+
}
49+
]
50+
},
51+
{
52+
"role": "user",
53+
"content": [
54+
{
55+
"text": "{}"
56+
}
57+
]
58+
}
59+
],
60+
"MaxTokens": 0,
61+
"Temperature": null,
62+
"JSONResponse": false,
63+
"Grammar": "",
64+
"Cache": null
65+
}`
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
other file
+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
tools: infile as local, ./other.gpt as remote
2+
3+
A tool
4+
5+
---
6+
name: infile
7+
8+
infile tool

pkg/types/tool.go

+24-11
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ func (p Program) ChatName() string {
5757
}
5858

5959
type ToolReference struct {
60+
Named string
6061
Reference string
6162
Arg string
6263
ToolID string
@@ -184,9 +185,14 @@ func SplitArg(hasArg string) (prefix, arg string) {
184185
var (
185186
fields = strings.Fields(hasArg)
186187
idx = slices.Index(fields, "with")
188+
asIdx = slices.Index(fields, "as")
187189
)
188190

189191
if idx == -1 {
192+
if asIdx != -1 {
193+
return strings.Join(fields[:asIdx], " "),
194+
strings.Join(fields[asIdx:], " ")
195+
}
190196
return strings.TrimSpace(hasArg), ""
191197
}
192198

@@ -201,7 +207,12 @@ func (t Tool) GetToolRefsFromNames(names []string) (result []ToolReference, _ er
201207
return nil, NewErrToolNotFound(toolName)
202208
}
203209
_, arg := SplitArg(toolName)
210+
named, ok := strings.CutPrefix(arg, "as ")
211+
if !ok {
212+
named = ""
213+
}
204214
result = append(result, ToolReference{
215+
Named: named,
205216
Arg: arg,
206217
Reference: toolName,
207218
ToolID: toolID,
@@ -287,15 +298,13 @@ func (t Tool) String() string {
287298
func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) {
288299
toolNames := map[string]struct{}{}
289300

290-
for _, subToolName := range t.Parameters.Tools {
291-
result, err = appendTool(result, prg, t, subToolName, toolNames)
292-
if err != nil {
293-
return nil, err
294-
}
301+
subToolRefs, err := t.GetToolRefsFromNames(t.Parameters.Tools)
302+
if err != nil {
303+
return nil, err
295304
}
296305

297-
for _, subToolName := range t.Parameters.Context {
298-
result, err = appendExports(result, prg, t, subToolName, toolNames)
306+
for _, subToolRef := range subToolRefs {
307+
result, err = appendTool(result, prg, t, subToolRef.Reference, toolNames, subToolRef.Named)
299308
if err != nil {
300309
return nil, err
301310
}
@@ -327,7 +336,7 @@ func appendExports(completionTools []CompletionTool, prg Program, parentTool Too
327336
}
328337

329338
for _, export := range subTool.Export {
330-
completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames)
339+
completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames, "")
331340
if err != nil {
332341
return nil, err
333342
}
@@ -336,7 +345,7 @@ func appendExports(completionTools []CompletionTool, prg Program, parentTool Too
336345
return completionTools, nil
337346
}
338347

339-
func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) {
348+
func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}, asName string) ([]CompletionTool, error) {
340349
subTool, err := getTool(prg, parentTool, subToolName)
341350
if err != nil {
342351
return nil, err
@@ -356,18 +365,22 @@ func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool,
356365
if subTool.Instructions == "" {
357366
log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID)
358367
} else {
368+
name := subToolName
369+
if asName != "" {
370+
name = asName
371+
}
359372
completionTools = append(completionTools, CompletionTool{
360373
Function: CompletionFunctionDefinition{
361374
ToolID: subTool.ID,
362-
Name: PickToolName(subToolName, toolNames),
375+
Name: PickToolName(name, toolNames),
363376
Description: subTool.Parameters.Description,
364377
Parameters: args,
365378
},
366379
})
367380
}
368381

369382
for _, export := range subTool.Export {
370-
completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames)
383+
completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames, "")
371384
if err != nil {
372385
return nil, err
373386
}

pkg/types/toolname_test.go

+6
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,12 @@ func TestParse(t *testing.T) {
2020
tool, subTool := SplitToolRef("a from b with x")
2121
autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool})
2222

23+
tool, subTool = SplitToolRef("a from b with x as other")
24+
autogold.Expect([]string{"b", "a"}).Equal(t, []string{tool, subTool})
25+
2326
tool, subTool = SplitToolRef("a with x")
2427
autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool})
28+
29+
tool, subTool = SplitToolRef("a with x as other")
30+
autogold.Expect([]string{"a", ""}).Equal(t, []string{tool, subTool})
2531
}

0 commit comments

Comments
 (0)