Skip to content

enhance: ask user for OpenAI key and store it in the cred store #396

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ require (
github.com/getkin/kin-openapi v0.123.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/google/uuid v1.6.0
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
github.com/gptscript-ai/tui v0.0.0-20240604233332-4a5ff43cdc58
github.com/hexops/autogold/v2 v2.2.1
github.com/hexops/valast v1.4.4
Expand Down
4 changes: 2 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -170,8 +170,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY=
github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c=
github.com/gorilla/websocket v1.5.0 h1:PPwGk2jz7EePpoHN/+ClbZu8SPxiqlu12wZP/3sWmnc=
github.com/gorilla/websocket v1.5.0/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9 h1:s6nL/aokB1sJTqVXEjN0zFI5CJa66ubw9g68VTMzEw0=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240515050533-bdef9f2226a9/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379 h1:vYnXoIyCXzaCEw0sYifQ4bDpsv3/fO/dZ2suEsTwCIo=
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
github.com/gptscript-ai/go-gptscript v0.0.0-20240604231423-7a845df843b1 h1:SHoqsU8Ne2V4zfrFve9kQn4vcv4N4TItD6Oju+pzKV8=
github.com/gptscript-ai/go-gptscript v0.0.0-20240604231423-7a845df843b1/go.mod h1:h1yYzC0rgB5Kk7lwdba+Xs6cWkuJfLq6sPRna45OVG0=
github.com/gptscript-ai/tui v0.0.0-20240604233332-4a5ff43cdc58 h1:kbr6cY4VdvxAfJf+xk6x/gxDzmgZMkX2ZfLcyskGYsw=
Expand Down
34 changes: 29 additions & 5 deletions pkg/credentials/credential.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,19 @@ import (
"github.com/docker/cli/cli/config/types"
)

const ctxSeparator = "///"

type CredentialType string

const (
CredentialTypeTool CredentialType = "tool"
CredentialTypeModelProvider CredentialType = "modelProvider"
)

type Credential struct {
Context string `json:"context"`
ToolName string `json:"toolName"`
Type CredentialType `json:"type"`
Env map[string]string `json:"env"`
}

Expand All @@ -21,7 +31,7 @@ func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
}

return types.AuthConfig{
Username: "gptscript", // Username is required, but not used
Username: string(c.Type),
Password: string(env),
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
}, nil
Expand All @@ -33,26 +43,40 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
return Credential{}, err
}

tool, ctx, err := toolNameAndCtxFromAddress(strings.TrimPrefix(authCfg.ServerAddress, "https://"))
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
// check for that here.
credType := authCfg.Username
if credType == "gptscript" {
credType = string(CredentialTypeTool)
}

// If it's a tool credential or sys.openai, remove the http[s] prefix.
address := authCfg.ServerAddress
if credType == string(CredentialTypeTool) || strings.HasPrefix(address, "https://sys.openai"+ctxSeparator) {
address = strings.TrimPrefix(strings.TrimPrefix(address, "https://"), "http://")
}

tool, ctx, err := toolNameAndCtxFromAddress(address)
if err != nil {
return Credential{}, err
}

return Credential{
Context: ctx,
ToolName: tool,
Type: CredentialType(credType),
Env: env,
}, nil
}

func toolNameWithCtx(toolName, credCtx string) string {
return toolName + "///" + credCtx
return toolName + ctxSeparator + credCtx
}

func toolNameAndCtxFromAddress(address string) (string, string, error) {
parts := strings.Split(address, "///")
parts := strings.Split(address, ctxSeparator)
if len(parts) != 2 {
return "", "", fmt.Errorf("error parsing tool name and context %q. Tool names cannot contain '///'", address)
return "", "", fmt.Errorf("error parsing tool name and context %q. Tool names cannot contain '%s'", address, ctxSeparator)
}
return parts[0], parts[1], nil
}
29 changes: 28 additions & 1 deletion pkg/credentials/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package credentials

