Skip to content

Commit b823f89

Browse files
feat: add support for wildcard subtool names
1 parent c2edee9 commit b823f89

File tree

11 files changed

+295
-62
lines changed

11 files changed

+295
-62
lines changed

pkg/engine/engine.go

+7-4
Original file line numberDiff line numberDiff line change
@@ -109,10 +109,13 @@ func (c *Context) ParentID() string {
109109
func (c *Context) GetCallContext() *CallContext {
110110
var toolName string
111111
if c.Parent != nil {
112-
for name, id := range c.Parent.Tool.ToolMapping {
113-
if id == c.Tool.ID {
114-
toolName = name
115-
break
112+
outer:
113+
for name, refs := range c.Parent.Tool.ToolMapping {
114+
for _, ref := range refs {
115+
if ref.ToolID == c.Tool.ID {
116+
toolName = name
117+
break outer
118+
}
116119
}
117120
}
118121
}

pkg/engine/http.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
3535

3636
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
3737
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
38-
referencedToolID, ok := tool.ToolMapping[referencedToolName]
39-
if !ok {
38+
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
39+
if !ok || len(referencedToolRefs) != 1 {
4040
return nil, fmt.Errorf("invalid reference [%s] to tool [%s] from [%s], missing \"tools: %s\" parameter", toolURL, referencedToolName, tool.Source, referencedToolName)
4141
}
42-
referencedTool, ok := prg.ToolSet[referencedToolID]
42+
referencedTool, ok := prg.ToolSet[referencedToolRefs[0].ToolID]
4343
if !ok {
4444
return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname())
4545
}

pkg/loader/loader.go

+54-28
Original file line numberDiff line numberDiff line change
@@ -181,11 +181,15 @@ func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
181181
return openAPIDocument
182182
}
183183

184-
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) (types.Tool, error) {
184+
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) ([]types.Tool, error) {
185185
data := base.Content
186186

187187
if bytes.HasPrefix(data, assemble.Header) {
188-
return loadProgram(data, prg, targetToolName)
188+
tool, err := loadProgram(data, prg, targetToolName)
189+
if err != nil {
190+
return nil, err
191+
}
192+
return []types.Tool{tool}, nil
189193
}
190194

191195
var (
@@ -200,7 +204,7 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
200204
tools, err = getOpenAPITools(openAPIDocument, "")
201205
}
202206
if err != nil {
203-
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
207+
return nil, fmt.Errorf("error parsing OpenAPI definition: %w", err)
204208
}
205209
}
206210

@@ -222,17 +226,17 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
222226
AssignGlobals: true,
223227
})
224228
if err != nil {
225-
return types.Tool{}, err
229+
return nil, err
226230
}
227231
}
228232

229233
if len(tools) == 0 {
230-
return types.Tool{}, fmt.Errorf("no tools found in %s", base)
234+
return nil, fmt.Errorf("no tools found in %s", base)
231235
}
232236

233237
var (
234-
localTools = types.ToolSet{}
235-
mainTool types.Tool
238+
localTools = types.ToolSet{}
239+
targetTools []types.Tool
236240
)
237241

238242
for i, tool := range tools {
@@ -243,44 +247,65 @@ func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base
243247
// Probably a better way to come up with an ID
244248
tool.ID = tool.Source.Location + ":" + tool.Name
245249

246-
if i == 0 {
247-
mainTool = tool
250+
if i == 0 && targetToolName == "" {
251+
targetTools = append(targetTools, tool)
248252
}
249253

250254
if i != 0 && tool.Parameters.Name == "" {
251-
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name"))
255+
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have no name"))
252256
}
253257

254258
if i != 0 && tool.Parameters.GlobalModelName != "" {
255-
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name"))
259+
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global model name"))
256260
}
257261

258262
if i != 0 && len(tool.Parameters.GlobalTools) > 0 {
259-
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools"))
263+
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, fmt.Errorf("only the first tool in a file can have global tools"))
260264
}
261265

262-
if targetToolName != "" && strings.EqualFold(tool.Parameters.Name, targetToolName) {
263-
mainTool = tool
266+
if targetToolName != "" && tool.Parameters.Name != "" {
267+
if strings.EqualFold(tool.Parameters.Name, targetToolName) {
268+
targetTools = append(targetTools, tool)
269+
} else if strings.Contains(targetToolName, "*") {
270+
match, err := filepath.Match(strings.ToLower(targetToolName), strings.ToLower(tool.Parameters.Name))
271+
if err != nil {
272+
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo, err)
273+
}
274+
if match {
275+
targetTools = append(targetTools, tool)
276+
}
277+
}
264278
}
265279

266280
if existing, ok := localTools[strings.ToLower(tool.Parameters.Name)]; ok {
267-
return types.Tool{}, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo,
281+
return nil, parser.NewErrLine(tool.Source.Location, tool.Source.LineNo,
268282
fmt.Errorf("duplicate tool name [%s] in %s found at lines %d and %d", tool.Parameters.Name, tool.Source.Location,
269283
tool.Source.LineNo, existing.Source.LineNo))
270284
}
271285

272286
localTools[strings.ToLower(tool.Parameters.Name)] = tool
273287
}
274288

