Skip to content

Commit b823ee6

Browse files
committed
set up different credential types and prompt for model providers
Signed-off-by: Grant Linville <[email protected]>
1 parent f8983d2 commit b823ee6

File tree

8 files changed

+149
-64
lines changed

8 files changed

+149
-64
lines changed

pkg/credentials/credential.go

+24-2
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,17 @@ import (
88
"github.com/docker/cli/cli/config/types"
99
)
1010

11+
type CredentialType string
12+
13+
const (
14+
CredentialTypeTool CredentialType = "tool"
15+
CredentialTypeModelProvider CredentialType = "modelprovider"
16+
)
17+
1118
type Credential struct {
1219
Context string `json:"context"`
1320
ToolName string `json:"toolName"`
21+
Type CredentialType `json:"type"`
1422
Env map[string]string `json:"env"`
1523
}
1624

@@ -21,7 +29,7 @@ func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
2129
}
2230

2331
return types.AuthConfig{
24-
Username: "gptscript", // Username is required, but not used
32+
Username: string(c.Type),
2533
Password: string(env),
2634
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
2735
}, nil
@@ -33,14 +41,28 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
3341
return Credential{}, err
3442
}
3543

36-
tool, ctx, err := toolNameAndCtxFromAddress(strings.TrimPrefix(authCfg.ServerAddress, "https://"))
44+
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
45+
// check for that here.
46+
credType := authCfg.Username
47+
if credType == "gptscript" {
48+
credType = string(CredentialTypeTool)
49+
}
50+
51+
// If it's a tool credential, remove the http[s] prefix.
52+
address := authCfg.ServerAddress
53+
if credType == string(CredentialTypeTool) {
54+
address = strings.TrimPrefix(strings.TrimPrefix(address, "https://"), "http://")
55+
}
56+
57+
tool, ctx, err := toolNameAndCtxFromAddress(address)
3758
if err != nil {
3859
return Credential{}, err
3960
}
4061

4162
return Credential{
4263
Context: ctx,
4364
ToolName: tool,
65+
Type: CredentialType(credType),
4466
Env: env,
4567
}, nil
4668
}

pkg/credentials/helper.go

+25-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package credentials
22

