Skip to content

Commit c74db50

Browse files
feat: add prompt callback URL
1 parent e65513e commit c74db50

File tree

1 file changed

+38
-1
lines changed

1 file changed

+38
-1
lines changed

pkg/builtin/builtin.go

+38-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package builtin
22

33
import (
4+
"bytes"
45
"context"
56
"encoding/json"
67
"errors"
@@ -774,7 +775,37 @@ func SysDownload(ctx context.Context, env []string, input string) (_ string, err
774775
return params.Location, nil
775776
}
776777

777-
func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error) {
778+
func sysPromptHTTP(ctx context.Context, url, message string, fields []string, sensitive bool) (_ string, err error) {
779+
data, err := json.Marshal(map[string]any{
780+
"message": message,
781+
"fields": fields,
782+
"sensitive": sensitive,
783+
})
784+
if err != nil {
785+
return "", err
786+
}
787+
788+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(data))
789+
if err != nil {
790+
return "", err
791+
}
792+
req.Header.Set("Content-Type", "application/json")
793+
794+
resp, err := http.DefaultClient.Do(req)
795+
if err != nil {
796+
return "", err
797+
}
798+
resp.Body.Close()
799+
800+
if resp.StatusCode != 200 {
801+
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
802+
}
803+
804+
data, err = io.ReadAll(resp.Body)
805+
return string(data), err
806+
}
807+
808+
func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err error) {
778809
var params struct {
779810
Message string `json:"message,omitempty"`
780811
Fields string `json:"fields,omitempty"`
@@ -784,6 +815,12 @@ func SysPrompt(_ context.Context, _ []string, input string) (_ string, err error
784815
return "", err
785816
}
786817

818+
for _, env := range envs {
819+
if url, ok := strings.CutPrefix(env, "GPTSCRIPT_PROMPT_URL="); ok {
820+
return sysPromptHTTP(ctx, url, params.Message, strings.Split(params.Fields, ","), params.Sensitive == "true")
821+
}
822+
}
823+
787824
if params.Message != "" {
788825
_, _ = fmt.Fprintln(os.Stderr, params.Message)
789826
}

0 commit comments

Comments
 (0)