Skip to content

Commit 2b377fc

Browse files
committed
enhance: ask user for OpenAI key and store it in the cred store
Signed-off-by: Grant Linville <[email protected]>
1 parent e6c23dd commit 2b377fc

File tree

12 files changed

+232
-36
lines changed

12 files changed

+232
-36
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ require (
1515
github.com/getkin/kin-openapi v0.123.0
1616
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1717
github.com/google/uuid v1.6.0
18-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9
18+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
1919
github.com/gptscript-ai/tui v0.0.0-20240604045848-e01b0b7aab9f
2020
github.com/hexops/autogold/v2 v2.2.1
2121
github.com/hexops/valast v1.4.4

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
170170
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
171171
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
172172
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
173-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9 h1:s6nL/aokB1sJTqVXEjN0zFI5CJa66ubw9g68VTMzEw0=
174-
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
173+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
174+
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
175175
github.com/gptscript-ai/go-gptscript v0.0.0-20240604030145-39497c0575b3 h1:mXLpCzEg4DoOeFZt6w99QFh9n60UwpRGGG0c+aaT+5k=
176176
github.com/gptscript-ai/go-gptscript v0.0.0-20240604030145-39497c0575b3/go.mod h1:h1yYzC0rgB5Kk7lwdba+Xs6cWkuJfLq6sPRna45OVG0=
177177
github.com/gptscript-ai/tui v0.0.0-20240604045848-e01b0b7aab9f h1:7dCE0E/y3y3p1BPSQGQ4mtsz5cWWl0FbXfCCDCf57SI=

pkg/credentials/credential.go

+24-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@ import (
88
"github.com/docker/cli/cli/config/types"
99
)
1010

11+
type CredentialType string
12+
13+
const (
14+
CredentialTypeTool CredentialType = "tool"
15+
CredentialTypeModelProvider CredentialType = "modelprovider"
16+
)
17+
1118
type Credential struct {
1219
Context string `json:"context"`
1320
ToolName string `json:"toolName"`
21+
Type CredentialType `json:"type"`
1422
Env map[string]string `json:"env"`
1523
}
1624

@@ -21,7 +29,7 @@ func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
2129
}
2230

2331
return types.AuthConfig{
24-
Username: "gptscript", // Username is required, but not used
32+
Username: string(c.Type),
2533
Password: string(env),
2634
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
2735
}, nil
@@ -33,14 +41,28 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
3341
return Credential{}, err
3442
}
3543

36-
tool, ctx, err := toolNameAndCtxFromAddress(strings.TrimPrefix(authCfg.ServerAddress, "https://"))
44+
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
45+
// check for that here.
46+
credType := authCfg.Username
47+
if credType == "gptscript" {
48+
credType = string(CredentialTypeTool)
49+
}
50+
51+
// If it's a tool credential or sys.openai, remove the http[s] prefix.
52+
address := authCfg.ServerAddress
53+
if credType == string(CredentialTypeTool) || strings.HasPrefix(address, "https://sys.openai///") {
54+
address = strings.TrimPrefix(strings.TrimPrefix(address, "https://"), "http://")
55+
}
56+
57+
tool, ctx, err := toolNameAndCtxFromAddress(address)
3758
if err != nil {
3859
return Credential{}, err
3960
}
4061

4162
return Credential{
4263
Context: ctx,
4364
ToolName: tool,
65+
Type: CredentialType(credType),
4466
Env: env,
4567
}, nil
4668
}

pkg/credentials/helper.go

+28-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package credentials
22