33
import (
44
"errors"
5+
"regexp"
6+
"strings"
57

68
"github.com/docker/cli/cli/config/credentials"
79
"github.com/docker/cli/cli/config/types"
@@ -55,7 +57,29 @@ func (h *HelperStore) GetAll() (map[string]types.AuthConfig, error) {
5557
return nil, err
5658
}
5759

58-
for serverAddress := range serverAddresses {
60+
newCredAddresses := make(map[string]string, len(serverAddresses))
61+
for serverAddress, val := range serverAddresses {
62+
// If the serverAddress contains a port, we need to put it back in the right spot.
63+
// For some reason, even when a credential is stored properly as http://hostname:8080///credctx,
64+
// the list function will return http://hostname///credctx:8080. This is something wrong
65+
// with macOS's built-in libraries. So we need to fix it here.
66+
toolName, ctx, err := toolNameAndCtxFromAddress(serverAddress)
67+
if err != nil {
68+
return nil, err
69+
}
70+
71+
possiblePortNumber := strings.Split(ctx, ":")[len(strings.Split(ctx, ":"))-1]
72+
if regexp.MustCompile(`\d+$`).MatchString(possiblePortNumber) {
73+
// port number confirmed
74+
toolName = toolName + ":" + possiblePortNumber
75+
ctx = strings.TrimSuffix(ctx, ":"+possiblePortNumber)
76+
}
77+
78+
newCredAddresses[toolNameWithCtx(toolName, ctx)] = val
79+
delete(serverAddresses, serverAddress)
80+
}
81+
82+
for serverAddress := range newCredAddresses {
5983
ac, err := h.Get(serverAddress)
6084
if err != nil {
6185
return nil, err

pkg/llm/registry.go

+22-15
Original file line numberDiff line numberDiff line change
@@ -33,21 +33,7 @@ func (r *Registry) ListModels(ctx context.Context, providers ...string) (result
3333
for _, v := range r.clients {
3434
models, err := v.ListModels(ctx, providers...)
3535
if err != nil {
36-
// If we got back an InvalidAuthError, then we know it came from the OpenAI client, and we can
37-
// try to get the credential from the cred store.
38-
if errors.Is(err, openai.InvalidAuthError{}) {
39-
if err := v.(*openai.Client).RetrieveAPIKey(); err != nil {
40-
return nil, err
41-
}
42-
43-
// Now that the API key has been retrieved, try to list models again.
44-
models, err = v.ListModels(ctx, providers...)
45-
if err != nil {
46-
return nil, err
47-
}
48-
} else {
49-
return nil, err
50-
}
36+
return nil, err
5137
}
5238
result = append(result, models...)
5339
}
@@ -59,15 +45,36 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ
5945
if messageRequest.Model == "" {
6046
return nil, fmt.Errorf("model is required")
6147
}
48+
6249
var errs []error
50+
var oaiClient *openai.Client
6351
for _, client := range r.clients {
6452
ok, err := client.Supports(ctx, messageRequest.Model)
6553
if err != nil {
54+
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
55+
if errors.Is(err, openai.InvalidAuthError{}) {
56+
oaiClient = client.(*openai.Client)
57+
}
58+
6659
errs = append(errs, err)
6760
} else if ok {
6861
return client.Call(ctx, messageRequest, status)
6962
}
7063
}
64+
65+
if len(errs) > 0 && oaiClient != nil {
66+
// Prompt the user to enter their OpenAI API key and try again.
67+
if err := oaiClient.RetrieveAPIKey(); err != nil {
68+
return nil, err
69+
}
70+
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
71+
if err != nil {
72+
return nil, err
73+
} else if ok {
74+
return oaiClient.Call(ctx, messageRequest, status)
75+
}
76+
}
77+
7178
if len(errs) == 0 {
7279
return nil, fmt.Errorf("failed to find a model provider for model [%s]", messageRequest.Model)
7380
}

pkg/openai/client.go

+9-30
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import (
1919
"github.com/gptscript-ai/gptscript/pkg/prompt"
2020
"github.com/gptscript-ai/gptscript/pkg/system"
2121
"github.com/gptscript-ai/gptscript/pkg/types"
22-
"github.com/tidwall/gjson"
2322
)
2423

2524
const (
@@ -168,6 +167,12 @@ func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
168167
if err != nil {
169168
return false, err
170169
}
170+
171+
if len(models) == 0 {
172+
// We got no models back, which means our auth is invalid.
173+
return false, InvalidAuthError{}
174+
}
175+
171176
return slices.Contains(models, modelName), nil
172177
}
173178

@@ -177,8 +182,9 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
177182
return nil, nil
178183
}
179184

185+
// If auth is invalid, we just want to return nothing.
180186
if err := c.ValidAuth(); err != nil {
181-
return nil, err
187+
return nil, nil
182188
}
183189

184190
models, err := c.c.ListModels(ctx)
@@ -559,38 +565,11 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
559565
}
560566

561567
func (c *Client) RetrieveAPIKey() error {
562-
store, err := credentials.NewStore(c.cliCfg, c.credCtx)
563-
if err != nil {
564-
return err
565-
}
566-
567-
cred, exists, err := store.Get(BuiltinCredName)
568+
k, err := prompt.GetModelProviderCredential(BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.credCtx, c.cliCfg)
568569
if err != nil {
569570
return err
570571
}
571572

572-
var k string
573-
if exists {
574-
k = cred.Env[key]
575-
} else {
576-
// SysPrompt doesn't use its first two arguments, so we can safely pass whatever to them
577-
result, err := prompt.SysPrompt(context.Background(), nil, `{"message":"Please provide your OpenAI API key:","fields":"key","sensitive":"true"}`)
578-
if err != nil {
579-
return err
580-
}
581-
582-
k = gjson.Get(result, "key").String()
583-
if err := store.Add(credentials.Credential{
584-
ToolName: BuiltinCredName,
585-
Env: map[string]string{
586-
"OPENAI_API_KEY": k,
587-
},
588-
}); err != nil {
589-
return err
590-
}
591-
log.Infof("Saved API key as credential %s", BuiltinCredName)
592-
}
593-
594573
c.c.SetAPIKey(k)
595574
c.invalidAuth = false
596575
return nil

pkg/prompt/credential.go

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
package prompt
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
"github.com/gptscript-ai/gptscript/pkg/config"
8+
"github.com/gptscript-ai/gptscript/pkg/credentials"
9+
"github.com/tidwall/gjson"
10+
)
11+
12+
func GetModelProviderCredential(credName, env, message, credCtx string, cliCfg *config.CLIConfig) (string, error) {
13+
store, err := credentials.NewStore(cliCfg, credCtx)
14+
if err != nil {
15+
return "", err
16+
}
17+
18+
cred, exists, err := store.Get(credName)
19+
if err != nil {
20+
return "", err
21+
}
22+
23+
var k string
24+
if exists {
25+
k = cred.Env[env]
26+
} else {
27+
// SysPrompt doesn't use its first two arguments, so we can safely pass whatever to them
28+
result, err := SysPrompt(context.Background(), nil, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
29+
if err != nil {
30+
return "", err
31+
}
32+
33+
k = gjson.Get(result, "key").String()
34+
if err := store.Add(credentials.Credential{
35+
ToolName: credName,
36+
Type: credentials.CredentialTypeModelProvider,
37+
Env: map[string]string{
38+
env: k,
39+
},
40+
}); err != nil {
41+
return "", err
42+
}
43+
log.Infof("Saved API key as credential %s", credName)
44+
}
45+
46+
return k, nil
47+
}

pkg/prompt/log.go

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package prompt
2+
3+
import "github.com/gptscript-ai/gptscript/pkg/mvl"
4+
5+
var log = mvl.Package()

pkg/remote/remote.go

+16-16
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package remote
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
76
"net/url"
87
"os"
@@ -16,6 +15,7 @@ import (
1615
env2 "github.com/gptscript-ai/gptscript/pkg/env"
1716
"github.com/gptscript-ai/gptscript/pkg/loader"
1817
"github.com/gptscript-ai/gptscript/pkg/openai"
18+
"github.com/gptscript-ai/gptscript/pkg/prompt"
1919
"github.com/gptscript-ai/gptscript/pkg/runner"
2020
"github.com/gptscript-ai/gptscript/pkg/types"
2121
)
@@ -63,21 +63,7 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
6363
}
6464
models, err := client.ListModels(ctx, "")
6565
if err != nil {
66-
// If we got back an InvalidAuthError, then we know it came from the OpenAI client, and we can
67-
// try to get the credential from the cred store.
68-
if errors.Is(err, openai.InvalidAuthError{}) {
69-
if err := client.RetrieveAPIKey(); err != nil {
70-
return nil, err
71-
}
72-
73-
// Now that the API key has been retrieved, try to list models again.
74-
models, err = client.ListModels(ctx, "")
75-
if err != nil {
76-
return nil, err
77-
}
78-
} else {
79-
return nil, err
80-
}
66+
return nil, err
8167
}
8268
for _, model := range models {
8369
result = append(result, model+" from "+provider)
@@ -121,6 +107,16 @@ func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) {
121107
return nil, err
122108
}
123109
env := "GPTSCRIPT_PROVIDER_" + env2.ToEnvLike(parsed.Hostname()) + "_API_KEY"
110+
key := os.Getenv(env)
111+
112+
if key == "" {
113+
var err error
114+
key, err = c.retrieveAPIKey(env, apiURL)
115+
if err != nil {
116+
return nil, err
117+
}
118+
}
119+
124120
return openai.NewClient(c.cliCfg, c.credCtx, openai.Options{
125121
BaseURL: apiURL,
126122
Cache: c.cache,
@@ -180,3 +176,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
180176
c.clients[toolName] = client
181177
return client, nil
182178
}
179+
180+
func (c *Client) retrieveAPIKey(env, url string) (string, error) {
181+
return prompt.GetModelProviderCredential(url, env, fmt.Sprintf("Please provide your API key for %s", url), c.credCtx, c.cliCfg)
182+
}

pkg/runner/runner.go

+1
Original file line numberDiff line numberDiff line change
@@ -861,6 +861,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
861861

862862
cred = &credentials.Credential{
863863
ToolName: credToolName,
864+
Type: credentials.CredentialTypeTool,
864865
Env: envMap.Env,
865866
}
866867

0 commit comments

Comments
 (0)