Skip to content

Commit f664d57

Browse files
Merge pull request #340 from ibuildthecloud/cache-openapi
bug: speed up the openapi loading
2 parents f0ae38a + 82ce76a commit f664d57

File tree

8 files changed

+102
-28
lines changed

8 files changed

+102
-28
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ require (
2929
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
3030
golang.org/x/sync v0.7.0
3131
golang.org/x/term v0.19.0
32-
gopkg.in/yaml.v3 v3.0.1
3332
)
3433

3534
require (
@@ -81,6 +80,7 @@ require (
8180
golang.org/x/sys v0.19.0 // indirect
8281
golang.org/x/text v0.14.0 // indirect
8382
golang.org/x/tools v0.20.0 // indirect
83+
gopkg.in/yaml.v3 v3.0.1 // indirect
8484
gotest.tools/v3 v3.5.1 // indirect
8585
mvdan.cc/gofumpt v0.6.0 // indirect
8686
)

pkg/cli/gptscript.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
400400

401401
if prg.IsChat() || r.ForceChat {
402402
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
403-
return r.readProgram(ctx, gptScript, args)
403+
return prg, nil
404404
}, os.Environ(), toolInput)
405405
}
406406

pkg/hash/sha256.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,15 @@ func ID(parts ...string) string {
2020

2121
func Digest(obj any) string {
2222
hash := sha256.New()
23-
if err := gob.NewEncoder(hash).Encode(obj); err != nil {
24-
panic(err)
23+
switch v := obj.(type) {
24+
case []byte:
25+
hash.Write(v)
26+
case string:
27+
hash.Write([]byte(v))
28+
default:
29+
if err := gob.NewEncoder(hash).Encode(obj); err != nil {
30+
panic(err)
31+
}
2532
}
2633
return hex.EncodeToString(hash.Sum(nil))
2734
}

pkg/loader/github/github.go

+12
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@ import (
88
"net/http"
99
"os"
1010
"path/filepath"
11+
"regexp"
1112
"strings"
1213

1314
"github.com/gptscript-ai/gptscript/pkg/cache"
1415
"github.com/gptscript-ai/gptscript/pkg/loader"
16+
"github.com/gptscript-ai/gptscript/pkg/mvl"
1517
"github.com/gptscript-ai/gptscript/pkg/repos/git"
1618
"github.com/gptscript-ai/gptscript/pkg/system"
1719
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -26,6 +28,7 @@ const (
2628

2729
var (
2830
githubAuthToken = os.Getenv("GITHUB_AUTH_TOKEN")
31+
log = mvl.Package()
2932
)
3033

3134
func init() {
@@ -37,7 +40,14 @@ func getCommitLsRemote(ctx context.Context, account, repo, ref string) (string,
3740
return git.LsRemote(ctx, url, ref)
3841
}
3942

43+
// regexp to match a git commit id
44+
var commitRegexp = regexp.MustCompile("^[a-f0-9]{40}$")
45+
4046
func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
47+
if commitRegexp.MatchString(ref) {
48+
return ref, nil
49+
}
50+
4151
url := fmt.Sprintf(githubCommitURL, account, repo, ref)
4252
req, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil)
4353
if err != nil {
@@ -69,6 +79,8 @@ func getCommit(ctx context.Context, account, repo, ref string) (string, error) {
6979
return "", fmt.Errorf("failed to decode GitHub commit of %s/%s at %s: %w", account, repo, url, err)
7080
}
7181

82+
log.Debugf("loaded github commit of %s/%s at %s as %q", account, repo, url, commit.SHA)
83+
7284
if commit.SHA == "" {
7385
return "", fmt.Errorf("failed to find commit in response of %s, got empty string", url)
7486
}

pkg/loader/loader.go

+65-21
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ import (
2020
"github.com/gptscript-ai/gptscript/pkg/assemble"
2121
"github.com/gptscript-ai/gptscript/pkg/builtin"
2222
"github.com/gptscript-ai/gptscript/pkg/cache"
23+
"github.com/gptscript-ai/gptscript/pkg/hash"
2324
"github.com/gptscript-ai/gptscript/pkg/parser"
2425
"github.com/gptscript-ai/gptscript/pkg/system"
2526
"github.com/gptscript-ai/gptscript/pkg/types"
26-
"gopkg.in/yaml.v3"
2727
)
2828

2929
const CacheTimeout = time.Hour
@@ -120,24 +120,50 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
120120
return tool, nil
121121
}
122122

123+
func loadOpenAPI(prg *types.Program, data []byte) *openapi3.T {
124+
var (
125+
openAPICacheKey = hash.Digest(data)
126+
openAPIDocument, ok = prg.OpenAPICache[openAPICacheKey].(*openapi3.T)
127+
err error
128+
)
129+
130+
if ok {
131+
return openAPIDocument
132+
}
133+
134+
if prg.OpenAPICache == nil {
135+
prg.OpenAPICache = map[string]any{}
136+
}
137+
138+
openAPIDocument, err = openapi3.NewLoader().LoadFromData(data)
139+
if err != nil || openAPIDocument.Paths.Len() == 0 {
140+
openAPIDocument = nil
141+
}
142+
143+
prg.OpenAPICache[openAPICacheKey] = openAPIDocument
144+
return openAPIDocument
145+
}
146+
123147
func readTool(ctx context.Context, cache *cache.Client, prg *types.Program, base *source, targetToolName string) (types.Tool, error) {
124148
data := base.Content
125149

126150
if bytes.HasPrefix(data, assemble.Header) {
127151
return loadProgram(data, prg, targetToolName)
128152
}
129153

130-
var tools []types.Tool
131-
if isOpenAPI(data) {
132-
if t, err := openapi3.NewLoader().LoadFromData(data); err == nil {
133-
if base.Remote {
134-
tools, err = getOpenAPITools(t, base.Location)
135-
} else {
136-
tools, err = getOpenAPITools(t, "")
137-
}
138-
if err != nil {
139-
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
140-
}
154+
var (
155+
tools []types.Tool
156+
)
157+
158+
if openAPIDocument := loadOpenAPI(prg, data); openAPIDocument != nil {
159+
var err error
160+
if base.Remote {
161+
tools, err = getOpenAPITools(openAPIDocument, base.Location)
162+
} else {
163+
tools, err = getOpenAPITools(openAPIDocument, "")
164+
}
165+
if err != nil {
166+
return types.Tool{}, fmt.Errorf("error parsing OpenAPI definition: %w", err)
141167
}
142168
}
143169

@@ -263,6 +289,12 @@ func link(ctx context.Context, cache *cache.Client, prg *types.Program, base *so
263289
}
264290

265291
func ProgramFromSource(ctx context.Context, content, subToolName string, opts ...Options) (types.Program, error) {
292+
if log.IsDebug() {
293+
start := time.Now()
294+
defer func() {
295+
log.Debugf("loaded program from source took %v", time.Since(start))
296+
}()
297+
}
266298
opt := complete(opts...)
267299

268300
prg := types.Program{
@@ -292,6 +324,13 @@ func complete(opts ...Options) (result Options) {
292324
}
293325

294326
func Program(ctx context.Context, name, subToolName string, opts ...Options) (types.Program, error) {
327+
if log.IsDebug() {
328+
start := time.Now()
329+
defer func() {
330+
log.Debugf("loaded program %s source took %v", name, time.Since(start))
331+
}()
332+
}
333+
295334
opt := complete(opts...)
296335

297336
if subToolName == "" {
@@ -346,15 +385,20 @@ func input(ctx context.Context, cache *cache.Client, base *source, name string)
346385
return nil, fmt.Errorf("can not load tools path=%s name=%s", base.Path, name)
347386
}
348387

349-
func isOpenAPI(data []byte) bool {
350-
var fragment struct {
351-
Paths map[string]any `json:"paths,omitempty"`
352-
}
388+
func SplitToolRef(targetToolName string) (toolName, subTool string) {
389+
var (
390+
fields = strings.Fields(targetToolName)
391+
idx = slices.Index(fields, "from")
392+
)
353393

354-
if err := json.Unmarshal(data, &fragment); err != nil {
355-
if err := yaml.Unmarshal(data, &fragment); err != nil {
356-
return false
357-
}
394+
defer func() {
395+
toolName, _ = types.SplitArg(toolName)
396+
}()
397+
398+
if idx == -1 {
399+
return strings.TrimSpace(targetToolName), ""
358400
}
359-
return len(fragment.Paths) > 0
401+
402+
return strings.Join(fields[idx+1:], " "),
403+
strings.Join(fields[:idx], " ")
360404
}

pkg/loader/openapi.go

+7
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"slices"
88
"sort"
99
"strings"
10+
"time"
1011

1112
"github.com/getkin/kin-openapi/openapi3"
1213
"github.com/gptscript-ai/gptscript/pkg/engine"
@@ -18,6 +19,12 @@ import (
1819
// The tool's Instructions will be in the format "#!sys.openapi '{JSON Instructions}'",
1920
// where the JSON Instructions are a JSON-serialized engine.OpenAPIInstructions struct.
2021
func getOpenAPITools(t *openapi3.T, defaultHost string) ([]types.Tool, error) {
22+
if log.IsDebug() {
23+
start := time.Now()
24+
defer func() {
25+
log.Debugf("loaded openapi tools in %v", time.Since(start))
26+
}()
27+
}
2128
// Determine the default server.
2229
if len(t.Servers) == 0 {
2330
if defaultHost != "" {

pkg/repos/git/cmd.go

+3
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ import (
99
)
1010

1111
func newGitCommand(ctx context.Context, args ...string) *debugcmd.WrappedCmd {
12+
if log.IsDebug() {
13+
log.Debugf("running git command: %s", strings.Join(args, " "))
14+
}
1215
cmd := debugcmd.New(ctx, "git", args...)
1316
return cmd
1417
}

pkg/types/tool.go

+4-3
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,10 @@ func (e *ErrToolNotFound) Error() string {
3636
type ToolSet map[string]Tool
3737

3838
type Program struct {
39-
Name string `json:"name,omitempty"`
40-
EntryToolID string `json:"entryToolId,omitempty"`
41-
ToolSet ToolSet `json:"toolSet,omitempty"`
39+
Name string `json:"name,omitempty"`
40+
EntryToolID string `json:"entryToolId,omitempty"`
41+
ToolSet ToolSet `json:"toolSet,omitempty"`
42+
OpenAPICache map[string]any `json:"-"`
4243
}
4344

4445
func (p Program) IsChat() bool {

0 commit comments

Comments
 (0)