Skip to content

Commit fb32c4c

Browse files
authored
Merge pull request #733 from thedadams/allow-provider-restart
feat: allow providers to be restarted if they stop
2 parents 4fd8e8a + c0507a2 commit fb32c4c

File tree

1 file changed

+18
-34
lines changed

1 file changed

+18
-34
lines changed

pkg/remote/remote.go

+18-34
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ import (
2222
)
2323

2424
type Client struct {
25-
clientsLock sync.Mutex
25+
modelsLock sync.Mutex
2626
cache *cache.Client
27-
clients map[string]*openai.Client
28-
models map[string]*openai.Client
27+
modelToProvider map[string]string
2928
runner *runner.Runner
3029
envs []string
3130
credStore credentials.CredentialStore
@@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
4342
}
4443

4544
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
46-
c.clientsLock.Lock()
47-
client, ok := c.models[messageRequest.Model]
48-
c.clientsLock.Unlock()
45+
c.modelsLock.Lock()
46+
provider, ok := c.modelToProvider[messageRequest.Model]
47+
c.modelsLock.Unlock()
4948

5049
if !ok {
5150
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
5251
}
5352

53+
client, err := c.load(ctx, provider)
54+
if err != nil {
55+
return nil, err
56+
}
57+
5458
toolName, modelName := types.SplitToolRef(messageRequest.Model)
5559
if modelName == "" {
5660
// modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider'
@@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
96100
return false, nil
97101
}
98102

99-
client, err := c.load(ctx, providerName)
103+
_, err := c.load(ctx, providerName)
100104
if err != nil {
101105
return false, err
102106
}
103107

104-
c.clientsLock.Lock()
105-
defer c.clientsLock.Unlock()
108+
c.modelsLock.Lock()
109+
defer c.modelsLock.Unlock()
106110

107-
if c.models == nil {
108-
c.models = map[string]*openai.Client{}
111+
if c.modelToProvider == nil {
112+
c.modelToProvider = map[string]string{}
109113
}
110114

111-
c.models[modelString] = client
115+
c.modelToProvider[modelString] = providerName
112116
return true, nil
113117
}
114118

@@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
141145
}
142146

143147
func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
144-
c.clientsLock.Lock()
145-
defer c.clientsLock.Unlock()
146-
147-
client, ok := c.clients[toolName]
148-
if ok {
149-
return client, nil
150-
}
151-
152-
if c.clients == nil {
153-
c.clients = make(map[string]*openai.Client)
154-
}
155-
156148
if isHTTPURL(toolName) {
157149
remoteClient, err := c.clientFromURL(ctx, toolName)
158150
if err != nil {
159151
return nil, err
160152
}
161-
c.clients[toolName] = remoteClient
162153
return remoteClient, nil
163154
}
164155

@@ -174,22 +165,15 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174165
return nil, err
175166
}
176167

177-
if strings.HasSuffix(url, "/") {
178-
url += "v1"
179-
} else {
180-
url += "/v1"
181-
}
182-
183-
client, err = openai.NewClient(ctx, c.credStore, openai.Options{
184-
BaseURL: url,
168+
client, err := openai.NewClient(ctx, c.credStore, openai.Options{
169+
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
185170
Cache: c.cache,
186171
CacheKey: prg.EntryToolID,
187172
})
188173
if err != nil {
189174
return nil, err
190175
}
191176

192-
c.clients[toolName] = client
193177
return client, nil
194178
}
195179

0 commit comments

Comments
 (0)