Skip to content

Commit fa52ff4

Browse files
Merge pull request #792 from ibuildthecloud/main
feat: add sys.model.provider.credential
2 parents 78b5c3d + 89cff87 commit fa52ff4

File tree

13 files changed

+245
-26
lines changed

13 files changed

+245
-26
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ require (
1515
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1616
github.com/google/uuid v1.6.0
1717
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86
18-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
18+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3
1919
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb
2020
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17
2121
github.com/gptscript-ai/tui v0.0.0-20240804004233-efc5673dc76e

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
200200
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
201201
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA=
202202
github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk=
203-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
204-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
203+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc=
204+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
205205
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc=
206206
github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw=
207207
github.com/gptscript-ai/go-gptscript v0.9.4-0.20240801203434-840b14393b17 h1:BTfJ6ls31Roq42lznlZnuPzRf0wrT8jT+tWcvq7wDXY=

pkg/builtin/builtin.go

+34-8
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,15 @@ import (
2626
)
2727

2828
var SafeTools = map[string]struct{}{
29-
"sys.abort": {},
30-
"sys.chat.finish": {},
31-
"sys.chat.history": {},
32-
"sys.chat.current": {},
33-
"sys.echo": {},
34-
"sys.prompt": {},
35-
"sys.time.now": {},
36-
"sys.context": {},
29+
"sys.abort": {},
30+
"sys.chat.finish": {},
31+
"sys.chat.history": {},
32+
"sys.chat.current": {},
33+
"sys.echo": {},
34+
"sys.prompt": {},
35+
"sys.time.now": {},
36+
"sys.context": {},
37+
"sys.model.provider.credential": {},
3738
}
3839

3940
var tools = map[string]types.Tool{
@@ -248,6 +249,15 @@ var tools = map[string]types.Tool{
248249
BuiltinFunc: SysContext,
249250
},
250251
},
252+
"sys.model.provider.credential": {
253+
ToolDef: types.ToolDef{
254+
Parameters: types.Parameters{
255+
Description: "A credential tool to set the OPENAI_API_KEY and OPENAI_BASE_URL to give access to the default model provider",
256+
Arguments: types.ObjectSchema(),
257+
},
258+
BuiltinFunc: SysModelProviderCredential,
259+
},
260+
},
251261
}
252262

253263
func ListTools() (result []types.Tool) {
@@ -678,6 +688,22 @@ func invalidArgument(input string, err error) string {
678688
return fmt.Sprintf("Failed to parse arguments %s: %v", input, err)
679689
}
680690

691+
func SysModelProviderCredential(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
692+
engineContext, _ := engine.FromContext(ctx)
693+
auth, url, err := engineContext.Engine.Model.ProxyInfo()
694+
if err != nil {
695+
return "", err
696+
}
697+
data, err := json.Marshal(map[string]any{
698+
"env": map[string]string{
699+
"OPENAI_API_KEY": auth,
700+
"OPENAI_BASE_URL": url,
701+
},
702+
"ephemeral": true,
703+
})
704+
return string(data), err
705+
}
706+
681707
func SysContext(ctx context.Context, _ []string, _ string, _ chan<- string) (string, error) {
682708
engineContext, _ := engine.FromContext(ctx)
683709

pkg/credentials/credential.go

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ type Credential struct {
2424
ToolName string `json:"toolName"`
2525
Type CredentialType `json:"type"`
2626
Env map[string]string `json:"env"`
27+
Ephemeral bool `json:"ephemeral,omitempty"`
2728
ExpiresAt *time.Time `json:"expiresAt"`
2829
RefreshToken string `json:"refreshToken"`
2930
}

pkg/engine/cmd.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate
109109
}
110110
}()
111111

112-
return tool.BuiltinFunc(ctx.WrappedContext(), e.Env, input, progress)
112+
return tool.BuiltinFunc(ctx.WrappedContext(e), e.Env, input, progress)
113113
}
114114

115115
var instructions []string

pkg/engine/engine.go

