Skip to content

Commit 13bff8e

Browse files
chore: add authorization hook, change --confirm
1 parent 4a35d81 commit 13bff8e

File tree

8 files changed

+121
-82
lines changed

8 files changed

+121
-82
lines changed

pkg/auth/auth.go

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
package auth
2+
3+
import (
4+
"fmt"
5+
"path/filepath"
6+
"strings"
7+
8+
"github.com/AlecAivazis/survey/v2"
9+
"github.com/gptscript-ai/gptscript/pkg/builtin"
10+
"github.com/gptscript-ai/gptscript/pkg/context"
11+
"github.com/gptscript-ai/gptscript/pkg/engine"
12+
"github.com/gptscript-ai/gptscript/pkg/runner"
13+
)
14+
15+
func Authorize(ctx engine.Context, input string) (runner.AuthorizerResponse, error) {
16+
defer context.GetPauseFuncFromCtx(ctx.Ctx)()()
17+
18+
if !ctx.Tool.IsCommand() {
19+
return runner.AuthorizerResponse{
20+
Accept: true,
21+
}, nil
22+
}
23+
24+
var (
25+
result bool
26+
loc = ctx.Tool.Source.Location
27+
interpreter = strings.Split(ctx.Tool.Instructions, "\n")[0][2:]
28+
)
29+
30+
if _, ok := builtin.SafeTools[interpreter]; ok {
31+
return runner.AuthorizerResponse{
32+
Accept: true,
33+
}, nil
34+
}
35+
36+
if ctx.Tool.Source.Repo != nil {
37+
loc = ctx.Tool.Source.Repo.Root
38+
loc = strings.TrimPrefix(loc, "https://")
39+
loc = strings.TrimSuffix(loc, ".git")
40+
loc = filepath.Join(loc, ctx.Tool.Source.Repo.Path, ctx.Tool.Source.Repo.Name)
41+
}
42+
43+
if ctx.Tool.BuiltinFunc != nil {
44+
loc = "Builtin"
45+
}
46+
47+
err := survey.AskOne(&survey.Confirm{
48+
Help: fmt.Sprintf("The full source of the tools is as follows:\n\n%s", ctx.Tool.String()),
49+
Default: true,
50+
Message: fmt.Sprintf(`Description: %s
51+
Interpreter: %s
52+
Source: %s
53+
Input: %s
54+
Allow the above tool to execute?`, ctx.Tool.Description, interpreter, loc, strings.TrimSpace(input)),
55+
}, &result)
56+
if err != nil {
57+
return runner.AuthorizerResponse{}, err
58+
}
59+
60+
return runner.AuthorizerResponse{
61+
Accept: result,
62+
Message: "Request denied, blocking execution.",
63+
}, nil
64+
}

pkg/builtin/builtin.go

+7-21
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,17 @@ import (
2020
"github.com/AlecAivazis/survey/v2"
2121
"github.com/BurntSushi/locker"
2222
"github.com/google/shlex"
23-
"github.com/gptscript-ai/gptscript/pkg/confirm"
2423
"github.com/gptscript-ai/gptscript/pkg/types"
2524
"github.com/jaytaylor/html2text"
2625
)
2726

