Skip to content

Commit 49fc1c9

Browse files
authored
enhance: ask user for OpenAI key and store it in the cred store (#396)
Signed-off-by: Grant Linville <[email protected]>
1 parent 4283d79 commit 49fc1c9

File tree

12 files changed

+241
-39
lines changed

12 files changed

+241
-39
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ require (
1515
github.com/getkin/kin-openapi v0.123.0
1616
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1717
github.com/google/uuid v1.6.0
18-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9
18+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
1919
github.com/gptscript-ai/tui v0.0.0-20240604233332-4a5ff43cdc58
2020
github.com/hexops/autogold/v2 v2.2.1
2121
github.com/hexops/valast v1.4.4

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
170170
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
171171
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
172172
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
173-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9 h1:s6nL/aokB1sJTqVXEjN0zFI5CJa66ubw9g68VTMzEw0=
174-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
173+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
174+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
175175
github.com/gptscript-ai/go-gptscript v0.0.0-20240604231423-7a845df843b1 h1:SHoqsU8Ne2V4zfrFve9kQn4vcv4N4TItD6Oju+pzKV8=
176176
github.com/gptscript-ai/go-gptscript v0.0.0-20240604231423-7a845df843b1/go.mod h1:h1yYzC0rgB5Kk7lwdba+Xs6cWkuJfLq6sPRna45OVG0=
177177
github.com/gptscript-ai/tui v0.0.0-20240604233332-4a5ff43cdc58 h1:kbr6cY4VdvxAfJf+xk6x/gxDzmgZMkX2ZfLcyskGYsw=

pkg/credentials/credential.go

+29-5
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,19 @@ import (
88
"github.com/docker/cli/cli/config/types"
99
)
1010

11+
const ctxSeparator = "///"
12+
13+
type CredentialType string
14+
15+
const (
16+
CredentialTypeTool CredentialType = "tool"
17+
CredentialTypeModelProvider CredentialType = "modelProvider"
18+
)
19+
1120
type Credential struct {
1221
Context string `json:"context"`
1322
ToolName string `json:"toolName"`
23+
Type CredentialType `json:"type"`
1424
Env map[string]string `json:"env"`
1525
}
1626

@@ -21,7 +31,7 @@ func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
2131
}
2232

2333
return types.AuthConfig{
24-
Username: "gptscript", // Username is required, but not used
34+
Username: string(c.Type),
2535
Password: string(env),
2636
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
2737
}, nil
@@ -33,26 +43,40 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
3343
return Credential{}, err
3444
}
3545

36-
tool, ctx, err := toolNameAndCtxFromAddress(strings.TrimPrefix(authCfg.ServerAddress, "https://"))
46+
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
47+
// check for that here.
48+
credType := authCfg.Username
49+
if credType == "gptscript" {
50+
credType = string(CredentialTypeTool)
51+
}
52+
53+
// If it's a tool credential or sys.openai, remove the http[s] prefix.
54+
address := authCfg.ServerAddress
55+
if credType == string(CredentialTypeTool) || strings.HasPrefix(address, "https://sys.openai"+ctxSeparator) {
56+
address = strings.TrimPrefix(strings.TrimPrefix(address, "https://"), "http://")
57+
}
58+
59+
tool, ctx, err := toolNameAndCtxFromAddress(address)
3760
if err != nil {
3861
return Credential{}, err
3962
}
4063

4164
return Credential{
4265
Context: ctx,
4366
ToolName: tool,
67+
Type: CredentialType(credType),
4468
Env: env,
4569
}, nil
4670
}
4771

4872
func toolNameWithCtx(toolName, credCtx string) string {
49-
return toolName + "///" + credCtx
73+
return toolName + ctxSeparator + credCtx
5074
}
5175

5276
func toolNameAndCtxFromAddress(address string) (string, string, error) {
53-
parts := strings.Split(address, "///")
77+
parts := strings.Split(address, ctxSeparator)
5478
if len(parts) != 2 {
55-
return "", "", fmt.Errorf("error parsing tool name and context %q. Tool names cannot contain '///'", address)
79+
return "", "", fmt.Errorf("error parsing tool name and context %q. Tool names cannot contain '%s'", address, ctxSeparator)
5680
}
5781
return parts[0], parts[1], nil
5882
}

pkg/credentials/helper.go

+28-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package credentials
22