+6-2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import (
1616

1717
type Model interface {
1818
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
19+
ProxyInfo() (string, string, error)
1920
}
2021

2122
type RuntimeManager interface {
@@ -79,6 +80,7 @@ type Context struct {
7980
Parent *Context
8081
LastReturn *Return
8182
CurrentReturn *Return
83+
Engine *Engine
8284
Program *types.Program
8385
// Input is saved only so that we can render display text, don't use otherwise
8486
Input string
@@ -250,8 +252,10 @@ func FromContext(ctx context.Context) (*Context, bool) {
250252
return c, ok
251253
}
252254

253-
func (c *Context) WrappedContext() context.Context {
254-
return context.WithValue(c.Ctx, engineContext{}, c)
255+
func (c *Context) WrappedContext(e *Engine) context.Context {
256+
cp := *c
257+
cp.Engine = e
258+
return context.WithValue(c.Ctx, engineContext{}, &cp)
255259
}
256260

257261
func (e *Engine) Start(ctx Context, input string) (ret *Return, _ error) {

pkg/llm/proxy.go

+104
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,104 @@
1+
package llm
2+
3+
import (
4+
"bytes"
5+
"encoding/json"
6+
"io"
7+
"net"
8+
"net/http"
9+
"net/http/httputil"
10+
"net/url"
11+
"path"
12+
"strings"
13+
14+
"github.com/gptscript-ai/gptscript/pkg/builtin"
15+
"github.com/gptscript-ai/gptscript/pkg/openai"
16+
)
17+
18+
func (r *Registry) ProxyInfo() (string, string, error) {
19+
r.proxyLock.Lock()
20+
defer r.proxyLock.Unlock()
21+
22+
if r.proxyURL != "" {
23+
return r.proxyToken, r.proxyURL, nil
24+
}
25+
26+
l, err := net.Listen("tcp", "127.0.0.1:0")
27+
if err != nil {
28+
return "", "", err
29+
}
30+
31+
go func() {
32+
_ = http.Serve(l, r)
33+
r.proxyLock.Lock()
34+
defer r.proxyLock.Unlock()
35+
_ = l.Close()
36+
r.proxyURL = ""
37+
}()
38+
39+
r.proxyURL = "http://" + l.Addr().String()
40+
return r.proxyToken, r.proxyURL, nil
41+
}
42+
43+
func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) {
44+
if r.proxyToken != strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") {
45+
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
46+
return
47+
}
48+
49+
inBytes, err := io.ReadAll(req.Body)
50+
if err != nil {
51+
http.Error(w, err.Error(), http.StatusBadRequest)
52+
return
53+
}
54+
55+
var (
56+
model string
57+
data = map[string]any{}
58+
)
59+
60+
if json.Unmarshal(inBytes, &data) == nil {
61+
model, _ = data["model"].(string)
62+
}
63+
64+
if model == "" {
65+
model = builtin.GetDefaultModel()
66+
}
67+
68+
c, err := r.getClient(req.Context(), model)
69+
if err != nil {
70+
http.Error(w, err.Error(), http.StatusInternalServerError)
71+
return
72+
}
73+
74+
oai, ok := c.(*openai.Client)
75+
if !ok {
76+
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
77+
return
78+
}
79+
80+
auth, targetURL := oai.ProxyInfo()
81+
if targetURL == "" {
82+
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
83+
return
84+
}
85+
86+
newURL, err := url.Parse(targetURL)
87+
if err != nil {
88+
http.Error(w, err.Error(), http.StatusInternalServerError)
89+
return
90+
}
91+
92+
newURL.Path = path.Join(newURL.Path, req.URL.Path)
93+
94+
rp := httputil.ReverseProxy{
95+
Director: func(proxyReq *http.Request) {
96+
proxyReq.Body = io.NopCloser(bytes.NewReader(inBytes))
97+
proxyReq.URL = newURL
98+
proxyReq.Header.Del("Authorization")
99+
proxyReq.Header.Add("Authorization", "Bearer "+auth)
100+
proxyReq.Host = newURL.Hostname()
101+
},
102+
}
103+
rp.ServeHTTP(w, req)
104+
}

pkg/llm/registry.go

+55-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@ import (
55
"errors"
66
"fmt"
77
"sort"
8+
"sync"
89

10+
"github.com/google/uuid"
11+
"github.com/gptscript-ai/gptscript/pkg/env"
912
"github.com/gptscript-ai/gptscript/pkg/openai"
1013
"github.com/gptscript-ai/gptscript/pkg/remote"
1114
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -18,11 +21,16 @@ type Client interface {
1821
}
1922

2023
type Registry struct {
21-
clients []Client
24+
proxyToken string
25+
proxyURL string
26+
proxyLock sync.Mutex
27+
clients []Client
2228
}
2329

2430
func NewRegistry() *Registry {
25-
return &Registry{}
31+
return &Registry{
32+
proxyToken: env.VarOrDefault("GPTSCRIPT_INTERNAL_PROXY_TOKEN", uuid.New().String()),
33+
}
2634
}
2735

2836
func (r *Registry) AddClient(client Client) error {
@@ -44,6 +52,10 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result
4452

4553
func (r *Registry) fastPath(modelName string) Client {
4654
// This is optimization hack to avoid doing List Models
55+
if len(r.clients) == 1 {
56+
return r.clients[0]
57+
}
58+
4759
if len(r.clients) != 2 {
4860
return nil
4961
}
@@ -66,6 +78,47 @@ func (r *Registry) fastPath(modelName string) Client {
6678
return r.clients[0]
6779
}
6880

81+
func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) {
82+
if c := r.fastPath(modelName); c != nil {
83+
return c, nil
84+
}
85+
86+
var errs []error
87+
var oaiClient *openai.Client
88+
for _, client := range r.clients {
89+
ok, err := client.Supports(ctx, modelName)
90+
if err != nil {
91+
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
92+
if errors.Is(err, openai.InvalidAuthError{}) {
93+
oaiClient = client.(*openai.Client)
94+
}
95+
96+
errs = append(errs, err)
97+
} else if ok {
98+
return client, nil
99+
}
100+
}
101+
102+
if len(errs) > 0 && oaiClient != nil {
103+
// Prompt the user to enter their OpenAI API key and try again.
104+
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
105+
return nil, err
106+
}
107+
ok, err := oaiClient.Supports(ctx, modelName)
108+
if err != nil {
109+
return nil, err
110+
} else if ok {
111+
return oaiClient, nil
112+
}
113+
}
114+
115+
if len(errs) == 0 {
116+
return nil, fmt.Errorf("failed to find a model provider for model [%s]", modelName)
117+
}
118+
119+
return nil, errors.Join(errs...)
120+
}
121+
69122
func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
70123
if messageRequest.Model == "" {
71124
return nil, fmt.Errorf("model is required")

pkg/openai/client.go

+7
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,13 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
130130
}, nil
131131
}
132132

133+
func (c *Client) ProxyInfo() (token, urlBase string) {
134+
if c.invalidAuth {
135+
return "", ""
136+
}
137+
return c.c.GetAPIKeyAndBaseURL()
138+
}
139+
133140
func (c *Client) ValidAuth() error {
134141
if c.invalidAuth {
135142
return InvalidAuthError{}

pkg/runner/runner.go

+19-8
Original file line numberDiff line numberDiff line change
@@ -872,6 +872,11 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
872872
return nil, fmt.Errorf("failed to parse credential tool %q: %w", ref.Reference, err)
873873
}
874874

875+
if callCtx.Program.ToolSet[ref.ToolID].IsNoop() {
876+
// ignore empty tools
877+
continue
878+
}
879+
875880
credName := toolName
876881
if credentialAlias != "" {
877882
credName = credentialAlias
@@ -944,6 +949,10 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
944949
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", ref.Reference)
945950
}
946951

952+
if *res.Result == "" {
953+
continue
954+
}
955+
947956
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
948957
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", ref.Reference, err)
949958
}
@@ -958,15 +967,17 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
958967
}
959968
}
960969

961-
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
962-
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
963-
if isEmpty {
964-
log.Warnf("Not saving empty credential for tool %s", toolName)
965-
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
966-
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
970+
if !c.Ephemeral {
971+
// Only store the credential if the tool is on GitHub or has an alias, and the credential is non-empty.
972+
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[ref.ToolID].Source.Repo != nil) || credentialAlias != "" {
973+
if isEmpty {
974+
log.Warnf("Not saving empty credential for tool %s", toolName)
975+
} else if err := r.credStore.Add(callCtx.Ctx, *c); err != nil {
976+
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
977+
}
978+
} else {
979+
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
967980
}
968-
} else {
969-
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
970981
}
971982
}
972983

pkg/tests/tester/runner.go

+4
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,10 @@ type Result struct {
3131
Err error
3232
}
3333

34+
func (c *Client) ProxyInfo() (string, string, error) {
35+
return "test-auth", "test-url", nil
36+
}
37+
3438
func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) {
3539
msgData, err := json.MarshalIndent(messageRequest, "", " ")
3640
require.NoError(c.t, err)

0 commit comments

Comments
 (0)