@@ -12,28 +12,41 @@ import (
12
12
13
13
openai "github.com/gptscript-ai/chat-completion-client"
14
14
"github.com/gptscript-ai/gptscript/pkg/cache"
15
+ "github.com/gptscript-ai/gptscript/pkg/config"
15
16
"github.com/gptscript-ai/gptscript/pkg/counter"
17
+ "github.com/gptscript-ai/gptscript/pkg/credentials"
16
18
"github.com/gptscript-ai/gptscript/pkg/hash"
19
+ "github.com/gptscript-ai/gptscript/pkg/prompt"
17
20
"github.com/gptscript-ai/gptscript/pkg/system"
18
21
"github.com/gptscript-ai/gptscript/pkg/types"
19
22
)
20
23
21
24
const (
22
- DefaultModel = openai .GPT4o
25
+ DefaultModel = openai .GPT4o
26
+ BuiltinCredName = "sys.openai"
23
27
)
24
28
25
29
var (
26
30
key = os .Getenv ("OPENAI_API_KEY" )
27
31
url = os .Getenv ("OPENAI_URL" )
28
32
)
29
33
34
+ type InvalidAuthError struct {}
35
+
36
+ func (InvalidAuthError ) Error () string {
37
+ return "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable"
38
+ }
39
+
30
40
type Client struct {
31
41
defaultModel string
32
42
c * openai.Client
33
43
cache * cache.Client
34
44
invalidAuth bool
35
45
cacheKeyBase string
36
46
setSeed bool
47
+ cliCfg * config.CLIConfig
48
+ credCtx string
49
+ envs []string
37
50
}
38
51
39
52
type Options struct {
@@ -75,12 +88,28 @@ func complete(opts ...Options) (result Options, err error) {
75
88
return result , err
76
89
}
77
90
78
- func NewClient (opts ... Options ) (* Client , error ) {
91
+ func NewClient (cliCfg * config. CLIConfig , credCtx string , opts ... Options ) (* Client , error ) {
79
92
opt , err := complete (opts ... )
80
93
if err != nil {
81
94
return nil , err
82
95
}
83
96
97
+ // If the API key is not set, try to get it from the cred store
98
+ if opt .APIKey == "" && opt .BaseURL == "" {
99
+ store , err := credentials .NewStore (cliCfg , credCtx )
100
+ if err != nil {
101
+ return nil , err
102
+ }
103
+
104
+ cred , exists , err := store .Get (BuiltinCredName )
105
+ if err != nil {
106
+ return nil , err
107
+ }
108
+ if exists {
109
+ opt .APIKey = cred .Env ["OPENAI_API_KEY" ]
110
+ }
111
+ }
112
+
84
113
cfg := openai .DefaultConfig (opt .APIKey )
85
114
cfg .BaseURL = types .FirstSet (opt .BaseURL , cfg .BaseURL )
86
115
cfg .OrgID = types .FirstSet (opt .OrgID , cfg .OrgID )
@@ -97,21 +126,33 @@ func NewClient(opts ...Options) (*Client, error) {
97
126
cacheKeyBase : cacheKeyBase ,
98
127
invalidAuth : opt .APIKey == "" && opt .BaseURL == "" ,
99
128
setSeed : opt .SetSeed ,
129
+ cliCfg : cliCfg ,
130
+ credCtx : credCtx ,
100
131
}, nil
101
132
}
102
133
103
134
func (c * Client ) ValidAuth () error {
104
135
if c .invalidAuth {
105
- return fmt . Errorf ( "OPENAI_API_KEY is not set. Please set the OPENAI_API_KEY environment variable" )
136
+ return InvalidAuthError {}
106
137
}
107
138
return nil
108
139
}
109
140
141
+ func (c * Client ) SetEnvs (env []string ) {
142
+ c .envs = env
143
+ }
144
+
110
145
func (c * Client ) Supports (ctx context.Context , modelName string ) (bool , error ) {
111
146
models , err := c .ListModels (ctx )
112
147
if err != nil {
113
148
return false , err
114
149
}
150
+
151
+ if len (models ) == 0 {
152
+ // We got no models back, which means our auth is invalid.
153
+ return false , InvalidAuthError {}
154
+ }
155
+
115
156
return slices .Contains (models , modelName ), nil
116
157
}
117
158
@@ -121,8 +162,13 @@ func (c *Client) ListModels(ctx context.Context, providers ...string) (result []
121
162
return nil , nil
122
163
}
123
164
165
+ // If auth is invalid, we just want to return nothing.
166
+ // Returning an InvalidAuthError here will lead to cases where the user is prompted to enter their OpenAI key,
167
+ // even when we don't want them to be prompted.
168
+ // So the UX we settled on is that no models get printed if the user does gptscript --list-models
169
+ // without having provided their key through the environment variable or the creds store.
124
170
if err := c .ValidAuth (); err != nil {
125
- return nil , err
171
+ return nil , nil
126
172
}
127
173
128
174
models , err := c .c .ListModels (ctx )
@@ -251,7 +297,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
251
297
252
298
func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
253
299
if err := c .ValidAuth (); err != nil {
254
- return nil , err
300
+ if err := c .RetrieveAPIKey (ctx ); err != nil {
301
+ return nil , err
302
+ }
255
303
}
256
304
257
305
if messageRequest .Model == "" {
@@ -499,6 +547,17 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
499
547
}
500
548
}
501
549
550
+ func (c * Client ) RetrieveAPIKey (ctx context.Context ) error {
551
+ k , err := prompt .GetModelProviderCredential (ctx , BuiltinCredName , "OPENAI_API_KEY" , "Please provide your OpenAI API key:" , c .credCtx , c .envs , c .cliCfg )
552
+ if err != nil {
553
+ return err
554
+ }
555
+
556
+ c .c .SetAPIKey (k )
557
+ c .invalidAuth = false
558
+ return nil
559
+ }
560
+
502
561
func ptr [T any ](v T ) * T {
503
562
return & v
504
563
}
0 commit comments