Skip to content

Commit f5cfde0

Browse files
chore: add progress output for builtins specifically sys.exec
1 parent 889eff2 commit f5cfde0

File tree

5 files changed

+67
-25
lines changed

5 files changed

+67
-25
lines changed

pkg/builtin/builtin.go

+39-21
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package builtin
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -264,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) {
264265
return SetDefaults(t), ok
265266
}
266267

267-
func SysFind(_ context.Context, _ []string, input string) (string, error) {
268+
func SysFind(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
268269
var result []string
269270
var params struct {
270271
Pattern string `json:"pattern,omitempty"`
@@ -305,7 +306,7 @@ func SysFind(_ context.Context, _ []string, input string) (string, error) {
305306
return strings.Join(result, "\n"), nil
306307
}
307308

308-
func SysExec(_ context.Context, env []string, input string) (string, error) {
309+
func SysExec(_ context.Context, env []string, input string, progress chan<- string) (string, error) {
309310
var params struct {
310311
Command string `json:"command,omitempty"`
311312
Directory string `json:"directory,omitempty"`
@@ -328,13 +329,30 @@ func SysExec(_ context.Context, env []string, input string) (string, error) {
328329
cmd = exec.Command("/bin/sh", "-c", params.Command)
329330
}
330331

332+
var (
333+
out bytes.Buffer
334+
pw = progressWriter{
335+
out: progress,
336+
}
337+
combined = io.MultiWriter(&out, &pw)
338+
)
331339
cmd.Env = env
332340
cmd.Dir = params.Directory
333-
out, err := cmd.CombinedOutput()
334-
if err != nil {
341+
cmd.Stdout = combined
342+
cmd.Stderr = combined
343+
if err := cmd.Run(); err != nil {
335344
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, out), nil
336345
}
337-
return string(out), nil
346+
return out.String(), nil
347+
}
348+
349+
type progressWriter struct {
350+
out chan<- string
351+
}
352+
353+
func (pw *progressWriter) Write(p []byte) (n int, err error) {
354+
pw.out <- string(p)
355+
return len(p), nil
338356
}
339357

340358
func getWorkspaceDir(envs []string) (string, error) {
@@ -347,7 +365,7 @@ func getWorkspaceDir(envs []string) (string, error) {
347365
return "", fmt.Errorf("no workspace directory found in env")
348366
}
349367

350-
func SysLs(_ context.Context, _ []string, input string) (string, error) {
368+
func SysLs(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
351369
var params struct {
352370
Dir string `json:"dir,omitempty"`
353371
}
@@ -383,7 +401,7 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
383401
return strings.Join(result, "\n"), nil
384402
}
385403

386-
func SysRead(_ context.Context, _ []string, input string) (string, error) {
404+
func SysRead(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
387405
var params struct {
388406
Filename string `json:"filename,omitempty"`
389407
}
@@ -411,7 +429,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) {
411429
return string(data), nil
412430
}
413431

414-
func SysWrite(_ context.Context, _ []string, input string) (string, error) {
432+
func SysWrite(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
415433
var params struct {
416434
Filename string `json:"filename,omitempty"`
417435
Content string `json:"content,omitempty"`
@@ -443,7 +461,7 @@ func SysWrite(_ context.Context, _ []string, input string) (string, error) {
443461
return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil
444462
}
445463

446-
func SysAppend(_ context.Context, _ []string, input string) (string, error) {
464+
func SysAppend(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
447465
var params struct {
448466
Filename string `json:"filename,omitempty"`
449467
Content string `json:"content,omitempty"`
@@ -489,7 +507,7 @@ func fixQueries(u string) string {
489507
return url.String()
490508
}
491509

492-
func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) {
510+
func SysHTTPGet(_ context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
493511
var params struct {
494512
URL string `json:"url,omitempty"`
495513
}
@@ -523,8 +541,8 @@ func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err erro
523541
return string(data), nil
524542
}
525543

526-
func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, error) {
527-
content, err := SysHTTPGet(ctx, env, input)
544+
func SysHTTPHtml2Text(ctx context.Context, env []string, input string, progress chan<- string) (string, error) {
545+
content, err := SysHTTPGet(ctx, env, input, progress)
528546
if err != nil {
529547
return "", err
530548
}
@@ -533,7 +551,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string,
533551
})
534552
}
535553

536-
func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) {
554+
func SysHTTPPost(ctx context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
537555
var params struct {
538556
URL string `json:"url,omitempty"`
539557
Content string `json:"content,omitempty"`
@@ -569,7 +587,7 @@ func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err e
569587
return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil
570588
}
571589

572-
func SysGetenv(_ context.Context, env []string, input string) (string, error) {
590+
func SysGetenv(_ context.Context, env []string, input string, _ chan<- string) (string, error) {
573591
var params struct {
574592
Name string `json:"name,omitempty"`
575593
}
@@ -597,7 +615,7 @@ func invalidArgument(input string, err error) string {
597615
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
598616
}
599617

600-
func SysChatHistory(ctx context.Context, _ []string, _ string) (string, error) {
618+
func SysChatHistory(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
601619
engineContext, _ := engine.FromContext(ctx)
602620

603621
data, err := json.Marshal(engine.ChatHistory{
@@ -627,7 +645,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
627645
return
628646
}
629647

630-
func SysChatFinish(_ context.Context, _ []string, input string) (string, error) {
648+
func SysChatFinish(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
631649
var params struct {
632650
Message string `json:"return,omitempty"`
633651
}
@@ -641,7 +659,7 @@ func SysChatFinish(_ context.Context, _ []string, input string) (string, error)
641659
}
642660
}
643661

644-
func SysAbort(_ context.Context, _ []string, input string) (string, error) {
662+
func SysAbort(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
645663
var params struct {
646664
Message string `json:"message,omitempty"`
647665
}
@@ -651,7 +669,7 @@ func SysAbort(_ context.Context, _ []string, input string) (string, error) {
651669
return "", fmt.Errorf("ABORT: %s", params.Message)
652670
}
653671

654-
func SysRemove(_ context.Context, _ []string, input string) (string, error) {
672+
func SysRemove(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
655673
var params struct {
656674
Location string `json:"location,omitempty"`
657675
}
@@ -670,7 +688,7 @@ func SysRemove(_ context.Context, _ []string, input string) (string, error) {
670688
return fmt.Sprintf("Removed file: %s", params.Location), nil
671689
}
672690

673-
func SysStat(_ context.Context, _ []string, input string) (string, error) {
691+
func SysStat(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
674692
var params struct {
675693
Filepath string `json:"filepath,omitempty"`
676694
}
@@ -690,7 +708,7 @@ func SysStat(_ context.Context, _ []string, input string) (string, error) {
690708
return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil
691709
}
692710

693-
func SysDownload(_ context.Context, env []string, input string) (_ string, err error) {
711+
func SysDownload(_ context.Context, env []string, input string, _ chan<- string) (_ string, err error) {
694712
var params struct {
695713
URL string `json:"url,omitempty"`
696714
Location string `json:"location,omitempty"`
@@ -763,6 +781,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e
763781
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
764782
}
765783

766-
func SysTimeNow(context.Context, []string, string) (string, error) {
784+
func SysTimeNow(context.Context, []string, string, chan<- string) (string, error) {
767785
return time.Now().Format(time.RFC3339), nil
768786
}

pkg/engine/cmd.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"runtime"
1313
"sort"
1414
"strings"
15+
"sync"
1516

1617
"github.com/google/shlex"
1718
"github.com/gptscript-ai/gptscript/pkg/counter"
@@ -64,7 +65,30 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
6465
"input": input,
6566
},
6667
}
67-
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input)
68+
69+
var (
70+
progress = make(chan string)
71+
wg sync.WaitGroup
72+
)
73+
wg.Add(1)
74+
defer wg.Wait()
75+
defer close(progress)
76+
go func() {
77+
defer wg.Done()
78+
buf := strings.Builder{}
79+
for line := range progress {
80+
buf.WriteString(line)
81+
e.Progress <- types.CompletionStatus{
82+
CompletionID: id,
83+
PartialResponse: &types.CompletionMessage{
84+
Role: types.CompletionMessageRoleTypeAssistant,
85+
Content: types.Text(buf.String()),
86+
},
87+
}
88+
}
89+
}()
90+
91+
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
6892
}
6993

7094
var instructions []string

pkg/prompt/credential.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ func GetModelProviderCredential(ctx context.Context, credStore credentials.Crede
1818
if exists {
1919
k = cred.Env[env]
2020
} else {
21-
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
21+
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message), nil)
2222
if err != nil {
2323
return "", err
2424
}

pkg/prompt/prompt.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
4848
return string(data), err
4949
}
5050

51-
func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
51+
func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) {
5252
var params struct {
5353
Message string `json:"message,omitempty"`
5454
Fields string `json:"fields,omitempty"`

pkg/types/tool.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (p Program) SetBlocking() Program {
117117
return p
118118
}
119119

120-
type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error)
120+
type BuiltinFunc func(ctx context.Context, env []string, input string, progress chan<- string) (string, error)
121121

122122
type Parameters struct {
123123
Name string `json:"name,omitempty"`

0 commit comments

Comments
 (0)