Skip to content

Commit c0507a2

Browse files
committed
feat: allow providers to be restarted if they stop
By not caching the client, gptscript is able to restart the provider daemon if it stops. If the daemon is still running, then there is little overhead because the daemon URL is cached and the tool will not be completely reprocessed. The model to provider mapping is still cached so that the client can be recreated when necessary. Signed-off-by: Donnie Adams <[email protected]>
1 parent 4fd8e8a commit c0507a2

File tree

1 file changed

+18
-34
lines changed

1 file changed

+18
-34
lines changed

pkg/remote/remote.go

+18-34
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,9 @@ import (
2222
)
2323

2424
type Client struct {
25-
clientsLock sync.Mutex
25+
modelsLock sync.Mutex
2626
cache *cache.Client
27-
clients map[string]*openai.Client
28-
models map[string]*openai.Client
27+
modelToProvider map[string]string
2928
runner *runner.Runner
3029
envs []string
3130
credStore credentials.CredentialStore
@@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
4342
}
4443

4544
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
46-
c.clientsLock.Lock()
47-
client, ok := c.models[messageRequest.Model]
48-
c.clientsLock.Unlock()
45+
c.modelsLock.Lock()
46+
provider, ok := c.modelToProvider[messageRequest.Model]
47+
c.modelsLock.Unlock()
4948

5049
if !ok {
5150
return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model)
5251
}
5352

53+
client, err := c.load(ctx, provider)
54+
if err != nil {
55+
return nil, err
56+
}
57+
5458
toolName, modelName := types.SplitToolRef(messageRequest.Model)
5559
if modelName == "" {
5660
// modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider'
@@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
96100
return false, nil
97101
}
98102

99-
client, err := c.load(ctx, providerName)
103+
_, err := c.load(ctx, providerName)
100104
if err != nil {
101105
return false, err
102106
}
103107

104-
c.clientsLock.Lock()
105-
defer c.clientsLock.Unlock()
108+
c.modelsLock.Lock()
109+
defer c.modelsLock.Unlock()
106110

107-
if c.models == nil {
108-
c.models = map[string]*openai.Client{}
111+
if c.modelToProvider == nil {
112+
c.modelToProvider = map[string]string{}
109113
}
110114

111-
c.models[modelString] = client
115+
c.modelToProvider[modelString] = providerName
112116
return true, nil
113117
}
114118

@@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
141145
}
142146

143147
func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) {
144-
c.clientsLock.Lock()
145-
defer c.clientsLock.Unlock()
146-
147-
client, ok := c.clients[toolName]
148-
if ok {
149-
return client, nil
150-
}
151-
152-
if c.clients == nil {
153-
c.clients = make(map[string]*openai.Client)
154-
}
155-
156148
if isHTTPURL(toolName) {
157149
remoteClient, err := c.clientFromURL(ctx, toolName)
158150
if err != nil {
159151
return nil, err
160152
}
161-
c.clients[toolName] = remoteClient
162153
return remoteClient, nil
163154
}
164155

@@ -174,22 +165,15 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174165
return nil, err
175166
}
176167

177-
if strings.HasSuffix(url, "/") {
178-
url += "v1"
179-
} else {
180-
url += "/v1"
181-
}
182-
183-
client, err = openai.NewClient(ctx, c.credStore, openai.Options{
184-
BaseURL: url,
168+
client, err := openai.NewClient(ctx, c.credStore, openai.Options{
169+
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
185170
Cache: c.cache,
186171
CacheKey: prg.EntryToolID,
187172
})
188173
if err != nil {
189174
return nil, err
190175
}
191176

192-
c.clients[toolName] = client
193177
return client, nil
194178
}
195179

0 commit comments

Comments
 (0)