Skip to content

Commit 29d0080

Browse files
feat: add sys.model.provider.credential
1 parent 50503da commit 29d0080

File tree

12 files changed

+235
-17
lines changed

12 files changed

+235
-17
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ require (
1616
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1717
github.com/google/uuid v1.6.0
1818
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
19-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
19+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
2020
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
2121
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17
2222
github.com/gptscript-ai/tui v0.0.0-20240804004233-efc5673dc76e

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
200200
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
201201
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
202202
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
203-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
204-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
203+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
204+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
205205
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
206206
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
207207
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 h1:BTfJ6ls31Roq42lznlZnuPzRf0wrT8jT+tWcvq7wDXY=

pkg/builtin/builtin.go

+25
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,15 @@ var tools = map[string]types.Tool{
248248
BuiltinFunc: SysContext,
249249
},
250250
},
251+
"sys.model.provider.credential": {
252+
ToolDef: types.ToolDef{
253+
Parameters: types.Parameters{
254+
Description: "A credential tool to set the OPENAI_API_KEY and OPENAI_BASE_URL to give access to the default model provider",
255+
Arguments: types.ObjectSchema(),
256+
},
257+
BuiltinFunc: SysModelProviderCredential,
258+
},
259+
},
251260
}
252261

253262
func ListTools() (result []types.Tool) {
@@ -678,6 +687,22 @@ func invalidArgument(input string, err error) string {
678687
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
679688
}
680689

690+
func SysModelProviderCredential(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
691+
engineContext, _ := engine.FromContext(ctx)
692+
auth, url, err := engineContext.Engine.Model.ProxyInfo()
693+
if err != nil {
694+
return "", err
695+
}
696+
data, err := json.Marshal(map[string]any{
697+
"env": map[string]string{
698+
"OPENAI_API_KEY": auth,
699+
"OPENAI_BASE_URL": url,
700+
},
701+
"ephemeral": true,
702+
})
703+
return string(data), err
704+
}
705+
681706
func SysContext(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
682707
engineContext, _ := engine.FromContext(ctx)
683708

pkg/credentials/credential.go

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type Credential struct {
2424
ToolName string `json:"toolName"`
2525
Type CredentialType `json:"type"`
2626
Env map[string]string `json:"env"`
27+
Ephemeral bool `json:"ephemeral,omitempty"`
2728
ExpiresAt *time.Time `json:"expiresAt"`
2829
RefreshToken string `json:"refreshToken"`
2930
}

pkg/engine/cmd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
109109
}
110110
}()
111111

112-
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
112+
return tool.BuiltinFunc(ctx.WrappedContext(e), e.Env, input, progress)
113113
}
114114

115115
var instructions []string

pkg/engine/engine.go

+11-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616

1717
type Model interface {
1818
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
19+
ProxyInfo() (string, string, error)
1920
}
2021

