Skip to content

Commit 799c2b7

Browse files
chore: add authorization hook
1 parent 4a35d81 commit 799c2b7

File tree

2 files changed

+39
-1
lines changed

2 files changed

+39
-1
lines changed

pkg/openai/client.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -435,7 +435,9 @@ func appendMessage(msg types.CompletionMessage, response openai.ChatCompletionSt
435435
tc.ToolCall.Index = tool.Index
436436
}
437437
tc.ToolCall.ID = override(tc.ToolCall.ID, tool.ID)
438-
tc.ToolCall.Function.Name += tool.Function.Name
438+
if tc.ToolCall.Function.Name != tool.Function.Name {
439+
tc.ToolCall.Function.Name += tool.Function.Name
440+
}
439441
tc.ToolCall.Function.Arguments += tool.Function.Arguments
440442

441443
msg.Content[idx] = tc

pkg/runner/runner.go

+36
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@ type Options struct {
3636
EndPort int64 `usage:"-"`
3737
CredentialOverride string `usage:"-"`
3838
Sequential bool `usage:"-"`
39+
Authorizer AuthorizerFunc `usage:"-"`
40+
}
41+
42+
type AuthorizerResponse struct {
43+
Accept bool
44+
Message string
45+
}
46+
47+
type AuthorizerFunc func(ctx engine.Context, input string) (AuthorizerResponse, error)
48+
49+
func DefaultAuthorizer(ctx engine.Context, input string) (AuthorizerResponse, error) {
50+
return AuthorizerResponse{
51+
Accept: true,
52+
}, nil
3953
}
4054

4155
func complete(opts ...Options) (result Options) {
@@ -46,6 +60,9 @@ func complete(opts ...Options) (result Options) {
4660
result.EndPort = types.FirstSet(opt.EndPort, result.EndPort)
4761
result.CredentialOverride = types.FirstSet(opt.CredentialOverride, result.CredentialOverride)
4862
result.Sequential = types.FirstSet(opt.Sequential, result.Sequential)
63+
if opt.Authorizer != nil {
64+
result.Authorizer = opt.Authorizer
65+
}
4966
}
5067
if result.MonitorFactory == nil {
5168
result.MonitorFactory = noopFactory{}
@@ -56,11 +73,15 @@ func complete(opts ...Options) (result Options) {
5673
if result.StartPort == 0 {
5774
result.StartPort = result.EndPort
5875
}
76+
if result.Authorizer == nil {
77+
result.Authorizer = DefaultAuthorizer
78+
}
5979
return
6080
}
6181

6282
type Runner struct {
6383
c engine.Model
84+
auth AuthorizerFunc
6485
factory MonitorFactory
6586
runtimeManager engine.RuntimeManager
6687
ports engine.Ports
@@ -81,6 +102,7 @@ func New(client engine.Model, credCtx string, opts ...Options) (*Runner, error)
81102
credMutex: sync.Mutex{},
82103
credOverrides: opt.CredentialOverride,
83104
sequential: opt.Sequential,
105+
auth: opt.Authorizer,
84106
}
85107

86108
if opt.StartPort != 0 {
@@ -405,6 +427,20 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
405427

406428
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
407429

430+
authResp, err := r.auth(callCtx, input)
431+
if err != nil {
432+
return nil, err
433+
}
434+
435+
if !authResp.Accept {
436+
msg := fmt.Sprintf("[CALL REJECTED]: %s", authResp.Message)
437+
return &State{
438+
Continuation: &engine.Return{
439+
Result: &msg,
440+
},
441+
}, nil
442+
}
443+
408444
ret, err := e.Start(callCtx, input)
409445
if err != nil {
410446
return nil, err

0 commit comments

Comments
 (0)