Skip to content

Commit 13c07d6

Browse files
fix: misc improvements related to creds and prompting (#536)
Signed-off-by: Grant Linville <[email protected]> Co-authored-by: Donnie Adams <[email protected]>
1 parent 0425038 commit 13c07d6

File tree

4 files changed

+27
-13
lines changed

4 files changed

+27
-13
lines changed

pkg/prompt/prompt.go

+12
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"context"
66
"encoding/json"
7+
"errors"
78
"fmt"
89
"io"
910
"net/http"
@@ -75,6 +76,17 @@ func SysPrompt(ctx context.Context, envs []string, input string, _ chan<- string
7576
func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
7677
defer context2.GetPauseFuncFromCtx(ctx)()()
7778

79+
if req.Message != "" && len(req.Fields) == 1 && strings.TrimSpace(req.Fields[0]) == "" {
80+
var errs []error
81+
_, err := fmt.Fprintln(os.Stderr, req.Message)
82+
errs = append(errs, err)
83+
_, err = fmt.Fprintln(os.Stderr, "Press enter to continue...")
84+
errs = append(errs, err)
85+
_, err = fmt.Fscanln(os.Stdin)
86+
errs = append(errs, err)
87+
return "", errors.Join(errs...)
88+
}
89+
7890
if req.Message != "" && len(req.Fields) != 1 {
7991
_, _ = fmt.Fprintln(os.Stderr, req.Message)
8092
}

pkg/runner/runner.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
567567
Time: time.Now(),
568568
CallContext: callCtx.GetCallContext(),
569569
Type: EventTypeCallFinish,
570-
Content: getFinishEventContent(*state, callCtx),
570+
Content: getEventContent(*state.Continuation.Result, callCtx),
571571
})
572572
if callCtx.Tool.Chat {
573573
return &State{
@@ -681,7 +681,7 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
681681
CallContext: callCtx.GetCallContext(),
682682
Type: EventTypeCallProgress,
683683
ChatCompletionID: status.CompletionID,
684-
Content: message.String(),
684+
Content: getEventContent(message.String(), *callCtx),
685685
})
686686
} else {
687687
monitor.Event(Event{
@@ -821,13 +821,13 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
821821
return state, callResults, nil
822822
}
823823

824-
func getFinishEventContent(state State, callCtx engine.Context) string {
825-
// If it is a credential tool, the finish event contains its output, which is sensitive, so we don't return it.
824+
func getEventContent(content string, callCtx engine.Context) string {
825+
// If it is a credential tool, the progress and finish events may contain its output, which is sensitive, so we don't return it.
826826
if callCtx.ToolCategory == engine.CredentialToolCategory {
827827
return ""
828828
}
829829

830-
return *state.Continuation.Result
830+
return content
831831
}
832832

833833
func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env []string) ([]string, error) {

pkg/types/credential_test.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -125,10 +125,12 @@ func TestParseCredentialArgs(t *testing.T) {
125125
wantErr: true,
126126
},
127127
{
128-
name: "invalid input",
129-
toolName: "myCredentialTool",
130-
input: `{"asdf":"asdf"`,
131-
wantErr: true,
128+
name: "invalid input",
129+
toolName: "myCredentialTool",
130+
input: `{"asdf":"asdf"`,
131+
expectedName: "myCredentialTool",
132+
expectedAlias: "",
133+
wantErr: false,
132134
},
133135
}
134136

pkg/types/tool.go

+4-4
Original file line numberDiff line numberDiff line change
@@ -255,10 +255,10 @@ func ParseCredentialArgs(toolName string, input string) (string, string, map[str
255255

256256
inputMap := make(map[string]any)
257257
if input != "" {
258-
err := json.Unmarshal([]byte(input), &inputMap)
259-
if err != nil {
260-
return "", "", nil, fmt.Errorf("failed to unmarshal input: %w", err)
261-
}
258+
// Sometimes this function can be called with input that is not a JSON string.
259+
// This typically happens during chat mode.
260+
// That's why we ignore the error if this fails to unmarshal.
261+
_ = json.Unmarshal([]byte(input), &inputMap)
262262
}
263263

264264
fields, err := shlex.Split(toolName)

0 commit comments

Comments
 (0)