Skip to content

Commit 721d518

Browse files
authored
enhance: automatically use the generic credential tool for OpenAPI (#488)
Signed-off-by: Grant Linville <[email protected]>
1 parent 3277a3c commit 721d518

File tree

4 files changed

+66
-18
lines changed

4 files changed

+66
-18
lines changed

pkg/engine/openapi.go

+50-16
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ import (
1313
"github.com/gptscript-ai/gptscript/pkg/env"
1414
"github.com/gptscript-ai/gptscript/pkg/types"
1515
"github.com/tidwall/gjson"
16+
"golang.org/x/exp/maps"
1617
)
1718

1819
var (
@@ -35,6 +36,45 @@ type SecurityInfo struct {
3536
In string `json:"in"` // header, query, or cookie, for type==apiKey
3637
}
3738

39+
func (i SecurityInfo) GetCredentialToolStrings(hostname string) []string {
40+
vars := i.getCredentialNamesAndEnvVars(hostname)
41+
var tools []string
42+
43+
for cred, v := range vars {
44+
field := "value"
45+
switch i.Type {
46+
case "apiKey":
47+
field = i.APIKeyName
48+
case "http":
49+
if i.Scheme == "bearer" {
50+
field = "bearer token"
51+
} else {
52+
if strings.Contains(v, "PASSWORD") {
53+
field = "password"
54+
} else {
55+
field = "username"
56+
}
57+
}
58+
}
59+
60+
tools = append(tools, fmt.Sprintf("github.com/gptscript-ai/credential as %s with %s as env and %q as message and %q as field",
61+
cred, v, "Please provide a value for the "+v+" environment variable", field))
62+
}
63+
return tools
64+
}
65+
66+
func (i SecurityInfo) getCredentialNamesAndEnvVars(hostname string) map[string]string {
67+
if i.Type == "http" && i.Scheme == "basic" {
68+
return map[string]string{
69+
hostname + i.Name + "Username": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_USERNAME",
70+
hostname + i.Name + "Password": "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name) + "_PASSWORD",
71+
}
72+
}
73+
return map[string]string{
74+
hostname + i.Name: "GPTSCRIPT_" + env.ToEnvLike(hostname) + "_" + env.ToEnvLike(i.Name),
75+
}
76+
}
77+
3878
type OpenAPIInstructions struct {
3979
Server string `json:"server"`
4080
Path string `json:"path"`
@@ -83,8 +123,8 @@ func (e *Engine) runOpenAPI(tool types.Tool, input string) (*Return, error) {
83123
return nil, fmt.Errorf("failed to create request: %w", err)
84124
}
85125

86-
// Check for authentication (only if using HTTPS)
87-
if u.Scheme == "https" {
126+
// Check for authentication (only if using HTTPS or localhost)
127+
if u.Scheme == "https" || u.Hostname() == "localhost" || u.Hostname() == "127.0.0.1" {
88128
if len(instructions.SecurityInfos) > 0 {
89129
if err := handleAuths(req, envMap, instructions.SecurityInfos); err != nil {
90130
return nil, fmt.Errorf("error setting up authentication: %w", err)
@@ -181,15 +221,9 @@ func handleAuths(req *http.Request, envMap map[string]string, infoSets [][]Secur
181221
for _, infoSet := range infoSets {
182222
var missing []string // Keep track of any missing environment variables
183223
for _, info := range infoSet {
184-
envNames := []string{"GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name)}
185-
if info.Type == "http" && info.Scheme == "basic" {
186-
envNames = []string{
187-
"GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name) + "_USERNAME",
188-
"GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name) + "_PASSWORD",
189-
}
190-
}
224+
vars := info.getCredentialNamesAndEnvVars(req.URL.Hostname())
191225

192-
for _, envName := range envNames {
226+
for _, envName := range vars {
193227
if _, ok := envMap[envName]; !ok {
194228
missing = append(missing, envName)
195229
}
@@ -203,28 +237,28 @@ func handleAuths(req *http.Request, envMap map[string]string, infoSets [][]Secur
203237
// We're using this info set, because no environment variables were missing.
204238
// Set up the request as needed.
205239
for _, info := range infoSet {
206-
envName := "GPTSCRIPT_" + env.ToEnvLike(req.URL.Hostname()) + "_" + env.ToEnvLike(info.Name)
240+
envNames := maps.Values(info.getCredentialNamesAndEnvVars(req.URL.Hostname()))
207241
switch info.Type {
208242
case "apiKey":
209243
switch info.In {
210244
case "header":
211-
req.Header.Set(info.APIKeyName, envMap[envName])
245+
req.Header.Set(info.APIKeyName, envMap[envNames[0]])
212246
case "query":
213247
v := url.Values{}
214-
v.Add(info.APIKeyName, envMap[envName])
248+
v.Add(info.APIKeyName, envMap[envNames[0]])
215249
req.URL.RawQuery = v.Encode()
216250
case "cookie":
217251
req.AddCookie(&http.Cookie{
218252
Name: info.APIKeyName,
219-
Value: envMap[envName],
253+
Value: envMap[envNames[0]],
220254
})
221255
}
222256
case "http":
223257
switch info.Scheme {
224258
case "bearer":
225-
req.Header.Set("Authorization", "Bearer "+envMap[envName])
259+
req.Header.Set("Authorization", "Bearer "+envMap[envNames[0]])
226260
case "basic":
227-
req.SetBasicAuth(envMap[envName+"_USERNAME"], envMap[envName+"_PASSWORD"])
261+
req.SetBasicAuth(envMap[envNames[0]], envMap[envNames[1]])
228262
}
229263
}
230264
}

pkg/loader/openapi.go

+11
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,17 @@ func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) {
278278
return nil, err
279279
}
280280

281+
if len(infos) > 0 {
282+
// Set up credential tools for the first set of infos.
283+
for _, info := range infos[0] {
284+
operationServerURL, err := url.Parse(operationServer)
285+
if err != nil {
286+
return nil, fmt.Errorf("failed to parse operation server URL: %w", err)
287+
}
288+
tool.Credentials = info.GetCredentialToolStrings(operationServerURL.Hostname())
289+
}
290+
}
291+
281292
// Register
282293
toolNames = append(toolNames, tool.Parameters.Name)
283294
tools = append(tools, tool)

pkg/types/tool.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,9 @@ func (t ToolDef) String() string {
411411
_, _ = fmt.Fprintf(buf, "Internal Prompt: %v\n", *t.Parameters.InternalPrompt)
412412
}
413413
if len(t.Parameters.Credentials) > 0 {
414-
_, _ = fmt.Fprintf(buf, "Credentials: %s\n", strings.Join(t.Parameters.Credentials, ", "))
414+
for _, cred := range t.Parameters.Credentials {
415+
_, _ = fmt.Fprintf(buf, "Credential: %s\n", cred)
416+
}
415417
}
416418
if t.Parameters.Chat {
417419
_, _ = fmt.Fprintf(buf, "Chat: true\n")

pkg/types/tool_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,8 @@ Temperature: 0.800000
5050
Parameter: arg1: desc1
5151
Parameter: arg2: desc2
5252
Internal Prompt: true
53-
Credentials: Credential1, Credential2
53+
Credential: Credential1
54+
Credential: Credential2
5455
Chat: true
5556
5657
This is a sample instruction

0 commit comments

Comments
 (0)