Skip to content

Commit 81d3b48

Browse files
Merge pull request #520 from ibuildthecloud/main
chore: add progress output for builtins specifically sys.exec
2 parents 889eff2 + e9c2bf9 commit 81d3b48

File tree

6 files changed

+84
-28
lines changed

6 files changed

+84
-28
lines changed

pkg/builtin/builtin.go

+51-22
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package builtin
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -264,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) {
264265
return SetDefaults(t), ok
265266
}
266267

267-
func SysFind(_ context.Context, _ []string, input string) (string, error) {
268+
func SysFind(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
268269
var result []string
269270
var params struct {
270271
Pattern string `json:"pattern,omitempty"`
@@ -305,7 +306,7 @@ func SysFind(_ context.Context, _ []string, input string) (string, error) {
305306
return strings.Join(result, "\n"), nil
306307
}
307308

308-
func SysExec(_ context.Context, env []string, input string) (string, error) {
309+
func SysExec(_ context.Context, env []string, input string, progress chan<- string) (string, error) {
309310
var params struct {
310311
Command string `json:"command,omitempty"`
311312
Directory string `json:"directory,omitempty"`
@@ -328,13 +329,30 @@ func SysExec(_ context.Context, env []string, input string) (string, error) {
328329
cmd = exec.Command("/bin/sh", "-c", params.Command)
329330
}
330331

332+
var (
333+
out bytes.Buffer
334+
pw = progressWriter{
335+
out: progress,
336+
}
337+
combined = io.MultiWriter(&out, &pw)
338+
)
331339
cmd.Env = env
332340
cmd.Dir = params.Directory
333-
out, err := cmd.CombinedOutput()
334-
if err != nil {
335-
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, out), nil
341+
cmd.Stdout = combined
342+
cmd.Stderr = combined
343+
if err := cmd.Run(); err != nil {
344+
return fmt.Sprintf("ERROR: %s\nOUTPUT:\n%s", err, &out), nil
336345
}
337-
return string(out), nil
346+
return out.String(), nil
347+
}
348+
349+
type progressWriter struct {
350+
out chan<- string
351+
}
352+
353+
func (pw *progressWriter) Write(p []byte) (n int, err error) {
354+
pw.out <- string(p)
355+
return len(p), nil
338356
}
339357

340358
func getWorkspaceDir(envs []string) (string, error) {
@@ -347,7 +365,7 @@ func getWorkspaceDir(envs []string) (string, error) {
347365
return "", fmt.Errorf("no workspace directory found in env")
348366
}
349367

350-
func SysLs(_ context.Context, _ []string, input string) (string, error) {
368+
func SysLs(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
351369
var params struct {
352370
Dir string `json:"dir,omitempty"`
353371
}
@@ -383,7 +401,7 @@ func SysLs(_ context.Context, _ []string, input string) (string, error) {
383401
return strings.Join(result, "\n"), nil
384402
}
385403

386-
func SysRead(_ context.Context, _ []string, input string) (string, error) {
404+
func SysRead(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
387405
var params struct {
388406
Filename string `json:"filename,omitempty"`
389407
}
@@ -411,7 +429,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) {
411429
return string(data), nil
412430
}
413431

414-
func SysWrite(_ context.Context, _ []string, input string) (string, error) {
432+
func SysWrite(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
415433
var params struct {
416434
Filename string `json:"filename,omitempty"`
417435
Content string `json:"content,omitempty"`
@@ -443,7 +461,7 @@ func SysWrite(_ context.Context, _ []string, input string) (string, error) {
443461
return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil
444462
}
445463

446-
func SysAppend(_ context.Context, _ []string, input string) (string, error) {
464+
func SysAppend(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
447465
var params struct {
448466
Filename string `json:"filename,omitempty"`
449467
Content string `json:"content,omitempty"`
@@ -489,7 +507,7 @@ func fixQueries(u string) string {
489507
return url.String()
490508
}
491509

492-
func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) {
510+
func SysHTTPGet(_ context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
493511
var params struct {
494512
URL string `json:"url,omitempty"`
495513
}
@@ -523,8 +541,8 @@ func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err erro
523541
return string(data), nil
524542
}
525543

526-
func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string, error) {
527-
content, err := SysHTTPGet(ctx, env, input)
544+
func SysHTTPHtml2Text(ctx context.Context, env []string, input string, progress chan<- string) (string, error) {
545+
content, err := SysHTTPGet(ctx, env, input, progress)
528546
if err != nil {
529547
return "", err
530548
}
@@ -533,7 +551,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string,
533551
})
534552
}
535553

536-
func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) {
554+
func SysHTTPPost(ctx context.Context, _ []string, input string, _ chan<- string) (_ string, err error) {
537555
var params struct {
538556
URL string `json:"url,omitempty"`
539557
Content string `json:"content,omitempty"`
@@ -569,7 +587,18 @@ func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err e
569587
return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil
570588
}
571589

572-
func SysGetenv(_ context.Context, env []string, input string) (string, error) {
590+
func DiscardProgress() (progress chan<- string, closeFunc func()) {
591+
ch := make(chan string)
592+
go func() {
593+
for range ch {
594+
}
595+
}()
596+
return ch, func() {
597+
close(ch)
598+
}
599+
}
600+
601+
func SysGetenv(_ context.Context, env []string, input string, _ chan<- string) (string, error) {
573602
var params struct {
574603
Name string `json:"name,omitempty"`
575604
}
@@ -597,7 +626,7 @@ func invalidArgument(input string, err error) string {
597626
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
598627
}
599628

600-
func SysChatHistory(ctx context.Context, _ []string, _ string) (string, error) {
629+
func SysChatHistory(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
601630
engineContext, _ := engine.FromContext(ctx)
602631

603632
data, err := json.Marshal(engine.ChatHistory{
@@ -627,7 +656,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
627656
return
628657
}
629658

630-
func SysChatFinish(_ context.Context, _ []string, input string) (string, error) {
659+
func SysChatFinish(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
631660
var params struct {
632661
Message string `json:"return,omitempty"`
633662
}
@@ -641,7 +670,7 @@ func SysChatFinish(_ context.Context, _ []string, input string) (string, error)
641670
}
642671
}
643672

644-
func SysAbort(_ context.Context, _ []string, input string) (string, error) {
673+
func SysAbort(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
645674
var params struct {
646675
Message string `json:"message,omitempty"`
647676
}
@@ -651,7 +680,7 @@ func SysAbort(_ context.Context, _ []string, input string) (string, error) {
651680
return "", fmt.Errorf("ABORT: %s", params.Message)
652681
}
653682

654-
func SysRemove(_ context.Context, _ []string, input string) (string, error) {
683+
func SysRemove(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
655684
var params struct {
656685
Location string `json:"location,omitempty"`
657686
}
@@ -670,7 +699,7 @@ func SysRemove(_ context.Context, _ []string, input string) (string, error) {
670699
return fmt.Sprintf("Removed file: %s", params.Location), nil
671700
}
672701

673-
func SysStat(_ context.Context, _ []string, input string) (string, error) {
702+
func SysStat(_ context.Context, _ []string, input string, _ chan<- string) (string, error) {
674703
var params struct {
675704
Filepath string `json:"filepath,omitempty"`
676705
}
@@ -690,7 +719,7 @@ func SysStat(_ context.Context, _ []string, input string) (string, error) {
690719
return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil
691720
}
692721

693-
func SysDownload(_ context.Context, env []string, input string) (_ string, err error) {
722+
func SysDownload(_ context.Context, env []string, input string, _ chan<- string) (_ string, err error) {
694723
var params struct {
695724
URL string `json:"url,omitempty"`
696725
Location string `json:"location,omitempty"`
@@ -763,6 +792,6 @@ func SysDownload(_ context.Context, env []string, input string) (_ string, err e
763792
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
764793
}
765794

766-
func SysTimeNow(context.Context, []string, string) (string, error) {
795+
func SysTimeNow(context.Context, []string, string, chan<- string) (string, error) {
767796
return time.Now().Format(time.RFC3339), nil
768797
}

pkg/builtin/builtin_test.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,17 @@ import (
1010
)
1111

1212
func TestSysGetenv(t *testing.T) {
13+
p, c := DiscardProgress()
14+
defer c()
1315
v, err := SysGetenv(context.Background(), []string{
1416
"MAGIC=VALUE",
15-
}, `{"name":"MAGIC"}`)
17+
}, `{"name":"MAGIC"}`, nil)
1618
require.NoError(t, err)
1719
autogold.Expect("VALUE").Equal(t, v)
1820

1921
v, err = SysGetenv(context.Background(), []string{
2022
"MAGIC=VALUE",
21-
}, `{"name":"MAGIC2"}`)
23+
}, `{"name":"MAGIC2"}`, p)
2224
require.NoError(t, err)
2325
autogold.Expect("MAGIC2 is not set or has no value").Equal(t, v)
2426
}

pkg/engine/cmd.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"runtime"
1313
"sort"
1414
"strings"
15+
"sync"
1516

1617
"github.com/google/shlex"
1718
"github.com/gptscript-ai/gptscript/pkg/counter"
@@ -64,7 +65,30 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
6465
"input": input,
6566
},
6667
}
67-
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input)
68+
69+
var (
70+
progress = make(chan string)
71+
wg sync.WaitGroup
72+
)
73+
wg.Add(1)
74+
defer wg.Wait()
75+
defer close(progress)
76+
go func() {
77+
defer wg.Done()
78+
buf := strings.Builder{}
79+
for line := range progress {
80+
buf.WriteString(line)
81+
e.Progress <- types.CompletionStatus{
82+
CompletionID: id,
83+
PartialResponse: &types.CompletionMessage{
84+
Role: types.CompletionMessageRoleTypeAssistant,
85+
Content: types.Text(buf.String()),
86+
},
87+
}
88+
}
89+
}()
90+
91+
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
6892
}
6993

7094
var instructions []string

pkg/prompt/credential.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@ func GetModelProviderCredential(ctx context.Context, credStore credentials.Crede
1818
if exists {
1919
k = cred.Env[env]
2020
} else {
21-
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
21+
// we know progress isn't used so pass as nil
22+
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message), nil)
2223
if err != nil {
2324
return "", err
2425
}

pkg/prompt/prompt.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
4848
return string(data), err
4949
}
5050

51-
func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
51+
func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string) (_ string, err error) {
5252
var params struct {
5353
Message string `json:"message,omitempty"`
5454
Fields string `json:"fields,omitempty"`

pkg/types/tool.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ func (p Program) SetBlocking() Program {
117117
return p
118118
}
119119

120-
type BuiltinFunc func(ctx context.Context, env []string, input string) (string, error)
120+
type BuiltinFunc func(ctx context.Context, env []string, input string, progress chan<- string) (string, error)
121121

122122
type Parameters struct {
123123
Name string `json:"name,omitempty"`

0 commit comments

Comments
 (0)