2122
type RuntimeManager interface {
@@ -79,6 +80,7 @@ type Context struct {
7980
Parent *Context
8081
LastReturn *Return
8182
CurrentReturn *Return
83+
Engine *Engine
8284
Program *types.Program
8385
// Input is saved only so that we can render display text, don't use otherwise
8486
Input string
@@ -250,8 +252,10 @@ func FromContext(ctx context.Context) (*Context, bool) {
250252
return c, ok
251253
}
252254

253-
func (c *Context) WrappedContext() context.Context {
254-
return context.WithValue(c.Ctx, engineContext{}, c)
255+
func (c *Context) WrappedContext(e *Engine) context.Context {
256+
cp := *c
257+
cp.Engine = e
258+
return context.WithValue(c.Ctx, engineContext{}, &cp)
255259
}
256260

257261
func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
@@ -280,6 +284,11 @@ func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {
280284
return &Return{
281285
Result: &s,
282286
}, nil
287+
} else if tool.IsNoop() {
288+
var empty string
289+
return &Return{
290+
Result: &empty,
291+
}, nil
283292
}
284293

285294
if ctx.ToolCategory == CredentialToolCategory {

pkg/llm/proxy.go

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package llm
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"net"
8+
"net/http"
9+
"net/http/httputil"
10+
"net/url"
11+
"path"
12+
"strings"
13+
14+
"github.com/gptscript-ai/gptscript/pkg/builtin"
15+
"github.com/gptscript-ai/gptscript/pkg/openai"
16+
)
17+
18+
func (r *Registry) ProxyInfo() (string, string, error) {
19+
r.proxyLock.Lock()
20+
defer r.proxyLock.Unlock()
21+
22+
if r.proxyURL != "" {
23+
return r.proxyToken, r.proxyURL, nil
24+
}
25+
26+
l, err := net.Listen("tcp", "127.0.0.1:0")
27+
if err != nil {
28+
return "", "", err
29+
}
30+
31+
go func() {
32+
_ = http.Serve(l, r)
33+
r.proxyLock.Lock()
34+
defer r.proxyLock.Unlock()
35+
_ = l.Close()
36+
r.proxyURL = ""
37+
}()
38+
39+
r.proxyURL = "http://" + l.Addr().String()
40+
return r.proxyToken, r.proxyURL, nil
41+
}
42+
43+
func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
44+
if r.proxyToken != strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") {
45+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
46+
return
47+
}
48+
49+
inBytes, err := io.ReadAll(req.Body)
50+
if err != nil {
51+
http.Error(w, err.Error(), http.StatusBadRequest)
52+
return
53+
}
54+
55+
var (
56+
model string
57+
data = map[string]any{}
58+
)
59+
60+
if json.Unmarshal(inBytes, &data) == nil {
61+
model, _ = data["model"].(string)
62+
}
63+
64+
if model == "" {
65+
model = builtin.GetDefaultModel()
66+
}
67+
68+
c, err := r.getClient(req.Context(), model)
69+
if err != nil {
70+
http.Error(w, err.Error(), http.StatusInternalServerError)
71+
return
72+
}
73+
74+
oai, ok := c.(*openai.Client)
75+
if !ok {
76+
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
77+
return
78+
}
79+
80+
auth, targetURL := oai.ProxyInfo()
81+
if targetURL == "" {
82+
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
83+
return
84+
}
85+
86+
newURL, err := url.Parse(targetURL)
87+
if err != nil {
88+
http.Error(w, err.Error(), http.StatusInternalServerError)
89+
return
90+
}
91+
92+
newURL.Path = path.Join(newURL.Path, req.URL.Path)
93+
94+
rp := httputil.ReverseProxy{
95+
Director: func(proxyReq *http.Request) {
96+
proxyReq.Body = io.NopCloser(bytes.NewReader(inBytes))
97+
proxyReq.URL = newURL
98+
proxyReq.Header.Del("Authorization")
99+
proxyReq.Header.Add("Authorization", "Bearer "+auth)
100+
proxyReq.Host = newURL.Hostname()
101+
},
102+
}
103+
rp.ServeHTTP(w, req)
104+
}

pkg/llm/registry.go

+55-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import (
55
"errors"
66
"fmt"
77
"sort"
8+
"sync"
89

10+
"github.com/google/uuid"
11+
"github.com/gptscript-ai/gptscript/pkg/env"
912
"github.com/gptscript-ai/gptscript/pkg/openai"
1013
"github.com/gptscript-ai/gptscript/pkg/remote"
1114
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -18,11 +21,16 @@ type Client interface {
1821
}
1922

2023
type Registry struct {
21-
clients []Client
24+
proxyToken string
25+
proxyURL string
26+
proxyLock sync.Mutex
27+
clients []Client
2228
}
2329

2430
func NewRegistry() *Registry {
25-
return &Registry{}
31+
return &Registry{
32+
proxyToken: env.VarOrDefault("GPTSCRIPT_INTERNAL_PROXY_TOKEN", uuid.New().String()),
33+
}
2634
}
2735

2836
func (r *Registry) AddClient(client Client) error {
@@ -44,6 +52,10 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result
4452

4553
func (r *Registry) fastPath(modelName string) Client {
4654
// This is optimization hack to avoid doing List Models
55+
if len(r.clients) == 1 {
56+
return r.clients[0]
57+
}
58+
4759
if len(r.clients) != 2 {
4860
return nil
4961
}
@@ -66,6 +78,47 @@ func (r *Registry) fastPath(modelName string) Client {
6678
return r.clients[0]
6779
}
6880

81+
func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
82+
if c := r.fastPath(modelName); c != nil {
83+
return c, nil
84+
}
85+
86+
var errs []error
87+
var oaiClient *openai.Client
88+
for _, client := range r.clients {
89+
ok, err := client.Supports(ctx, modelName)
90+
if err != nil {
91+
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
92+
if errors.Is(err, openai.InvalidAuthError{}) {
93+
oaiClient = client.(*openai.Client)
94+
}
95+
96+
errs = append(errs, err)
97+
} else if ok {
98+
return client, nil
99+
}
100+
}
101+
102+
if len(errs) > 0 && oaiClient != nil {
103+
// Prompt the user to enter their OpenAI API key and try again.
104+
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
105+
return nil, err
106+
}
107+
ok, err := oaiClient.Supports(ctx, modelName)
108+
if err != nil {
109+
return nil, err
110+
} else if ok {
111+
return oaiClient, nil
112+
}
113+
}
114+
115+
if len(errs) == 0 {
116+
return nil, fmt.Errorf("failed to find a model provider for model [%s]", modelName)
117+
}
118+
119+
return nil, errors.Join(errs...)
120+
}
121+
69122
func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
70123
if messageRequest.Model == "" {
71124
return nil, fmt.Errorf("model is required")

pkg/openai/client.go

+7
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
130130
}, nil
131131
}
132132

133+
func (c *Client) ProxyInfo() (token, urlBase string) {
134+
if c.invalidAuth {
135+
return "", ""
136+
}
137+
return c.c.GetAPIKeyAndBaseURL()
138+
}
139+
133140
func (c *Client) ValidAuth() error {
134141
if c.invalidAuth {
135142
return InvalidAuthError{}

pkg/runner/runner.go

+14-8
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
944944
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
945945
}
946946

947+
if *res.Result == "" {
948+
continue
949+
}
950+
947951
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
948952
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
949953
}
@@ -958,15 +962,17 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
958962
}
959963
}
960964