import (
"errors"
"regexp"
"strings"

"github.com/docker/cli/cli/config/credentials"
"github.com/docker/cli/cli/config/types"
Expand Down Expand Up @@ -55,7 +57,32 @@ func (h *HelperStore) GetAll() (map[string]types.AuthConfig, error) {
return nil, err
}

for serverAddress := range serverAddresses {
newCredAddresses := make(map[string]string, len(serverAddresses))
for serverAddress, val := range serverAddresses {
// If the serverAddress contains a port, we need to put it back in the right spot.
// For some reason, even when a credential is stored properly as http://hostname:8080///credctx,
// the list function will return http://hostname///credctx:8080. This is something wrong
// with macOS's built-in libraries. So we need to fix it here.
toolName, ctx, err := toolNameAndCtxFromAddress(serverAddress)
if err != nil {
return nil, err
}

contextPieces := strings.Split(ctx, ":")
if len(contextPieces) > 1 {
possiblePortNumber := contextPieces[len(contextPieces)-1]
if regexp.MustCompile(`\d+$`).MatchString(possiblePortNumber) {
// port number confirmed
toolName = toolName + ":" + possiblePortNumber
ctx = strings.TrimSuffix(ctx, ":"+possiblePortNumber)
}
}

newCredAddresses[toolNameWithCtx(toolName, ctx)] = val
delete(serverAddresses, serverAddress)
}

for serverAddress := range newCredAddresses {
ac, err := h.Get(serverAddress)
if err != nil {
return nil, err
Expand Down
23 changes: 15 additions & 8 deletions pkg/gptscript/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (

"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/config"
context2 "github.com/gptscript-ai/gptscript/pkg/context"
"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/hash"
Expand Down Expand Up @@ -70,15 +71,20 @@ func New(opts *Options) (*GPTScript, error) {
return nil, err
}

oAIClient, err := openai.NewClient(opts.OpenAI, openai.Options{
cliCfg, err := config.ReadCLIConfig(opts.OpenAI.ConfigFile)
if err != nil {
return nil, err
}

oaiClient, err := openai.NewClient(cliCfg, opts.CredentialContext, opts.OpenAI, openai.Options{
Cache: cacheClient,
SetSeed: true,
})
if err != nil {
return nil, err
}

if err := registry.AddClient(oAIClient); err != nil {
if err := registry.AddClient(oaiClient); err != nil {
return nil, err
}

Expand All @@ -95,18 +101,19 @@ func New(opts *Options) (*GPTScript, error) {
return nil, err
}

remoteClient := remote.New(runner, opts.Env, cacheClient)

if err := registry.AddClient(remoteClient); err != nil {
return nil, err
}

ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
extraEnv, err := prompt.NewServer(ctx, opts.Env)
if err != nil {
closeServer()
return nil, err
}
oaiClient.SetEnvs(extraEnv)

remoteClient := remote.New(runner, extraEnv, cacheClient, cliCfg, opts.CredentialContext)
if err := registry.AddClient(remoteClient); err != nil {
closeServer()
return nil, err
}

return &GPTScript{
Registry: registry,
Expand Down
22 changes: 22 additions & 0 deletions pkg/llm/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"sort"

"github.com/gptscript-ai/gptscript/pkg/openai"
"github.com/gptscript-ai/gptscript/pkg/types"
)

Expand Down Expand Up @@ -44,15 +45,36 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ
if messageRequest.Model == "" {
return nil, fmt.Errorf("model is required")
}

var errs []error
var oaiClient *openai.Client
for _, client := range r.clients {
ok, err := client.Supports(ctx, messageRequest.Model)
if err != nil {
// If we got an OpenAI invalid auth error back, store the OpenAI client for later.
if errors.Is(err, openai.InvalidAuthError{}) {
oaiClient = client.(*openai.Client)
}

errs = append(errs, err)
} else if ok {
return client.Call(ctx, messageRequest, status)
}
}

if len(errs) > 0 && oaiClient != nil {
// Prompt the user to enter their OpenAI API key and try again.
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
return nil, err
}
ok, err := oaiClient.Supports(ctx, messageRequest.Model)
if err != nil {
return nil, err
} else if ok {
return oaiClient.Call(ctx, messageRequest, status)
}
}

if len(errs) == 0 {
return nil, fmt.Errorf("failed to find a model provider for model [%s]", messageRequest.Model)
}
Expand Down
69 changes: 64 additions & 5 deletions pkg/openai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,41 @@ import (

openai "github.com/gptscript-ai/chat-completion-client"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/config"
"github.com/gptscript-ai/gptscript/pkg/counter"
"github.com/gptscript-ai/gptscript/pkg/credentials"
"github.com/gptscript-ai/gptscript/pkg/hash"
"github.com/gptscript-ai/gptscript/pkg/prompt"
"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
)

const (
DefaultModel = openai.GPT4o
DefaultModel = openai.GPT4o
BuiltinCredName = "sys.openai"
)

var (
key = os.Getenv("OPENAI_API_KEY")
url = os.Getenv("OPENAI_URL")
)

type InvalidAuthError struct{}

func (InvalidAuthError) Error() string {
return "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable"
}

type Client struct {
defaultModel string
c *openai.Client
cache *cache.Client
invalidAuth bool
cacheKeyBase string
setSeed bool
cliCfg *config.CLIConfig
credCtx string
envs []string
}

type Options struct {
Expand Down Expand Up @@ -75,12 +88,28 @@ func complete(opts ...Options) (result Options, err error) {
return result, err
}

func NewClient(opts ...Options) (*Client, error) {
func NewClient(cliCfg *config.CLIConfig, credCtx string, opts ...Options) (*Client, error) {
opt, err := complete(opts...)
if err != nil {
return nil, err
}

// If the API key is not set, try to get it from the cred store
if opt.APIKey == "" && opt.BaseURL == "" {
store, err := credentials.NewStore(cliCfg, credCtx)
if err != nil {
return nil, err
}

cred, exists, err := store.Get(BuiltinCredName)
if err != nil {
return nil, err
}
if exists {
opt.APIKey = cred.Env["OPENAI_API_KEY"]
}
}

cfg := openai.DefaultConfig(opt.APIKey)
cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL)
cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID)
Expand All @@ -97,21 +126,33 @@ func NewClient(opts ...Options) (*Client, error) {
cacheKeyBase: cacheKeyBase,
invalidAuth: opt.APIKey == "" && opt.BaseURL == "",
setSeed: opt.SetSeed,
cliCfg: cliCfg,
credCtx: credCtx,
}, nil
}

func (c *Client) ValidAuth() error {
if c.invalidAuth {
return fmt.Errorf("OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable")
return InvalidAuthError{}
}
return nil
}

func (c *Client) SetEnvs(env []string) {
c.envs = env
}

func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
models, err := c.ListModels(ctx)
if err != nil {
return false, err
}

if len(models) == 0 {
// We got no models back, which means our auth is invalid.
return false, InvalidAuthError{}
}

return slices.Contains(models, modelName), nil
}

Expand All @@ -121,8 +162,13 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
return nil, nil
}

// If auth is invalid, we just want to return nothing.
// Returning an InvalidAuthError here will lead to cases where the user is prompted to enter their OpenAI key,
// even when we don't want them to be prompted.
// So the UX we settled on is that no models get printed if the user does gptscript --list-models
// without having provided their key through the environment variable or the creds store.
if err := c.ValidAuth(); err != nil {
return nil, err
return nil, nil
}

models, err := c.c.ListModels(ctx)
Expand Down Expand Up @@ -251,7 +297,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C

func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
if err := c.ValidAuth(); err != nil {
return nil, err
if err := c.RetrieveAPIKey(ctx); err != nil {
return nil, err
}
}

if messageRequest.Model == "" {
Expand Down Expand Up @@ -499,6 +547,17 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
}
}

func (c *Client) RetrieveAPIKey(ctx context.Context) error {
k, err := prompt.GetModelProviderCredential(ctx, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.credCtx, c.envs, c.cliCfg)
if err != nil {
return err
}

c.c.SetAPIKey(k)
c.invalidAuth = false
return nil
}

func ptr[T any](v T) *T {
return &v
}
Loading