Skip to content

Commit eca61d2

Browse files
feat: add display text to callframe to make it easier on the sdk clients
1 parent 87429c1 commit eca61d2

File tree

6 files changed

+138
-36
lines changed

6 files changed

+138
-36
lines changed

pkg/builtin/builtin.go

+7-19
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,18 @@ import (
2020

2121
"github.com/AlecAivazis/survey/v2"
2222
"github.com/BurntSushi/locker"
23-
"github.com/google/shlex"
2423
"github.com/gptscript-ai/gptscript/pkg/engine"
2524
"github.com/gptscript-ai/gptscript/pkg/types"
2625
"github.com/jaytaylor/html2text"
2726
)
2827

2928
var SafeTools = map[string]struct{}{
30-
"sys.echo": {},
31-
"sys.time.now": {},
32-
"sys.prompt": {},
29+
"sys.abort": {},
3330
"sys.chat.finish": {},
3431
"sys.chat.history": {},
32+
"sys.echo": {},
33+
"sys.prompt": {},
34+
"sys.time.now": {},
3535
}
3636

3737
var tools = map[string]types.Tool{
@@ -333,11 +333,7 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
333333
var cmd *exec.Cmd
334334

335335
if runtime.GOOS == "windows" {
336-
args, err := shlex.Split(params.Command)
337-
if err != nil {
338-
return "", fmt.Errorf("parsing command: %w", err)
339-
}
340-
cmd = exec.Command(args[0], args[1:]...)
336+
cmd = exec.Command("cmd.exe", "/c", params.Command)
341337
} else {
342338
cmd = exec.Command("/bin/sh", "-c", params.Command)
343339
}
@@ -346,7 +342,7 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
346342
cmd.Dir = params.Directory
347343
out, err := cmd.CombinedOutput()
348344
if err != nil {
349-
return string(out), fmt.Errorf("OUTPUT: %s, ERROR: %w", out, err)
345+
return fmt.Sprintf("ERROR: %s\nOUTPUT: %s", err, out), nil
350346
}
351347
return string(out), nil
352348
}
@@ -362,10 +358,6 @@ func getWorkspaceDir(envs []string) (string, error) {
362358
}
363359

364360
func SysLs(_ context.Context, _ []string, input string) (string, error) {
365-
return sysLs("", input)
366-
}
367-
368-
func sysLs(base, input string) (string, error) {
369361
var params struct {
370362
Dir string `json:"dir,omitempty"`
371363
}
@@ -378,10 +370,6 @@ func sysLs(base, input string) (string, error) {
378370
dir = "."
379371
}
380372

381-
if base != "" {
382-
dir = filepath.Join(base, dir)
383-
}
384-
385373
entries, err := os.ReadDir(dir)
386374
if errors.Is(err, fs.ErrNotExist) {
387375
return fmt.Sprintf("directory does not exist: %s", params.Dir), nil
@@ -772,7 +760,7 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
772760
return "", fmt.Errorf("failed copying data from [%s] to [%s]: %w", params.URL, params.Location, err)
773761
}
774762

775-
return params.Location, nil
763+
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
776764
}
777765

778766
func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) {

pkg/builtin/builtin_test.go

+8
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"testing"
66

7+
"github.com/gptscript-ai/gptscript/pkg/types"
78
"github.com/hexops/autogold/v2"
89
"github.com/stretchr/testify/require"
910
)
@@ -21,3 +22,10 @@ func TestSysGetenv(t *testing.T) {
2122
require.NoError(t, err)
2223
autogold.Expect("").Equal(t, v)
2324
}
25+
26+
func TestDisplayCoverage(t *testing.T) {
27+
for _, tool := range ListTools() {
28+
_, err := types.ToSysDisplayString(tool.ID, nil)
29+
require.NoError(t, err)
30+
}
31+
}

pkg/engine/engine.go

+8-2
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ type CallContext struct {
6464
commonContext `json:",inline"`
6565
ToolName string `json:"toolName,omitempty"`
6666
ParentID string `json:"parentID,omitempty"`
67+
DisplayText string `json:"displayText,omitempty"`
6768
}
6869