961-
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
962-
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
963-
if isEmpty {
964-
log.Warnf("Not saving empty credential for tool %s", toolName)
965-
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
966-
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
965+
if !c.Ephemeral {
966+
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
967+
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
968+
if isEmpty {
969+
log.Warnf("Not saving empty credential for tool %s", toolName)
970+
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
971+
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
972+
}
973+
} else {
974+
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
967975
}
968-
} else {
969-
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
970976
}
971977
}
972978

pkg/tests/tester/runner.go

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ type Result struct {
3131
Err error
3232
}
3333

34+
func (c *Client) ProxyInfo() (string, string, error) {
35+
return "test-auth", "test-url", nil
36+
}
37+
3438
func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
3539
msgData, err := json.MarshalIndent(messageRequest, "", " ")
3640
require.NoError(c.t, err)

pkg/types/tool.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,16 @@ func (t Tool) GetCredentialTools(prg Program, agentGroup []ToolReference) ([]Too
753753

754754
result.AddAll(t.getCompletionToolRefs(prg, nil, ToolTypeCredential))
755755

756-
toolRefs, err := t.getCompletionToolRefs(prg, agentGroup)
756+
toolRefs, err := result.List()
757+
if err != nil {
758+
return nil, err
759+
}
760+
for _, toolRef := range toolRefs {
761+
referencedTool := prg.ToolSet[toolRef.ToolID]
762+
result.AddAll(referencedTool.GetToolRefsFromNames(referencedTool.ExportCredentials))
763+
}
764+
765+
toolRefs, err = t.getCompletionToolRefs(prg, agentGroup)
757766
if err != nil {
758767
return nil, err
759768
}

0 commit comments

Comments
 (0)