@@ -22,10 +22,9 @@ import (
22
22
)
23
23
24
24
type Client struct {
25
- clientsLock sync.Mutex
25
+ modelsLock sync.Mutex
26
26
cache * cache.Client
27
- clients map [string ]* openai.Client
28
- models map [string ]* openai.Client
27
+ modelToProvider map [string ]string
29
28
runner * runner.Runner
30
29
envs []string
31
30
credStore credentials.CredentialStore
@@ -43,14 +42,19 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent
43
42
}
44
43
45
44
func (c * Client ) Call (ctx context.Context , messageRequest types.CompletionRequest , status chan <- types.CompletionStatus ) (* types.CompletionMessage , error ) {
46
- c .clientsLock .Lock ()
47
- client , ok := c .models [messageRequest .Model ]
48
- c .clientsLock .Unlock ()
45
+ c .modelsLock .Lock ()
46
+ provider , ok := c .modelToProvider [messageRequest .Model ]
47
+ c .modelsLock .Unlock ()
49
48
50
49
if ! ok {
51
50
return nil , fmt .Errorf ("failed to find remote model %s" , messageRequest .Model )
52
51
}
53
52
53
+ client , err := c .load (ctx , provider )
54
+ if err != nil {
55
+ return nil , err
56
+ }
57
+
54
58
toolName , modelName := types .SplitToolRef (messageRequest .Model )
55
59
if modelName == "" {
56
60
// modelName is empty, then the messageRequest.Model is not of the form 'modelName from provider'
@@ -96,19 +100,19 @@ func (c *Client) Supports(ctx context.Context, modelString string) (bool, error)
96
100
return false , nil
97
101
}
98
102
99
- client , err := c .load (ctx , providerName )
103
+ _ , err := c .load (ctx , providerName )
100
104
if err != nil {
101
105
return false , err
102
106
}
103
107
104
- c .clientsLock .Lock ()
105
- defer c .clientsLock .Unlock ()
108
+ c .modelsLock .Lock ()
109
+ defer c .modelsLock .Unlock ()
106
110
107
- if c .models == nil {
108
- c .models = map [string ]* openai. Client {}
111
+ if c .modelToProvider == nil {
112
+ c .modelToProvider = map [string ]string {}
109
113
}
110
114
111
- c .models [modelString ] = client
115
+ c .modelToProvider [modelString ] = providerName
112
116
return true , nil
113
117
}
114
118
@@ -141,24 +145,11 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie
141
145
}
142
146
143
147
func (c * Client ) load (ctx context.Context , toolName string ) (* openai.Client , error ) {
144
- c .clientsLock .Lock ()
145
- defer c .clientsLock .Unlock ()
146
-
147
- client , ok := c .clients [toolName ]
148
- if ok {
149
- return client , nil
150
- }
151
-
152
- if c .clients == nil {
153
- c .clients = make (map [string ]* openai.Client )
154
- }
155
-
156
148
if isHTTPURL (toolName ) {
157
149
remoteClient , err := c .clientFromURL (ctx , toolName )
158
150
if err != nil {
159
151
return nil , err
160
152
}
161
- c .clients [toolName ] = remoteClient
162
153
return remoteClient , nil
163
154
}
164
155
@@ -174,22 +165,15 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
174
165
return nil , err
175
166
}
176
167
177
- if strings .HasSuffix (url , "/" ) {
178
- url += "v1"
179
- } else {
180
- url += "/v1"
181
- }
182
-
183
- client , err = openai .NewClient (ctx , c .credStore , openai.Options {
184
- BaseURL : url ,
168
+ client , err := openai .NewClient (ctx , c .credStore , openai.Options {
169
+ BaseURL : strings .TrimSuffix (url , "/" ) + "/v1" ,
185
170
Cache : c .cache ,
186
171
CacheKey : prg .EntryToolID ,
187
172
})
188
173
if err != nil {
189
174
return nil , err
190
175
}
191
176
192
- c .clients [toolName ] = client
193
177
return client , nil
194
178
}
195
179
0 commit comments