275-
return link(ctx, cache, prg, base, mainTool, localTools)
289+
return linkAll(ctx, cache, prg, base, targetTools, localTools)
290+
}
291+
292+
func linkAll(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tools []types.Tool, localTools types.ToolSet) (result []types.Tool, _ error) {
293+
for _, tool := range tools {
294+
tool, err := link(ctx, cache, prg, base, tool, localTools)
295+
if err != nil {
296+
return nil, err
297+
}
298+
result = append(result, tool)
299+
}
300+
return
276301
}
277302

278303
func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, tool types.Tool, localTools types.ToolSet) (types.Tool, error) {
279304
if existing, ok := prg.ToolSet[tool.ID]; ok {
280305
return existing, nil
281306
}
282307

283-
tool.ToolMapping = map[string]string{}
308+
tool.ToolMapping = map[string][]types.ToolReference{}
284309
tool.LocalTools = map[string]string{}
285310
toolNames := map[string]struct{}{}
286311

@@ -310,16 +335,17 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
310335
}
311336
}
312337

313-
tool.ToolMapping[targetToolName] = linkedTool.ID
338+
tool.AddToolMapping(targetToolName, linkedTool)
314339
toolNames[targetToolName] = struct{}{}
315340
} else {
316341
toolName, subTool := types.SplitToolRef(targetToolName)
317-
resolvedTool, err := resolve(ctx, cache, prg, base, toolName, subTool)
342+
resolvedTools, err := resolve(ctx, cache, prg, base, toolName, subTool)
318343
if err != nil {
319344
return types.Tool{}, fmt.Errorf("failed resolving %s from %s: %w", targetToolName, base, err)
320345
}
321-
322-
tool.ToolMapping[targetToolName] = resolvedTool.ID
346+
for _, resolvedTool := range resolvedTools {
347+
tool.AddToolMapping(targetToolName, resolvedTool)
348+
}
323349
}
324350
}
325351

@@ -345,14 +371,14 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
345371
prg := types.Program{
346372
ToolSet: types.ToolSet{},
347373
}
348-
tool, err := readTool(ctx, opt.Cache, &prg, &source{
374+
tools, err := readTool(ctx, opt.Cache, &prg, &source{
349375
Content: []byte(content),
350376
Location: "inline",
351377
}, subToolName)
352378
if err != nil {
353379
return types.Program{}, err
354380
}
355-
prg.EntryToolID = tool.ID
381+
prg.EntryToolID = tools[0].ID
356382
return prg, nil
357383
}
358384

@@ -385,26 +411,26 @@ func Program(ctx context.Context, name, subToolName string, opts ...Options) (ty
385411
Name: name,
386412
ToolSet: types.ToolSet{},
387413
}
388-
tool, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName)
414+
tools, err := resolve(ctx, opt.Cache, &prg, &source{}, name, subToolName)
389415
if err != nil {
390416
return types.Program{}, err
391417
}
392-
prg.EntryToolID = tool.ID
418+
prg.EntryToolID = tools[0].ID
393419
return prg, nil
394420
}
395421