33
import (
44
"errors"
5+
"regexp"
6+
"strings"
57

68
"github.com/docker/cli/cli/config/credentials"
79
"github.com/docker/cli/cli/config/types"
@@ -55,7 +57,32 @@ func (h *HelperStore) GetAll() (map[string]types.AuthConfig, error) {
5557
return nil, err
5658
}
5759

58-
for serverAddress := range serverAddresses {
60+
newCredAddresses := make(map[string]string, len(serverAddresses))
61+
for serverAddress, val := range serverAddresses {
62+
// If the serverAddress contains a port, we need to put it back in the right spot.
63+
// For some reason, even when a credential is stored properly as http://hostname:8080///credctx,
64+
// the list function will return http://hostname///credctx:8080. This is something wrong
65+
// with macOS's built-in libraries. So we need to fix it here.
66+
toolName, ctx, err := toolNameAndCtxFromAddress(serverAddress)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
contextPieces := strings.Split(ctx, ":")
72+
if len(contextPieces) > 1 {
73+
possiblePortNumber := contextPieces[len(contextPieces)-1]
74+
if regexp.MustCompile(`\d+$`).MatchString(possiblePortNumber) {
75+
// port number confirmed
76+
toolName = toolName + ":" + possiblePortNumber
77+
ctx = strings.TrimSuffix(ctx, ":"+possiblePortNumber)
78+
}
79+
}
80+
81+
newCredAddresses[toolNameWithCtx(toolName, ctx)] = val
82+
delete(serverAddresses, serverAddress)
83+
}
84+
85+
for serverAddress := range newCredAddresses {
5986
ac, err := h.Get(serverAddress)
6087
if err != nil {
6188
return nil, err

pkg/gptscript/gptscript.go

+15-8
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/gptscript-ai/gptscript/pkg/builtin"
1111
"github.com/gptscript-ai/gptscript/pkg/cache"
12+
"github.com/gptscript-ai/gptscript/pkg/config"
1213
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1314
"github.com/gptscript-ai/gptscript/pkg/engine"
1415
"github.com/gptscript-ai/gptscript/pkg/hash"
@@ -70,15 +71,20 @@ func New(opts *Options) (*GPTScript, error) {
7071
return nil, err
7172
}
7273

73-
oAIClient, err := openai.NewClient(opts.OpenAI, openai.Options{
74+
cliCfg, err := config.ReadCLIConfig(opts.OpenAI.ConfigFile)
75+
if err != nil {
76+
return nil, err
77+
}
78+
79+
oaiClient, err := openai.NewClient(cliCfg, opts.CredentialContext, opts.OpenAI, openai.Options{
7480
Cache: cacheClient,
7581
SetSeed: true,
7682
})
7783
if err != nil {
7884
return nil, err
7985
}
8086

81-
if err := registry.AddClient(oAIClient); err != nil {
87+
if err := registry.AddClient(oaiClient); err != nil {
8288
return nil, err
8389
}
8490

@@ -95,18 +101,19 @@ func New(opts *Options) (*GPTScript, error) {
95101
return nil, err
96102
}
97103

98-
remoteClient := remote.New(runner, opts.Env, cacheClient)
99-
100-
if err := registry.AddClient(remoteClient); err != nil {
101-
return nil, err
102-
}
103-
104104
ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
105105
extraEnv, err := prompt.NewServer(ctx, opts.Env)
106106
if err != nil {
107107
closeServer()
108108
return nil, err
109109
}
110+
oaiClient.SetEnvs(extraEnv)
111+
112+
remoteClient := remote.New(runner, extraEnv, cacheClient, cliCfg, opts.CredentialContext)
113+
if err := registry.AddClient(remoteClient); err != nil {
114+
closeServer()
115+
return nil, err
116+
}
110117

111118
return &GPTScript{
112119
Registry: registry,

pkg/llm/registry.go

+22
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"fmt"
77
"sort"
88

9+
"github.com/gptscript-ai/gptscript/pkg/openai"
910
"github.com/gptscript-ai/gptscript/pkg/types"
1011
)
1112

@@ -44,15 +45,36 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ
4445
if messageRequest.Model == "" {
4546
return nil, fmt.Errorf("model is required")
4647
}
48+
4749
var errs []error
50+
var oaiClient *openai.Client
4851
for _, client := range r.clients {
4952
ok, err := client.Supports(ctx, messageRequest.Model)
5053
if err != nil {
54+
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
55+
if errors.Is(err, openai.InvalidAuthError{}) {
56+
oaiClient = client.(*openai.Client)
57+
}
58+
5159
errs = append(errs, err)
5260
} else if ok {
5361
return client.Call(ctx, messageRequest, status)
5462
}
5563
}
64+
65+
if len(errs) > 0 && oaiClient != nil {
66+
// Prompt the user to enter their OpenAI API key and try again.
67+
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
68+
return nil, err
69+
}
70+
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
71+
if err != nil {
72+
return nil, err
73+
} else if ok {
74+
return oaiClient.Call(ctx, messageRequest, status)
75+
}
76+
}
77+
5678
if len(errs) == 0 {
5779
return nil, fmt.Errorf("failed to find a model provider for model [%s]", messageRequest.Model)
5880
}

pkg/openai/client.go

+60-5
Original file line numberDiff line numberDiff line change
@@ -12,28 +12,41 @@ import (
1212

1313
openai "github.com/gptscript-ai/chat-completion-client"
1414
"github.com/gptscript-ai/gptscript/pkg/cache"
15+
"github.com/gptscript-ai/gptscript/pkg/config"
1516
"github.com/gptscript-ai/gptscript/pkg/counter"
17+
"github.com/gptscript-ai/gptscript/pkg/credentials"
1618
"github.com/gptscript-ai/gptscript/pkg/hash"
19+
"github.com/gptscript-ai/gptscript/pkg/prompt"
1720
"github.com/gptscript-ai/gptscript/pkg/system"
1821
"github.com/gptscript-ai/gptscript/pkg/types"
1922
)
2023

2124
const (
22-
DefaultModel = openai.GPT4o
25+
DefaultModel = openai.GPT4o
26+
BuiltinCredName = "sys.openai"
2327
)
2428

2529
var (
2630
key = os.Getenv("OPENAI_API_KEY")
2731
url = os.Getenv("OPENAI_URL")
2832
)
2933

34+
type InvalidAuthError struct{}
35+
36+
func (InvalidAuthError) Error() string {
37+
return "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable"
38+
}
39+
3040
type Client struct {
3141
defaultModel string
3242
c *openai.Client
3343
cache *cache.Client
3444
invalidAuth bool
3545
cacheKeyBase string
3646
setSeed bool
47+
cliCfg *config.CLIConfig
48+
credCtx string
49+
envs []string
3750
}
3851

3952
type Options struct {
@@ -75,12 +88,28 @@ func complete(opts ...Options) (result Options, err error) {
7588
return result, err
7689
}
7790

78-
func NewClient(opts ...Options) (*Client, error) {
91+
func NewClient(cliCfg *config.CLIConfig, credCtx string, opts ...Options) (*Client, error) {
7992
opt, err := complete(opts...)
8093
if err != nil {
8194
return nil, err
8295
}
8396

97+
// If the API key is not set, try to get it from the cred store
98+
if opt.APIKey == "" && opt.BaseURL == "" {
99+
store, err := credentials.NewStore(cliCfg, credCtx)
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
cred, exists, err := store.Get(BuiltinCredName)
105+
if err != nil {
106+
return nil, err
107+
}
108+
if exists {
109+
opt.APIKey = cred.Env["OPENAI_API_KEY"]
110+
}
111+
}
112+
84113
cfg := openai.DefaultConfig(opt.APIKey)
85114
cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL)
86115
cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID)
@@ -97,21 +126,33 @@ func NewClient(opts ...Options) (*Client, error) {
97126
cacheKeyBase: cacheKeyBase,
98127
invalidAuth: opt.APIKey == "" && opt.BaseURL == "",
99128
setSeed: opt.SetSeed,
129+
cliCfg: cliCfg,
130+
credCtx: credCtx,
100131
}, nil
101132
}
102133

103134
func (c *Client) ValidAuth() error {
104135
if c.invalidAuth {
105-
return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
136+
return InvalidAuthError{}
106137
}
107138
return nil
108139
}
109140

141+
func (c *Client) SetEnvs(env []string) {
142+
c.envs = env
143+
}
144+
110145
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
111146
models, err := c.ListModels(ctx)
112147
if err != nil {
113148
return false, err
114149
}
150+
151+
if len(models) == 0 {
152+
// We got no models back, which means our auth is invalid.
153+
return false, InvalidAuthError{}
154+
}
155+
115156
return slices.Contains(models, modelName), nil
116157
}
117158

@@ -121,8 +162,9 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
121162
return nil, nil
122163
}
123164

165+
// If auth is invalid, we just want to return nothing.
124166
if err := c.ValidAuth(); err != nil {
125-
return nil, err
167+
return nil, nil
126168
}
127169

128170
models, err := c.c.ListModels(ctx)
@@ -251,7 +293,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
251293

252294
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
253295
if err := c.ValidAuth(); err != nil {
254-
return nil, err
296+
if err := c.RetrieveAPIKey(ctx); err != nil {
297+
return nil, err
298+
}
255299
}
256300

257301
if messageRequest.Model == "" {
@@ -499,6 +543,17 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
499543
}
500544
}
501545

546+
func (c *Client) RetrieveAPIKey(ctx context.Context) error {
547+
k, err := prompt.GetModelProviderCredential(ctx, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.credCtx, c.envs, c.cliCfg)
548+
if err != nil {
549+
return err
550+
}
551+
552+
c.c.SetAPIKey(k)
553+
c.invalidAuth = false
554+
return nil
555+
}
556+
502557
func ptr[T any](v T) *T {
503558
return &v
504559
}

0 commit comments

Comments
 (0)