33
import (
44
"errors"
5+
"regexp"
6+
"strings"
57

68
"github.com/docker/cli/cli/config/credentials"
79
"github.com/docker/cli/cli/config/types"
@@ -55,7 +57,32 @@ func (h *HelperStore) GetAll() (map[string]types.AuthConfig, error) {
5557
return nil, err
5658
}
5759

58-
for serverAddress := range serverAddresses {
60+
newCredAddresses := make(map[string]string, len(serverAddresses))
61+
for serverAddress, val := range serverAddresses {
62+
// If the serverAddress contains a port, we need to put it back in the right spot.
63+
// For some reason, even when a credential is stored properly as http://hostname:8080///credctx,
64+
// the list function will return http://hostname///credctx:8080. This is something wrong
65+
// with macOS's built-in libraries. So we need to fix it here.
66+
toolName, ctx, err := toolNameAndCtxFromAddress(serverAddress)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
contextPieces := strings.Split(ctx, ":")
72+
if len(contextPieces) > 1 {
73+
possiblePortNumber := contextPieces[len(contextPieces)-1]
74+
if regexp.MustCompile(`\d+$`).MatchString(possiblePortNumber) {
75+
// port number confirmed
76+
toolName = toolName + ":" + possiblePortNumber
77+
ctx = strings.TrimSuffix(ctx, ":"+possiblePortNumber)
78+
}
79+
}
80+
81+
newCredAddresses[toolNameWithCtx(toolName, ctx)] = val
82+
delete(serverAddresses, serverAddress)
83+
}
84+
85+
for serverAddress := range newCredAddresses {
5986
ac, err := h.Get(serverAddress)
6087
if err != nil {
6188
return nil, err

pkg/gptscript/gptscript.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/gptscript-ai/gptscript/pkg/builtin"
1111
"github.com/gptscript-ai/gptscript/pkg/cache"
12+
"github.com/gptscript-ai/gptscript/pkg/config"
1213
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1314
"github.com/gptscript-ai/gptscript/pkg/engine"
1415
"github.com/gptscript-ai/gptscript/pkg/hash"
@@ -70,15 +71,20 @@ func New(opts *Options) (*GPTScript, error) {
7071
return nil, err
7172
}
7273

73-
oAIClient, err := openai.NewClient(opts.OpenAI, openai.Options{
74+
cliCfg, err := config.ReadCLIConfig(opts.OpenAI.ConfigFile)
75+
if err != nil {
76+
return nil, err
77+
}
78+
79+
oaiClient, err := openai.NewClient(cliCfg, opts.CredentialContext, opts.OpenAI, openai.Options{
7480
Cache: cacheClient,
7581
SetSeed: true,
7682
})
7783
if err != nil {
7884
return nil, err
7985
}
8086

81-
if err := registry.AddClient(oAIClient); err != nil {
87+
if err := registry.AddClient(oaiClient); err != nil {
8288
return nil, err
8389
}
8490

@@ -95,18 +101,19 @@ func New(opts *Options) (*GPTScript, error) {
95101
return nil, err
96102
}
97103

98-
remoteClient := remote.New(runner, opts.Env, cacheClient)
99-
100-
if err := registry.AddClient(remoteClient); err != nil {
101-
return nil, err
102-
}
103-
104104
ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
105105
extraEnv, err := prompt.NewServer(ctx, opts.Env)
106106
if err != nil {
107107
closeServer()
108108
return nil, err
109109
}
110+
oaiClient.SetEnvs(extraEnv)
111+
112+
remoteClient := remote.New(runner, extraEnv, cacheClient, cliCfg, opts.CredentialContext)
113+
if err := registry.AddClient(remoteClient); err != nil {
114+
closeServer()
115+
return nil, err
116+
}
110117

111118
return &GPTScript{
112119
Registry: registry,

pkg/llm/registry.go

+22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"sort"
88

9+
"github.com/gptscript-ai/gptscript/pkg/openai"
910
"github.com/gptscript-ai/gptscript/pkg/types"
1011
)
1112

@@ -44,15 +45,36 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ
4445
if messageRequest.Model == "" {
4546
return nil, fmt.Errorf("model is required")
4647
}
48+
4749
var errs []error
50+
var oaiClient *openai.Client
4851
for _, client := range r.clients {
4952
ok, err := client.Supports(ctx, messageRequest.Model)
5053
if err != nil {
54+
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
55+
if errors.Is(err, openai.InvalidAuthError{}) {
56+
oaiClient = client.(*openai.Client)
57+
}
58+
5159
errs = append(errs, err)
5260
} else if ok {
5361
return client.Call(ctx, messageRequest, status)
5462
}
5563
}
64+
65+
if len(errs) > 0 && oaiClient != nil {
66+
// Prompt the user to enter their OpenAI API key and try again.
67+
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
68+
return nil, err
69+
}
70+
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
71+
if err != nil {
72+
return nil, err
73+
} else if ok {
74+
return oaiClient.Call(ctx, messageRequest, status)
75+
}
76+
}
77+
5678
if len(errs) == 0 {
5779
return nil, fmt.Errorf("failed to find a model provider for model [%s]", messageRequest.Model)
5880
}

pkg/openai/client.go

+64-5
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,41 @@ import (
1212

1313
openai "github.com/gptscript-ai/chat-completion-client"
1414
"github.com/gptscript-ai/gptscript/pkg/cache"
15+
"github.com/gptscript-ai/gptscript/pkg/config"
1516
"github.com/gptscript-ai/gptscript/pkg/counter"
17+
"github.com/gptscript-ai/gptscript/pkg/credentials"
1618
"github.com/gptscript-ai/gptscript/pkg/hash"
19+
"github.com/gptscript-ai/gptscript/pkg/prompt"
1720
"github.com/gptscript-ai/gptscript/pkg/system"
1821
"github.com/gptscript-ai/gptscript/pkg/types"
1922
)
2023

2124
const (
22-
DefaultModel = openai.GPT4o
25+
DefaultModel = openai.GPT4o
26+
BuiltinCredName = "sys.openai"
2327
)
2428

2529
var (
2630
key = os.Getenv("OPENAI_API_KEY")
2731
url = os.Getenv("OPENAI_URL")
2832
)
2933

34+
type InvalidAuthError struct{}
35+
36+
func (InvalidAuthError) Error() string {
37+
return "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable"
38+
}
39+
3040
type Client struct {
3141
defaultModel string
3242
c *openai.Client
3343
cache *cache.Client
3444
invalidAuth bool
3545
cacheKeyBase string
3646
setSeed bool
47+
cliCfg *config.CLIConfig
48+
credCtx string
49+
envs []string
3750
}
3851

3952
type Options struct {
@@ -75,12 +88,28 @@ func complete(opts ...Options) (result Options, err error) {
7588
return result, err
7689
}
7790

78-
func NewClient(opts ...Options) (*Client, error) {
91+
func NewClient(cliCfg *config.CLIConfig, credCtx string, opts ...Options) (*Client, error) {
7992
opt, err := complete(opts...)
8093
if err != nil {
8194
return nil, err
8295
}
8396

97+
// If the API key is not set, try to get it from the cred store
98+
if opt.APIKey == "" && opt.BaseURL == "" {
99+
store, err := credentials.NewStore(cliCfg, credCtx)
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
cred, exists, err := store.Get(BuiltinCredName)
105+
if err != nil {
106+
return nil, err
107+
}
108+
if exists {
109+
opt.APIKey = cred.Env["OPENAI_API_KEY"]
110+
}
111+
}
112+
84113
cfg := openai.DefaultConfig(opt.APIKey)
85114
cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL)
86115
cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID)
@@ -97,21 +126,33 @@ func NewClient(opts ...Options) (*Client, error) {
97126
cacheKeyBase: cacheKeyBase,
98127
invalidAuth: opt.APIKey == "" && opt.BaseURL == "",
99128
setSeed: opt.SetSeed,
129+
cliCfg: cliCfg,
130+
credCtx: credCtx,
100131
}, nil
101132
}
102133

103134
func (c *Client) ValidAuth() error {
104135
if c.invalidAuth {
105-
return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
136+
return InvalidAuthError{}
106137
}
107138
return nil
108139
}
109140

141+
func (c *Client) SetEnvs(env []string) {
142+
c.envs = env
143+
}
144+
110145
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
111146
models, err := c.ListModels(ctx)
112147
if err != nil {
113148
return false, err
114149
}
150+
151+
if len(models) == 0 {
152+
// We got no models back, which means our auth is invalid.
153+
return false, InvalidAuthError{}
154+
}
155+
115156
return slices.Contains(models, modelName), nil
116157
}
117158

@@ -121,8 +162,13 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
121162
return nil, nil
122163
}
123164

165+
// If auth is invalid, we just want to return nothing.
166+
// Returning an InvalidAuthError here will lead to cases where the user is prompted to enter their OpenAI key,
167+
// even when we don't want them to be prompted.
168+
// So the UX we settled on is that no models get printed if the user does gptscript --list-models
169+
// without having provided their key through the environment variable or the creds store.
124170
if err := c.ValidAuth(); err != nil {
125-
return nil, err
171+
return nil, nil
126172
}
127173

128174
models, err := c.c.ListModels(ctx)
@@ -251,7 +297,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
251297

252298
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
253299
if err := c.ValidAuth(); err != nil {
254-
return nil, err
300+
if err := c.RetrieveAPIKey(ctx); err != nil {
301+
return nil, err
302+
}
255303
}
256304

257305
if messageRequest.Model == "" {
@@ -499,6 +547,17 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
499547
}
500548
}
501549

550+
func (c *Client) RetrieveAPIKey(ctx context.Context) error {
551+
k, err := prompt.GetModelProviderCredential(ctx, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.credCtx, c.envs, c.cliCfg)
552+
if err != nil {
553+
return err
554+
}
555+
556+
c.c.SetAPIKey(k)
557+
c.invalidAuth = false
558+
return nil
559+
}
560+
502561
func ptr[T any](v T) *T {
503562
return &v
504563
}

0 commit comments

Comments
 (0)