396-
func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) (types.Tool, error) {
422+
func resolve(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, name, subTool string) ([]types.Tool, error) {
397423
if subTool == "" {
398424
t, ok := builtin.Builtin(name)
399425
if ok {
400426
prg.ToolSet[t.ID] = t
401-
return t, nil
427+
return []types.Tool{t}, nil
402428
}
403429
}
404430

405431
s, err := input(ctx, cache, base, name)
406432
if err != nil {
407-
return types.Tool{}, err
433+
return nil, err
408434
}
409435

410436
return readTool(ctx, cache, prg, s, subTool)

pkg/loader/loader_test.go

+6-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,12 @@ func TestHelloWorld(t *testing.T) {
109109
"instructions": "call bob",
110110
"id": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/sub/tool.gpt:",
111111
"toolMapping": {
112-
"../bob.gpt": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/bob.gpt:"
112+
"../bob.gpt": [
113+
{
114+
"reference": "../bob.gpt",
115+
"toolID": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/bob.gpt:"
116+
}
117+
]
113118
},
114119
"localTools": {
115120
"": "https://raw.githubusercontent.com/ibuildthecloud/test/bafe5a62174e8a0ea162277dcfe3a2ddb7eea928/example/sub/tool.gpt:"

pkg/runner/runner.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -834,12 +834,12 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
834834
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
835835
// and save it in the store.
836836
if !exists {
837-
credToolID, ok := callCtx.Tool.ToolMapping[credToolName]
838-
if !ok {
837+
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
838+
if !ok || len(credToolRefs) != 1 {
839839
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
840840
}
841841

842-
subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
842+
subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
843843
if err != nil {
844844
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
845845
}
@@ -874,7 +874,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
874874
}
875875

876876
// Only store the credential if the tool is on GitHub, and the credential is non-empty.
877-
if isGitHubTool(credToolName) && callCtx.Program.ToolSet[credToolID].Source.Repo != nil {
877+
if isGitHubTool(credToolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil {
878878
if isEmpty {
879879
log.Warnf("Not saving empty credential for tool %s", credToolName)
880880
} else if err := store.Add(*cred); err != nil {

pkg/tests/runner_test.go

+85
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,91 @@ func toJSONString(t *testing.T, v interface{}) string {
2020
return string(x)
2121
}
2222

23+
func TestAsterick(t *testing.T) {
24+
r := tester.NewRunner(t)
25+
p, err := r.Load("")
26+
require.NoError(t, err)
27+
autogold.Expect(`{
28+
"name": "testdata/TestAsterick/test.gpt",
29+
"entryToolId": "testdata/TestAsterick/test.gpt:",
30+
"toolSet": {
31+
"testdata/TestAsterick/other.gpt:a": {
32+
"name": "a",
33+
"modelName": "gpt-4o",
34+
"internalPrompt": null,
35+
"instructions": "a",
36+
"id": "testdata/TestAsterick/other.gpt:a",
37+
"localTools": {
38+
"a": "testdata/TestAsterick/other.gpt:a",
39+
"afoo": "testdata/TestAsterick/other.gpt:afoo",
40+
"foo": "testdata/TestAsterick/other.gpt:foo",
41+
"fooa": "testdata/TestAsterick/other.gpt:fooa",
42+
"fooafoo": "testdata/TestAsterick/other.gpt:fooafoo"
43+
},
44+
"source": {
45+
"location": "testdata/TestAsterick/other.gpt",
46+
"lineNo": 10
47+
},
48+
"workingDir": "testdata/TestAsterick"
49+
},
50+
"testdata/TestAsterick/other.gpt:afoo": {
51+
"name": "afoo",
52+
"modelName": "gpt-4o",
53+
"internalPrompt": null,
54+
"instructions": "afoo",
55+
"id": "testdata/TestAsterick/other.gpt:afoo",
56+
"localTools": {
57+
"a": "testdata/TestAsterick/other.gpt:a",
58+
"afoo": "testdata/TestAsterick/other.gpt:afoo",
59+
"foo": "testdata/TestAsterick/other.gpt:foo",
60+
"fooa": "testdata/TestAsterick/other.gpt:fooa",
61+
"fooafoo": "testdata/TestAsterick/other.gpt:fooafoo"
62+
},
63+
"source": {
64+
"location": "testdata/TestAsterick/other.gpt",
65+
"lineNo": 4
66+
},
67+
"workingDir": "testdata/TestAsterick"
68+
},
69+
"testdata/TestAsterick/test.gpt:": {
70+
"modelName": "gpt-4o",
71+
"internalPrompt": null,
72+
"tools": [
73+
"a* from ./other.gpt"
74+
],
75+
"instructions": "Ask Bob how he is doing and let me know exactly what he said.",
76+
"id": "testdata/TestAsterick/test.gpt:",
77+
"toolMapping": {
78+
"a* from ./other.gpt": [
79+
{
80+
"reference": "afoo from ./other.gpt",
81+
"toolID": "testdata/TestAsterick/other.gpt:afoo"
82+
},
83+
{
84+
"reference": "a from ./other.gpt",
85+
"toolID": "testdata/TestAsterick/other.gpt:a"
86+
}
87+
]
88+
},
89+
"localTools": {
90+
"": "testdata/TestAsterick/test.gpt:"
91+
},
92+
"source": {
93+
"location": "testdata/TestAsterick/test.gpt",
94+
"lineNo": 1
95+
},
96+
"workingDir": "testdata/TestAsterick"
97+
}
98+
}
99+
}`).Equal(t, toJSONString(t, p))
100+
101+
r.RespondWith(tester.Result{
102+
Text: "hi",
103+
})
104+
_, err = r.Run("", "")
105+
require.NoError(t, err)
106+
}
107+
23108
func TestDualSubChat(t *testing.T) {
24109
r := tester.NewRunner(t)
25110
r.RespondWith(tester.Result{

0 commit comments

Comments
 (0)