Skip to content

Commit 62807c0

Browse files
committed
handle context
Signed-off-by: Grant Linville <[email protected]>
1 parent a095fff commit 62807c0

File tree

4 files changed

+11
-12
lines changed

4 files changed

+11
-12
lines changed

pkg/llm/registry.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ
6464

6565
if len(errs) > 0 && oaiClient != nil {
6666
// Prompt the user to enter their OpenAI API key and try again.
67-
if err := oaiClient.RetrieveAPIKey(); err != nil {
67+
if err := oaiClient.RetrieveAPIKey(ctx); err != nil {
6868
return nil, err
6969
}
7070
ok, err := oaiClient.Supports(ctx, messageRequest.Model)

pkg/openai/client.go

+3-3
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
318318

319319
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
320320
if err := c.ValidAuth(); err != nil {
321-
if err := c.RetrieveAPIKey(); err != nil {
321+
if err := c.RetrieveAPIKey(ctx); err != nil {
322322
return nil, err
323323
}
324324
}
@@ -569,8 +569,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
569569
}
570570
}
571571

572-
func (c *Client) RetrieveAPIKey() error {
573-
k, err := prompt.GetModelProviderCredential(BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.credCtx, c.envs, c.cliCfg)
572+
func (c *Client) RetrieveAPIKey(ctx context.Context) error {
573+
k, err := prompt.GetModelProviderCredential(ctx, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.credCtx, c.envs, c.cliCfg)
574574
if err != nil {
575575
return err
576576
}

pkg/prompt/credential.go

+2-3
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"github.com/tidwall/gjson"
1010
)
1111

12-
func GetModelProviderCredential(credName, env, message, credCtx string, envs []string, cliCfg *config.CLIConfig) (string, error) {
12+
func GetModelProviderCredential(ctx context.Context, credName, env, message, credCtx string, envs []string, cliCfg *config.CLIConfig) (string, error) {
1313
store, err := credentials.NewStore(cliCfg, credCtx)
1414
if err != nil {
1515
return "", err
@@ -24,8 +24,7 @@ func GetModelProviderCredential(credName, env, message, credCtx string, envs []s
2424
if exists {
2525
k = cred.Env[env]
2626
} else {
27-
// SysPrompt doesn't use its first two arguments, so we can safely pass whatever to them
28-
result, err := SysPrompt(context.Background(), envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
27+
result, err := SysPrompt(ctx, envs, fmt.Sprintf(`{"message":"%s","fields":"key","sensitive":"true"}`, message))
2928
if err != nil {
3029
return "", err
3130
}

pkg/remote/remote.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ func isHTTPURL(toolName string) bool {
101101
strings.HasPrefix(toolName, "https://")
102102
}
103103

104-
func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) {
104+
func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Client, error) {
105105
parsed, err := url.Parse(apiURL)
106106
if err != nil {
107107
return nil, err
@@ -111,7 +111,7 @@ func (c *Client) clientFromURL(apiURL string) (*openai.Client, error) {
111111

112112
if key == "" {
113113
var err error
114-
key, err = c.retrieveAPIKey(env, apiURL)
114+
key, err = c.retrieveAPIKey(ctx, env, apiURL)
115115
if err != nil {
116116
return nil, err
117117
}
@@ -138,7 +138,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
138138
}
139139

140140
if isHTTPURL(toolName) {
141-
remoteClient, err := c.clientFromURL(toolName)
141+
remoteClient, err := c.clientFromURL(ctx, toolName)
142142
if err != nil {
143143
return nil, err
144144
}
@@ -177,6 +177,6 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
177177
return client, nil
178178
}
179179

180-
func (c *Client) retrieveAPIKey(env, url string) (string, error) {
181-
return prompt.GetModelProviderCredential(url, env, fmt.Sprintf("Please provide your API key for %s", url), c.credCtx, c.envs, c.cliCfg)
180+
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
181+
return prompt.GetModelProviderCredential(ctx, url, env, fmt.Sprintf("Please provide your API key for %s", url), c.credCtx, c.envs, c.cliCfg)
182182
}

0 commit comments

Comments
 (0)