6970
type Context struct {
@@ -72,6 +73,8 @@ type Context struct {
7273
Parent *Context
7374
LastReturn *Return
7475
Program *types.Program
76+
// Input is saved only so that we can render display text, don't use otherwise
77+
Input string
7578
}
7679

7780
type ChatHistory struct {
@@ -123,6 +126,7 @@ func (c *Context) GetCallContext() *CallContext {
123126
commonContext: c.commonContext,
124127
ParentID: c.ParentID(),
125128
ToolName: toolName,
129+
DisplayText: types.ToDisplayText(c.Tool, c.Input),
126130
}
127131
}
128132

@@ -140,7 +144,7 @@ func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Co
140144
return context.WithValue(ctx, toolCategoryKey{}, toolCategory)
141145
}
142146

143-
func NewContext(ctx context.Context, prg *types.Program) Context {
147+
func NewContext(ctx context.Context, prg *types.Program, input string) Context {
144148
category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory)
145149

146150
callCtx := Context{
@@ -151,11 +155,12 @@ func NewContext(ctx context.Context, prg *types.Program) Context {
151155
},
152156
Ctx: ctx,
153157
Program: prg,
158+
Input: input,
154159
}
155160
return callCtx
156161
}
157162

158-
func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCategory ToolCategory) (Context, error) {
163+
func (c *Context) SubCall(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
159164
tool, ok := c.Program.ToolSet[toolID]
160165
if !ok {
161166
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)
@@ -174,6 +179,7 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string, toolCatego
174179
Ctx: ctx,
175180
Parent: c,
176181
Program: c.Program,
182+
Input: input,
177183
}, nil
178184
}
179185

pkg/runner/runner.go

+18-15
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
150150
monitor.Stop(resp.Content, err)
151151
}()
152152

153-
callCtx := engine.NewContext(ctx, &prg)
153+
callCtx := engine.NewContext(ctx, &prg, input)
154154
if state == nil || state.StartContinuation {
155155
if state != nil {
156156
state = state.WithResumeInput(&input)
@@ -423,18 +423,21 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
423423

424424
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
425425

426-
authResp, err := r.auth(callCtx, input)
427-
if err != nil {
428-
return nil, err
429-
}
426+
_, safe := builtin.SafeTools[callCtx.Tool.ID]
427+
if callCtx.Tool.IsCommand() && !safe {
428+
authResp, err := r.auth(callCtx, input)
429+
if err != nil {
430+
return nil, err
431+
}
430432

431-
if !authResp.Accept {
432-
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
433-
return &State{
434-
Continuation: &engine.Return{
435-
Result: &msg,
436-
},
437-
}, nil
433+
if !authResp.Accept {
434+
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
435+
return &State{
436+
Continuation: &engine.Return{
437+
Result: &msg,
438+
},
439+
}, nil
440+
}
438441
}
439442

440443
ret, err := e.Start(callCtx, input)
@@ -671,7 +674,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
671674
}
672675

673676
func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, toolCategory engine.ToolCategory) (*State, error) {
674-
callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory)
677+
callCtx, err := parentContext.SubCall(ctx, input, toolID, callID, toolCategory)
675678
if err != nil {
676679
return nil, err
677680
}
@@ -680,7 +683,7 @@ func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, moni
680683
}
681684

682685
func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
683-
callCtx, err := parentContext.SubCall(ctx, toolID, callID, toolCategory)
686+
callCtx, err := parentContext.SubCall(ctx, "", toolID, callID, toolCategory)
684687
if err != nil {
685688
return nil, err
686689
}
@@ -834,7 +837,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
834837
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
835838
}
836839