27+
var SafeTools = map[string]struct{}{
28+
"sys.echo": {},
29+
"sys.time.now": {},
30+
"sys.prompt": {},
31+
"sys.chat.finish": {},
32+
}
33+
2834
var tools = map[string]types.Tool{
2935
"sys.time.now": {
3036
Parameters: types.Parameters{
@@ -278,10 +284,6 @@ func SysExec(ctx context.Context, env []string, input string) (string, error) {
278284

279285
log.Debugf("Running %s in %s", params.Command, params.Directory)
280286

281-
if err := confirm.Promptf(ctx, "Run command: %s", params.Command); err != nil {
282-
return "", err
283-
}
284-
285287
var cmd *exec.Cmd
286288

287289
if runtime.GOOS == "windows" {
@@ -404,12 +406,6 @@ func SysWrite(ctx context.Context, _ []string, input string) (string, error) {
404406
}
405407
}
406408

407-
if _, err := os.Stat(file); err == nil {
408-
if err := confirm.Promptf(ctx, "Overwrite: %s", params.Filename); err != nil {
409-
return "", err
410-
}
411-
}
412-
413409
data := []byte(params.Content)
414410
log.Debugf("Wrote %d bytes to file %s", len(data), file)
415411

@@ -429,12 +425,6 @@ func SysAppend(ctx context.Context, env []string, input string) (string, error)
429425
locker.Lock(params.Filename)
430426
defer locker.Unlock(params.Filename)
431427

432-
if _, err := os.Stat(params.Filename); err == nil {
433-
if err := confirm.Promptf(ctx, "Write to existing file: %s.", params.Filename); err != nil {
434-
return "", err
435-
}
436-
}
437-
438428
f, err := os.OpenFile(params.Filename, os.O_APPEND|os.O_WRONLY|os.O_CREATE, 0644)
439429
if err != nil {
440430
return "", err
@@ -609,10 +599,6 @@ func SysRemove(ctx context.Context, env []string, input string) (string, error)
609599
return "", err
610600
}
611601

612-
if err := confirm.Promptf(ctx, "Remove: %s", params.Location); err != nil {
613-
return "", err
614-
}
615-
616602
// Lock the file to prevent concurrent writes from other tool calls.
617603
locker.Lock(params.Location)
618604
defer locker.Unlock(params.Location)

pkg/cli/eval.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -72,12 +72,12 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error {
7272
}
7373

7474
if e.Chat {
75-
return chat.Start(e.gptscript.NewRunContext(cmd), nil, runner, func() (types.Program, error) {
75+
return chat.Start(cmd.Context(), nil, runner, func() (types.Program, error) {
7676
return prg, nil
7777
}, os.Environ(), toolInput)
7878
}
7979

80-
toolOutput, err := runner.Run(e.gptscript.NewRunContext(cmd), prg, os.Environ(), toolInput)
80+
toolOutput, err := runner.Run(cmd.Context(), prg, os.Environ(), toolInput)
8181
if err != nil {
8282
return err
8383
}

pkg/cli/gptscript.go

+8-12
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,10 @@ import (
1313
"github.com/acorn-io/cmd"
1414
"github.com/fatih/color"
1515
"github.com/gptscript-ai/gptscript/pkg/assemble"
16+
"github.com/gptscript-ai/gptscript/pkg/auth"
1617
"github.com/gptscript-ai/gptscript/pkg/builtin"
1718
"github.com/gptscript-ai/gptscript/pkg/cache"
1819
"github.com/gptscript-ai/gptscript/pkg/chat"
19-
"github.com/gptscript-ai/gptscript/pkg/confirm"
2020
"github.com/gptscript-ai/gptscript/pkg/gptscript"
2121
"github.com/gptscript-ai/gptscript/pkg/input"
2222
"github.com/gptscript-ai/gptscript/pkg/loader"
@@ -117,14 +117,6 @@ func New() *cobra.Command {
117117
return command
118118
}
119119

120-
func (r *GPTScript) NewRunContext(cmd *cobra.Command) context.Context {
121-
ctx := cmd.Context()
122-
if r.Confirm {
123-
ctx = confirm.WithConfirm(ctx, confirm.TextPrompt{})
124-
}
125-
return ctx
126-
}
127-
128120
func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
129121
opts := gptscript.Options{
130122
Cache: cache.Options(r.CacheOptions),
@@ -140,6 +132,10 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
140132
Workspace: r.Workspace,
141133
}
142134

135+
if r.Confirm {
136+
opts.Runner.Authorizer = auth.Authorize
137+
}
138+
143139
if r.Ports != "" {
144140
start, end, _ := strings.Cut(r.Ports, "-")
145141
startNum, err := strconv.ParseInt(strings.TrimSpace(start), 10, 64)
@@ -388,7 +384,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
388384
}
389385

390386
if r.ChatState != "" {
391-
resp, err := gptScript.Chat(r.NewRunContext(cmd), r.ChatState, prg, os.Environ(), toolInput)
387+
resp, err := gptScript.Chat(cmd.Context(), r.ChatState, prg, os.Environ(), toolInput)
392388
if err != nil {
393389
return err
394390
}
@@ -400,12 +396,12 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
400396
}
401397

402398
if prg.IsChat() || r.ForceChat {
403-
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
399+
return chat.Start(cmd.Context(), nil, gptScript, func() (types.Program, error) {
404400
return r.readProgram(ctx, gptScript, args)
405401
}, os.Environ(), toolInput)
406402
}
407403

408-
s, err := gptScript.Run(r.NewRunContext(cmd), prg, os.Environ(), toolInput)
404+
s, err := gptScript.Run(cmd.Context(), prg, os.Environ(), toolInput)
409405
if err != nil {
410406
return err
411407
}

pkg/confirm/confirm.go

-45
This file was deleted.

pkg/mvl/log.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ func (f formatter) Format(entry *logrus.Entry) ([]byte, error) {
4343
}
4444
d, _ := json.Marshal(i)
4545
i = string(d)
46-
i = strings.TrimSpace(i[1 : len(i)-2])
46+
i = strings.TrimSpace(i[1 : len(i)-1])
4747
if addDot {
4848
i += "..."
4949
}

pkg/openai/client.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionSt
435435
tc.ToolCall.Index = tool.Index
436436
}
437437
tc.ToolCall.ID = override(tc.ToolCall.ID, tool.ID)
438-
tc.ToolCall.Function.Name += tool.Function.Name
438+
if tc.ToolCall.Function.Name != tool.Function.Name {
439+
tc.ToolCall.Function.Name += tool.Function.Name
440+
}
439441
tc.ToolCall.Function.Arguments += tool.Function.Arguments
440442

441443
msg.Content[idx] = tc

pkg/runner/runner.go

+36
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ type Options struct {
3636
EndPort int64 `usage:"-"`
3737
CredentialOverride string `usage:"-"`
3838
Sequential bool `usage:"-"`
39+
Authorizer AuthorizerFunc `usage:"-"`
40+
}
41+
42+
type AuthorizerResponse struct {
43+
Accept bool
44+
Message string
45+
}
46+
47+
type AuthorizerFunc func(ctx engine.Context, input string) (AuthorizerResponse, error)
48+
49+
func DefaultAuthorizer(ctx engine.Context, input string) (AuthorizerResponse, error) {
50+
return AuthorizerResponse{
51+
Accept: true,
52+
}, nil
3953
}
4054

4155
func complete(opts ...Options) (result Options) {
@@ -46,6 +60,9 @@ func complete(opts ...Options) (result Options) {
4660
result.EndPort = types.FirstSet(opt.EndPort, result.EndPort)
4761
result.CredentialOverride = types.FirstSet(opt.CredentialOverride, result.CredentialOverride)
4862
result.Sequential = types.FirstSet(opt.Sequential, result.Sequential)
63+
if opt.Authorizer != nil {
64+
result.Authorizer = opt.Authorizer
65+
}
4966
}
5067
if result.MonitorFactory == nil {
5168
result.MonitorFactory = noopFactory{}
@@ -56,11 +73,15 @@ func complete(opts ...Options) (result Options) {
5673
if result.StartPort == 0 {
5774
result.StartPort = result.EndPort
5875
}
76+
if result.Authorizer == nil {
77+
result.Authorizer = DefaultAuthorizer
78+
}
5979
return
6080
}
6181

6282
type Runner struct {
6383
c engine.Model
84+
auth AuthorizerFunc
6485
factory MonitorFactory
6586
runtimeManager engine.RuntimeManager
6687
ports engine.Ports
@@ -81,6 +102,7 @@ func New(client engine.Model, credCtx string, opts ...Options) (*Runner, error)
81102
credMutex: sync.Mutex{},
82103
credOverrides: opt.CredentialOverride,
83104
sequential: opt.Sequential,
105+
auth: opt.Authorizer,
84106
}
85107

86108
if opt.StartPort != 0 {
@@ -405,6 +427,20 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
405427

406428
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
407429

430+
authResp, err := r.auth(callCtx, input)
431+
if err != nil {
432+
return nil, err
433+
}
434+
435+
if !authResp.Accept {
436+
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
437+
return &State{
438+
Continuation: &engine.Return{
439+
Result: &msg,
440+
},
441+
}, nil
442+
}
443+
408444
ret, err := e.Start(callCtx, input)
409445
if err != nil {
410446
return nil, err

0 commit comments

Comments
 (0)