Skip to content

Commit 9ca6e93

Browse files
Merge pull request #629 from ibuildthecloud/main
chore: add location option to loading scripts
2 parents 0c73f4b + 2676b35 commit 9ca6e93

File tree

5 files changed

+51
-16
lines changed

5 files changed

+51
-16
lines changed

pkg/loader/loader.go

+16-2
Original file line numberDiff line numberDiff line change
@@ -373,12 +373,20 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
373373
}
374374
opt := complete(opts...)
375375

376+
var locationPath, locationName string
377+
if opt.Location != "" {
378+
locationPath = path.Dir(opt.Location)
379+
locationName = path.Base(opt.Location)
380+
}
381+
376382
prg := types.Program{
377383
ToolSet: types.ToolSet{},
378384
}
379385
tools, err := readTool(ctx, opt.Cache, &prg, &source{
380386
Content: []byte(content),
381-
Location: "inline",
387+
Path: locationPath,
388+
Name: locationName,
389+
Location: opt.Location,
382390
}, subToolName)
383391
if err != nil {
384392
return types.Program{}, err
@@ -388,12 +396,18 @@ func ProgramFromSource(ctx context.Context, content, subToolName string, opts ..
388396
}
389397

390398
type Options struct {
391-
Cache *cache.Client
399+
Cache *cache.Client
400+
Location string
392401
}
393402

394403
func complete(opts ...Options) (result Options) {
395404
for _, opt := range opts {
396405
result.Cache = types.FirstSet(opt.Cache, result.Cache)
406+
result.Location = types.FirstSet(opt.Location, result.Location)
407+
}
408+
409+
if result.Location == "" {
410+
result.Location = "inline"
397411
}
398412

399413
return

pkg/loader/url.go

+25-13
Original file line numberDiff line numberDiff line change
@@ -111,11 +111,20 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
111111
req.Header.Set("Authorization", "Bearer "+bearerToken)
112112
}
113113

114-
data, err := getWithDefaults(req)
114+
data, defaulted, err := getWithDefaults(req)
115115
if err != nil {
116116
return nil, false, fmt.Errorf("error loading %s: %v", url, err)
117117
}
118118

119+
if defaulted != "" {
120+
pathString = url
121+
name = defaulted
122+
if repo != nil {
123+
repo.Path = path.Join(repo.Path, repo.Name)
124+
repo.Name = defaulted
125+
}
126+
}
127+
119128
log.Debugf("opened %s", url)
120129

121130
result := &source{
@@ -137,31 +146,32 @@ func loadURL(ctx context.Context, cache *cache.Client, base *source, name string
137146
return result, true, nil
138147
}
139148

140-
func getWithDefaults(req *http.Request) ([]byte, error) {
149+
func getWithDefaults(req *http.Request) ([]byte, string, error) {
141150
originalPath := req.URL.Path
142151

143152
// First, try to get the original path as is. It might be an OpenAPI definition.
144153
resp, err := http.DefaultClient.Do(req)
145154
if err != nil {
146-
return nil, err
155+
return nil, "", err
147156
}
148157
defer resp.Body.Close()
149158

150159
if resp.StatusCode == http.StatusOK {
151-
if toolBytes, err := io.ReadAll(resp.Body); err == nil && isOpenAPI(toolBytes) != 0 {
152-
return toolBytes, nil
153-
}
160+
toolBytes, err := io.ReadAll(resp.Body)
161+
return toolBytes, "", err
162+
}
163+
164+
base := path.Base(originalPath)
165+
if strings.Contains(base, ".") {
166+
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
154167
}
155168

156169
for i, def := range types.DefaultFiles {
157-
base := path.Base(originalPath)
158-
if !strings.Contains(base, ".") {
159-
req.URL.Path = path.Join(originalPath, def)
160-
}
170+
req.URL.Path = path.Join(originalPath, def)
161171

162172
resp, err := http.DefaultClient.Do(req)
163173
if err != nil {
164-
return nil, err
174+
return nil, "", err
165175
}
166176
defer resp.Body.Close()
167177

@@ -170,11 +180,13 @@ func getWithDefaults(req *http.Request) ([]byte, error) {
170180
}
171181

172182
if resp.StatusCode != http.StatusOK {
173-
return nil, fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
183+
return nil, "", fmt.Errorf("error loading %s: %s", req.URL.String(), resp.Status)
174184
}
175185

176-
return io.ReadAll(resp.Body)
186+
data, err := io.ReadAll(resp.Body)
187+
return data, def, err
177188
}
189+
178190
panic("unreachable")
179191
}
180192

pkg/sdkserver/routes.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ func (s *server) execHandler(w http.ResponseWriter, r *http.Request) {
183183
logger.Debugf("executing tool: %+v", reqObject)
184184
var (
185185
def fmt.Stringer = &reqObject.ToolDefs
186-
programLoader loaderFunc = loader.ProgramFromSource
186+
programLoader = loaderWithLocation(loader.ProgramFromSource, reqObject.Location)
187187
)
188188
if reqObject.Content != "" {
189189
def = &reqObject.content

pkg/sdkserver/run.go

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,14 @@ import (
1616

1717
type loaderFunc func(context.Context, string, string, ...loader.Options) (types.Program, error)
1818

19+
func loaderWithLocation(f loaderFunc, loc string) loaderFunc {
20+
return func(ctx context.Context, s string, s2 string, options ...loader.Options) (types.Program, error) {
21+
return f(ctx, s, s2, append(options, loader.Options{
22+
Location: loc,
23+
})...)
24+
}
25+
}
26+
1927
func (s *server) execAndStream(ctx context.Context, programLoader loaderFunc, logger mvl.Logger, w http.ResponseWriter, opts gptscript.Options, chatState, input, subTool string, toolDef fmt.Stringer) {
2028
g, err := gptscript.New(ctx, s.gptscriptOpts, opts)
2129
if err != nil {

pkg/sdkserver/types.go

+1
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ type toolOrFileRequest struct {
6161
CredentialContext string `json:"credentialContext"`
6262
CredentialOverrides []string `json:"credentialOverrides"`
6363
Confirm bool `json:"confirm"`
64+
Location string `json:"location,omitempty"`
6465
}
6566

6667
type content struct {

0 commit comments

Comments
 (0)