Skip to content

Commit f636186

Browse files
committed
feat: add prompt support to the SDK server
Signed-off-by: Donnie Adams <[email protected]>
1 parent d2e0f57 commit f636186

File tree

9 files changed

+309
-65
lines changed

9 files changed

+309
-65
lines changed

pkg/builtin/builtin.go

+23-22
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,7 @@ func Builtin(name string) (types.Tool, bool) {
265265
return SetDefaults(t), ok
266266
}
267267

268-
func SysFind(ctx context.Context, env []string, input string) (string, error) {
268+
func SysFind(_ context.Context, _ []string, input string) (string, error) {
269269
var result []string
270270
var params struct {
271271
Pattern string `json:"pattern,omitempty"`
@@ -306,7 +306,7 @@ func SysFind(ctx context.Context, env []string, input string) (string, error) {
306306
return strings.Join(result, "\n"), nil
307307
}
308308

309-
func SysExec(ctx context.Context, env []string, input string) (string, error) {
309+
func SysExec(_ context.Context, env []string, input string) (string, error) {
310310
var params struct {
311311
Command string `json:"command,omitempty"`
312312
Directory string `json:"directory,omitempty"`
@@ -412,7 +412,7 @@ func SysRead(_ context.Context, _ []string, input string) (string, error) {
412412
return string(data), nil
413413
}
414414

415-
func SysWrite(ctx context.Context, _ []string, input string) (string, error) {
415+
func SysWrite(_ context.Context, _ []string, input string) (string, error) {
416416
var params struct {
417417
Filename string `json:"filename,omitempty"`
418418
Content string `json:"content,omitempty"`
@@ -444,7 +444,7 @@ func SysWrite(ctx context.Context, _ []string, input string) (string, error) {
444444
return fmt.Sprintf("Wrote (%d) bytes to file %s", len(data), file), nil
445445
}
446446

447-
func SysAppend(ctx context.Context, env []string, input string) (string, error) {
447+
func SysAppend(_ context.Context, _ []string, input string) (string, error) {
448448
var params struct {
449449
Filename string `json:"filename,omitempty"`
450450
Content string `json:"content,omitempty"`
@@ -490,7 +490,7 @@ func fixQueries(u string) string {
490490
return url.String()
491491
}
492492

493-
func SysHTTPGet(ctx context.Context, env []string, input string) (_ string, err error) {
493+
func SysHTTPGet(_ context.Context, _ []string, input string) (_ string, err error) {
494494
var params struct {
495495
URL string `json:"url,omitempty"`
496496
}
@@ -534,7 +534,7 @@ func SysHTTPHtml2Text(ctx context.Context, env []string, input string) (string,
534534
})
535535
}
536536

537-
func SysHTTPPost(ctx context.Context, env []string, input string) (_ string, err error) {
537+
func SysHTTPPost(ctx context.Context, _ []string, input string) (_ string, err error) {
538538
var params struct {
539539
URL string `json:"url,omitempty"`
540540
Content string `json:"content,omitempty"`
@@ -570,7 +570,7 @@ func SysHTTPPost(ctx context.Context, env []string, input string) (_ string, err
570570
return fmt.Sprintf("Wrote %d to %s", len([]byte(params.Content)), params.URL), nil
571571
}
572572

573-
func SysGetenv(ctx context.Context, env []string, input string) (string, error) {
573+
func SysGetenv(_ context.Context, env []string, input string) (string, error) {
574574
var params struct {
575575
Name string `json:"name,omitempty"`
576576
}
@@ -636,7 +636,7 @@ func writeHistory(ctx *engine.Context) (result []engine.ChatHistoryCall) {
636636
return
637637
}
638638

639-
func SysChatFinish(ctx context.Context, env []string, input string) (string, error) {
639+
func SysChatFinish(_ context.Context, _ []string, input string) (string, error) {
640640
var params struct {
641641
Message string `json:"return,omitempty"`
642642
}
@@ -650,7 +650,7 @@ func SysChatFinish(ctx context.Context, env []string, input string) (string, err
650650
}
651651
}
652652

653-
func SysAbort(ctx context.Context, env []string, input string) (string, error) {
653+
func SysAbort(_ context.Context, _ []string, input string) (string, error) {
654654
var params struct {
655655
Message string `json:"message,omitempty"`
656656
}
@@ -660,7 +660,7 @@ func SysAbort(ctx context.Context, env []string, input string) (string, error) {
660660
return "", fmt.Errorf("ABORT: %s", params.Message)
661661
}
662662

663-
func SysRemove(ctx context.Context, env []string, input string) (string, error) {
663+
func SysRemove(_ context.Context, _ []string, input string) (string, error) {
664664
var params struct {
665665
Location string `json:"location,omitempty"`
666666
}
@@ -679,7 +679,7 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error)
679679
return fmt.Sprintf("Removed file: %s", params.Location), nil
680680
}
681681

682-
func SysStat(ctx context.Context, env []string, input string) (string, error) {
682+
func SysStat(_ context.Context, _ []string, input string) (string, error) {
683683
var params struct {
684684
Filepath string `json:"filepath,omitempty"`
685685
}
@@ -699,7 +699,7 @@ func SysStat(ctx context.Context, env []string, input string) (string, error) {
699699
return fmt.Sprintf("%s %s mode: %s, size: %d bytes, modtime: %s", title, params.Filepath, stat.Mode().String(), stat.Size(), stat.ModTime().String()), nil
700700
}
701701

702-
func SysDownload(ctx context.Context, env []string, input string) (_ string, err error) {
702+
func SysDownload(_ context.Context, env []string, input string) (_ string, err error) {
703703
var params struct {
704704
URL string `json:"url,omitempty"`
705705
Location string `json:"location,omitempty"`
@@ -772,12 +772,8 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
772772
return fmt.Sprintf("Downloaded %s to %s", params.URL, params.Location), nil
773773
}
774774

775-
func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) {
776-
data, err := json.Marshal(map[string]any{
777-
"message": message,
778-
"fields": fields,
779-
"sensitive": sensitive,
780-
})
775+
func sysPromptHTTP(ctx context.Context, url string, prompt types.Prompt) (_ string, err error) {
776+
data, err := json.Marshal(prompt)
781777
if err != nil {
782778
return "", err
783779
}
@@ -792,7 +788,7 @@ func sysPromptHTTP(ctx context.Context, url, message string, fields []string, se
792788
if err != nil {
793789
return "", err
794790
}
795-
resp.Body.Close()
791+
defer resp.Body.Close()
796792

797793
if resp.StatusCode != 200 {
798794
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
@@ -813,8 +809,13 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err
813809
}
814810

815811
for _, env := range envs {
816-
if url, ok := strings.CutPrefix(env, "GPTSCRIPT_PROMPT_URL="); ok {
817-
return sysPromptHTTP(ctx, url, params.Message, strings.Split(params.Fields, ","), params.Sensitive == "true")
812+
if url, ok := strings.CutPrefix(env, types.PromptURLEnvVar+"="); ok {
813+
httpPrompt := types.Prompt{
814+
Message: params.Message,
815+
Fields: strings.Split(params.Fields, ","),
816+
Sensitive: params.Sensitive == "true",
817+
}
818+
return sysPromptHTTP(ctx, url, httpPrompt)
818819
}
819820
}
820821

@@ -844,6 +845,6 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err
844845
return string(resultsStr), nil
845846
}
846847

847-
func SysTimeNow(ctx context.Context, env []string, input string) (string, error) {
848+
func SysTimeNow(context.Context, []string, string) (string, error) {
848849
return time.Now().Format(time.RFC3339), nil
849850
}

pkg/sdkserver/confirm.go

+9-7
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ func (s *server) authorize(ctx engine.Context, input string) (runner.AuthorizerR
3939
s.lock.Unlock()
4040
}(ctx.ID)
4141

42-
s.events.C <- gserver.Event{
43-
Event: runner.Event{
44-
Time: time.Now(),
45-
CallContext: ctx.GetCallContext(),
46-
Type: CallConfirm,
42+
s.events.C <- event{
43+
Event: gserver.Event{
44+
Event: runner.Event{
45+
Time: time.Now(),
46+
CallContext: ctx.GetCallContext(),
47+
Type: CallConfirm,
48+
},
49+
Input: input,
50+
RunID: runID,
4751
},
48-
Input: input,
49-
RunID: runID,
5052
}
5153

5254
// Wait for the confirmation to come through.

pkg/sdkserver/monitor.go

+94
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package sdkserver
2+
3+
import (
4+
"context"
5+
"sync"
6+
"time"
7+
8+
"github.com/acorn-io/broadcaster"
9+
"github.com/gptscript-ai/gptscript/pkg/runner"
10+
gserver "github.com/gptscript-ai/gptscript/pkg/server"
11+
"github.com/gptscript-ai/gptscript/pkg/types"
12+
)
13+
14+
type SessionFactory struct {
15+
events *broadcaster.Broadcaster[event]
16+
}
17+
18+
func NewSessionFactory(events *broadcaster.Broadcaster[event]) *SessionFactory {
19+
return &SessionFactory{
20+
events: events,
21+
}
22+
}
23+
24+
func (s SessionFactory) Start(ctx context.Context, prg *types.Program, env []string, input string) (runner.Monitor, error) {
25+
id := gserver.RunIDFromContext(ctx)
26+
27+
s.events.C <- event{
28+
Event: gserver.Event{
29+
Event: runner.Event{
30+
Time: time.Now(),
31+
Type: runner.EventTypeRunStart,
32+
},
33+
RunID: id,
34+
Program: prg,
35+
},
36+
}
37+
38+
return &Session{
39+
id: id,
40+
prj: prg,
41+
env: env,
42+
input: input,
43+
events: s.events,
44+
}, nil
45+
}
46+
47+
type Session struct {
48+
id string
49+
prj *types.Program
50+
env []string
51+
input string
52+
events *broadcaster.Broadcaster[event]
53+
runLock sync.Mutex
54+
}
55+
56+
func (s *Session) Event(e runner.Event) {
57+
s.runLock.Lock()
58+
defer s.runLock.Unlock()
59+
s.events.C <- event{
60+
Event: gserver.Event{
61+
Event: e,
62+
RunID: s.id,
63+
Input: s.input,
64+
},
65+
}
66+
}
67+
68+
func (s *Session) Stop(output string, err error) {
69+
e := event{
70+
Event: gserver.Event{
71+
Event: runner.Event{
72+
Time: time.Now(),
73+
Type: runner.EventTypeRunFinish,
74+
},
75+
RunID: s.id,
76+
Input: s.input,
77+
Output: output,
78+
},
79+
}
80+
if err != nil {
81+
e.Err = err.Error()
82+
}
83+
84+
s.runLock.Lock()
85+
defer s.runLock.Unlock()
86+
s.events.C <- e
87+
}
88+
89+
func (s *Session) Pause() func() {
90+
s.runLock.Lock()
91+
return func() {
92+
s.runLock.Unlock()
93+
}
94+
}

pkg/sdkserver/prompt.go

+111
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
package sdkserver
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"net/http"
7+
"time"
8+
9+
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
10+
"github.com/gptscript-ai/gptscript/pkg/mvl"
11+
"github.com/gptscript-ai/gptscript/pkg/runner"
12+
gserver "github.com/gptscript-ai/gptscript/pkg/server"
13+
"github.com/gptscript-ai/gptscript/pkg/types"
14+
)
15+
16+
func (s *server) promptResponse(w http.ResponseWriter, r *http.Request) {
17+
logger := gcontext.GetLogger(r.Context())
18+
id := r.PathValue("id")
19+
20+
s.lock.RLock()
21+
promptChan := s.waitingToPrompt[id]
22+
s.lock.RUnlock()
23+
24+
if promptChan == nil {
25+
writeError(logger, w, http.StatusNotFound, fmt.Errorf("no prompt found with id %q", id))
26+
return
27+
}
28+
29+
var promptResponse map[string]string
30+
if err := json.NewDecoder(r.Body).Decode(&promptResponse); err != nil {
31+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %w", err))
32+
return
33+
}
34+
35+
// Don't block here because, if the prompter is no longer waiting on this then it will never unblock.
36+
select {
37+
case promptChan <- promptResponse:
38+
w.WriteHeader(http.StatusAccepted)
39+
default:
40+
w.WriteHeader(http.StatusConflict)
41+
}
42+
}
43+
44+
func (s *server) prompt(w http.ResponseWriter, r *http.Request) {
45+
logger := gcontext.GetLogger(r.Context())
46+
id := r.PathValue("id")
47+
48+
s.lock.RLock()
49+
promptChan := s.waitingToPrompt[id]
50+
s.lock.RUnlock()
51+
52+
if promptChan != nil {
53+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("prompt called multiple times for same ID: %s", id))
54+
return
55+
}
56+
57+
var prompt types.Prompt
58+
if err := json.NewDecoder(r.Body).Decode(&prompt); err != nil {
59+
writeError(logger, w, http.StatusBadRequest, fmt.Errorf("failed to decode request body: %v", err))
60+
return
61+
}
62+
63+
s.lock.Lock()
64+
promptChan = make(chan map[string]string)
65+
s.waitingToPrompt[id] = promptChan
66+
s.lock.Unlock()
67+
defer func(id string) {
68+
s.lock.Lock()
69+
delete(s.waitingToPrompt, id)
70+
s.lock.Unlock()
71+
}(id)
72+
73+
s.events.C <- event{
74+
Prompt: types.Prompt{
75+
Message: prompt.Message,
76+
Fields: prompt.Fields,
77+
Sensitive: prompt.Sensitive,
78+
},
79+
Event: gserver.Event{
80+
RunID: id,
81+
Event: runner.Event{
82+
Type: Prompt,
83+
Time: time.Now(),
84+
},
85+
},
86+
}
87+
88+
// Wait for the prompt response to come through.
89+
select {
90+
case <-r.Context().Done():
91+
writeError(logger, w, http.StatusInternalServerError, fmt.Errorf("context canceled: %v", r.Context().Err()))
92+
return
93+
case promptResponse := <-promptChan:
94+
writePromptResponse(logger, w, http.StatusOK, promptResponse)
95+
}
96+
}
97+
98+
func writePromptResponse(logger mvl.Logger, w http.ResponseWriter, code int, resp any) {
99+
b, err := json.Marshal(resp)
100+
if err != nil {
101+
logger.Errorf("failed to marshal response: %v", err)
102+
w.WriteHeader(http.StatusInternalServerError)
103+
} else {
104+
w.WriteHeader(code)
105+
}
106+
107+
_, err = w.Write(b)
108+
if err != nil {
109+
logger.Errorf("failed to write response: %v", err)
110+
}
111+
}

0 commit comments

Comments
 (0)