837-
subCtx, err := callCtx.SubCall(callCtx.Ctx, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
840+
subCtx, err := callCtx.SubCall(callCtx.Ctx, "", credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
838841
if err != nil {
839842
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
840843
}

pkg/types/tool.go

+15
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package types
33
import (
44
"context"
55
"fmt"
6+
"path/filepath"
67
"slices"
78
"sort"
89
"strings"
@@ -453,6 +454,20 @@ func (t ToolSource) String() string {
453454
return fmt.Sprintf("%s:%d", t.Location, t.LineNo)
454455
}
455456

457+
func (t Tool) GetInterpreter() string {
458+
if !strings.HasPrefix(t.Instructions, CommandPrefix) {
459+
return ""
460+
}
461+
fields := strings.Fields(strings.TrimPrefix(t.Instructions, CommandPrefix))
462+
for _, field := range fields {
463+
name := filepath.Base(field)
464+
if name != "env" {
465+
return name
466+
}
467+
}
468+
return fields[0]
469+
}
470+
456471
func (t Tool) IsCommand() bool {
457472
return strings.HasPrefix(t.Instructions, CommandPrefix)
458473
}

pkg/types/toolstring.go

+82
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
package types
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"path/filepath"
7+
"strings"
8+
)
9+
10+
func ToDisplayText(tool Tool, input string) string {
11+
interpreter := tool.GetInterpreter()
12+
if interpreter == "" {
13+
return ""
14+
}
15+
16+
if strings.HasPrefix(interpreter, "sys.") {
17+
data := map[string]string{}
18+
_ = json.Unmarshal([]byte(input), &data)
19+
out, err := ToSysDisplayString(interpreter, data)
20+
if err != nil {
21+
return fmt.Sprintf("Running %s", interpreter)
22+
}
23+
return out
24+
}
25+
26+
if tool.Source.Repo != nil {
27+
repo := tool.Source.Repo
28+
root := strings.TrimPrefix(repo.Root, "https://")
29+
root = strings.TrimSuffix(root, ".git")
30+
name := repo.Name
31+
if name == "tool.gpt" {
32+
name = ""
33+
}
34+
35+
return fmt.Sprintf("Running %s from %s", tool.Name, filepath.Join(root, repo.Path, name))
36+
}
37+
38+
if tool.Source.Location != "" {
39+
return fmt.Sprintf("Running %s from %s", tool.Name, tool.Source.Location)
40+
}
41+
42+
return ""
43+
}
44+
45+
func ToSysDisplayString(id string, args map[string]string) (string, error) {
46+
switch id {
47+
case "sys.append":
48+
return fmt.Sprintf("Appending to file `%s`", args["filename"]), nil
49+
case "sys.download":
50+
if location := args["location"]; location != "" {
51+
return fmt.Sprintf("Downloading `%s` to `%s`", args["url"], location), nil
52+
} else {
53+
return fmt.Sprintf("Downloading `%s` to workspace", args["url"]), nil
54+
}
55+
case "sys.exec":
56+
return fmt.Sprintf("Running `%s`", args["command"]), nil
57+
case "sys.find":
58+
dir := args["directory"]
59+
if dir == "" {
60+
dir = "."
61+
}
62+
return fmt.Sprintf("Finding `%s` in `%s`", args["pattern"], dir), nil
63+
case "sys.http.get":
64+
return fmt.Sprintf("Downloading `%s`", args["url"]), nil
65+
case "sys.http.post":
66+
return fmt.Sprintf("Sending to `%s`", args["url"]), nil
67+
case "sys.http.html2text":
68+
return fmt.Sprintf("Downloading `%s`", args["url"]), nil
69+
case "sys.ls":
70+
return fmt.Sprintf("Listing `%s`", args["dir"]), nil
71+
case "sys.read":
72+
return fmt.Sprintf("Reading `%s`", args["filename"]), nil
73+
case "sys.remove":
74+
return fmt.Sprintf("Removing `%s`", args["location"]), nil
75+
case "sys.write":
76+
return fmt.Sprintf("Writing `%s`", args["filename"]), nil
77+
case "sys.stat", "sys.getenv", "sys.abort", "sys.chat.finish", "sys.chat.history", "sys.echo", "sys.prompt", "sys.time.now":
78+
return "", nil
79+
default:
80+
return "", fmt.Errorf("unknown tool for display string: %s", id)
81+
}
82+
}

0 commit comments

Comments
 (0)