Skip to content

Commit c98b7ed

Browse files
feat: add input filters
1 parent 81d3b48 commit c98b7ed

File tree

15 files changed

+308
-62
lines changed

15 files changed

+308
-62
lines changed

pkg/engine/cmd.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,8 @@ func appendInputAsEnv(env []string, input string) []string {
185185
dec := json.NewDecoder(bytes.NewReader([]byte(input)))
186186
dec.UseNumber()
187187

188+
env = appendEnv(env, "GPTSCRIPT_INPUT", input)
189+
188190
if err := json.Unmarshal([]byte(input), &data); err != nil {
189191
// ignore invalid JSON
190192
return env
@@ -206,7 +208,6 @@ func appendInputAsEnv(env []string, input string) []string {
206208
}
207209
}
208210

209-
env = appendEnv(env, "GPTSCRIPT_INPUT", input)
210211
return env
211212
}
212213

pkg/engine/engine.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ const (
9898
ProviderToolCategory ToolCategory = "provider"
9999
CredentialToolCategory ToolCategory = "credential"
100100
ContextToolCategory ToolCategory = "context"
101+
InputToolCategory ToolCategory = "input"
101102
NoCategory ToolCategory = ""
102103
)
103104

@@ -180,7 +181,7 @@ func NewContext(ctx context.Context, prg *types.Program, input string) Context {
180181
return callCtx
181182
}
182183

183-
func (c *Context) SubCall(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
184+
func (c *Context) SubCallContext(ctx context.Context, input, toolID, callID string, toolCategory ToolCategory) (Context, error) {
184185
tool, ok := c.Program.ToolSet[toolID]
185186
if !ok {
186187
return Context{}, fmt.Errorf("failed to file tool for id [%s]", toolID)

pkg/openai/client.go

+7-2
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@ package openai
22

33
import (
44
"context"
5-
"fmt"
65
"io"
76
"log/slog"
87
"os"
@@ -16,6 +15,7 @@ import (
1615
"github.com/gptscript-ai/gptscript/pkg/counter"
1716
"github.com/gptscript-ai/gptscript/pkg/credentials"
1817
"github.com/gptscript-ai/gptscript/pkg/hash"
18+
"github.com/gptscript-ai/gptscript/pkg/mvl"
1919
"github.com/gptscript-ai/gptscript/pkg/prompt"
2020
"github.com/gptscript-ai/gptscript/pkg/system"
2121
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -29,6 +29,7 @@ const (
2929
var (
3030
key = os.Getenv("OPENAI_API_KEY")
3131
url = os.Getenv("OPENAI_BASE_URL")
32+
log = mvl.Package()
3233
)
3334

3435
type InvalidAuthError struct{}
@@ -305,7 +306,11 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques
305306
}
306307

307308
if len(msgs) == 0 {
308-
return nil, fmt.Errorf("invalid request, no messages to send to LLM")
309+
log.Errorf("invalid request, no messages to send to LLM")
310+
return &types.CompletionMessage{
311+
Role: types.CompletionMessageRoleTypeAssistant,
312+
Content: types.Text(""),
313+
}, nil
309314
}
310315

311316
request := openai.ChatCompletionRequest{

pkg/parser/parser.go

+9-2
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
105105
tool.Parameters.Export = append(tool.Parameters.Export, csv(value)...)
106106
case "tool", "tools":
107107
tool.Parameters.Tools = append(tool.Parameters.Tools, csv(value)...)
108+
case "inputfilter", "inputfilters":
109+
tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...)
110+
case "shareinputfilter", "shareinputfilters":
111+
tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...)
108112
case "agent", "agents":
109113
tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...)
110114
case "globaltool", "globaltools":
@@ -183,10 +187,13 @@ type context struct {
183187

184188
func (c *context) finish(tools *[]Node) {
185189
c.tool.Instructions = strings.TrimSpace(strings.Join(c.instructions, ""))
186-
if c.tool.Instructions != "" || c.tool.Parameters.Name != "" ||
187-
len(c.tool.Export) > 0 || len(c.tool.Tools) > 0 ||
190+
if c.tool.Instructions != "" ||
191+
c.tool.Parameters.Name != "" ||
192+
len(c.tool.Export) > 0 ||
193+
len(c.tool.Tools) > 0 ||
188194
c.tool.GlobalModelName != "" ||
189195
len(c.tool.GlobalTools) > 0 ||
196+
len(c.tool.ExportInputFilters) > 0 ||
190197
c.tool.Chat {
191198
*tools = append(*tools, Node{
192199
ToolNode: &ToolNode{

pkg/parser/parser_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -191,3 +191,27 @@ name: bad
191191
},
192192
}}).Equal(t, out)
193193
}
194+
195+
func TestParseInput(t *testing.T) {
196+
input := `
197+
input filters: input
198+
share input filters: shared
199+
`
200+
out, err := Parse(strings.NewReader(input))
201+
require.NoError(t, err)
202+
autogold.Expect(Document{Nodes: []Node{
203+
{ToolNode: &ToolNode{
204+
Tool: types.Tool{
205+
ToolDef: types.ToolDef{
206+
Parameters: types.Parameters{
207+
InputFilters: []string{
208+
"input",
209+
},
210+
ExportInputFilters: []string{"shared"},
211+
},
212+
},
213+
Source: types.ToolSource{LineNo: 1},
214+
},
215+
}},
216+
}}).Equal(t, out)
217+
}

pkg/runner/input.go

+27
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
package runner
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/gptscript-ai/gptscript/pkg/engine"
7+
)
8+
9+
func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []string, input string) (string, error) {
10+
inputToolRefs, err := callCtx.Tool.GetInputFilterTools(*callCtx.Program)
11+
if err != nil {
12+
return "", err
13+
}
14+
15+
for _, inputToolRef := range inputToolRefs {
16+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, input, "", engine.InputToolCategory)
17+
if err != nil {
18+
return "", err
19+
}
20+
if res.Result == nil {
21+
return "", fmt.Errorf("invalid state: input tool [%s] can not result in a chat continuation", inputToolRef.Reference)
22+
}
23+
input = *res.Result
24+
}
25+
26+
return input, nil
27+
}

pkg/runner/runner.go

+23-11
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,11 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
408408
Content: input,
409409
})
410410

411+
input, err := r.handleInput(callCtx, monitor, env, input)
412+
if err != nil {
413+
return nil, err
414+
}
415+
411416
if len(callCtx.Tool.Credentials) > 0 {
412417
var err error
413418
env, err = r.handleCredentials(callCtx, monitor, env)
@@ -417,7 +422,6 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
417422
}
418423

419424
var (
420-
err error
421425
newState *State
422426
)
423427
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
@@ -446,7 +450,10 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
446450
}
447451

448452
if !authResp.Accept {
449-
msg := fmt.Sprintf("[AUTHORIZATION ERROR]: %s", authResp.Message)
453+
msg := authResp.Message
454+
if msg == "" {
455+
msg = "Tool call request has been denied"
456+
}
450457
return &State{
451458
Continuation: &engine.Return{
452459
Result: &msg,
@@ -631,8 +638,12 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
631638
}
632639

633640
if state.ResumeInput != nil {
641+
input, err := r.handleInput(callCtx, monitor, env, *state.ResumeInput)
642+
if err != nil {
643+
return state, err
644+
}
634645
engineResults = append(engineResults, engine.CallResult{
635-
User: *state.ResumeInput,
646+
User: input,
636647
})
637648
}
638649

@@ -689,16 +700,22 @@ func streamProgress(callCtx *engine.Context, monitor Monitor) (chan<- types.Comp
689700
}
690701

691702
func (r *Runner) subCall(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, input, callID string, toolCategory engine.ToolCategory) (*State, error) {
692-
callCtx, err := parentContext.SubCall(ctx, input, toolID, callID, toolCategory)
703+
callCtx, err := parentContext.SubCallContext(ctx, input, toolID, callID, toolCategory)
693704
if err != nil {
694705
return nil, err
695706
}
696707

708+
if toolCategory == engine.ContextToolCategory && callCtx.Tool.IsNoop() {
709+
return &State{
710+
Result: new(string),
711+
}, nil
712+
}
713+
697714
return r.call(callCtx, monitor, env, input)
698715
}
699716

700717
func (r *Runner) subCallResume(ctx context.Context, parentContext engine.Context, monitor Monitor, env []string, toolID, callID string, state *State, toolCategory engine.ToolCategory) (*State, error) {
701-
callCtx, err := parentContext.SubCall(ctx, "", toolID, callID, toolCategory)
718+
callCtx, err := parentContext.SubCallContext(ctx, "", toolID, callID, toolCategory)
702719
if err != nil {
703720
return nil, err
704721
}
@@ -882,12 +899,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
882899
input = string(inputBytes)
883900
}
884901

885-
subCtx, err := callCtx.SubCall(callCtx.Ctx, input, credToolRefs[0].ToolID, "", engine.CredentialToolCategory) // leaving callID as "" will cause it to be set by the engine
886-
if err != nil {
887-
return nil, fmt.Errorf("failed to create subcall context for tool %s: %w", credToolName, err)
888-
}
889-
890-
res, err := r.call(subCtx, monitor, env, input)
902+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, credToolRefs[0].ToolID, input, "", engine.CredentialToolCategory)
891903
if err != nil {
892904
return nil, fmt.Errorf("failed to run credential tool %s: %w", credToolName, err)
893905
}

pkg/tests/runner_test.go

+21
Original file line numberDiff line numberDiff line change
@@ -822,3 +822,24 @@ func TestAgents(t *testing.T) {
822822
autogold.Expect("TEST RESULT CALL: 4").Equal(t, resp.Content)
823823
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))
824824
}
825+
826+
func TestInput(t *testing.T) {
827+
r := tester.NewRunner(t)
828+
829+
prg, err := r.Load("")
830+
require.NoError(t, err)
831+
832+
resp, err := r.Chat(context.Background(), nil, prg, nil, "You're stupid")
833+
require.NoError(t, err)
834+
r.AssertResponded(t)
835+
assert.False(t, resp.Done)
836+
autogold.Expect("TEST RESULT CALL: 1").Equal(t, resp.Content)
837+
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))
838+
839+
resp, err = r.Chat(context.Background(), resp.State, prg, nil, "You're ugly")
840+
require.NoError(t, err)
841+
r.AssertResponded(t)
842+
assert.False(t, resp.Done)
843+
autogold.Expect("TEST RESULT CALL: 2").Equal(t, resp.Content)
844+
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))
845+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
`{
2+
"role": "assistant",
3+
"content": [
4+
{
5+
"text": "TEST RESULT CALL: 1"
6+
}
7+
],
8+
"usage": {}
9+
}`
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
`{
2+
"model": "gpt-4o",
3+
"internalSystemPrompt": false,
4+
"messages": [
5+
{
6+
"role": "system",
7+
"content": [
8+
{
9+
"text": "\nTool body"
10+
}
11+
],
12+
"usage": {}
13+
},
14+
{
15+
"role": "user",
16+
"content": [
17+
{
18+
"text": "No, You're stupid!\n ha ha ha\n"
19+
}
20+
],
21+
"usage": {}
22+
}
23+
]
24+
}`
+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
`{
2+
"done": false,
3+
"content": "TEST RESULT CALL: 1",
4+
"toolID": "testdata/TestInput/test.gpt:",
5+
"state": {
6+
"continuation": {
7+
"state": {
8+
"input": "No, You're stupid!\n ha ha ha\n",
9+
"completion": {
10+
"model": "gpt-4o",
11+
"internalSystemPrompt": false,
12+
"messages": [
13+
{
14+
"role": "system",
15+
"content": [
16+
{
17+
"text": "\nTool body"
18+
}
19+
],
20+
"usage": {}
21+
},
22+
{
23+
"role": "user",
24+
"content": [
25+
{
26+
"text": "No, You're stupid!\n ha ha ha\n"
27+
}
28+
],
29+
"usage": {}
30+
},
31+
{
32+
"role": "assistant",
33+
"content": [
34+
{
35+
"text": "TEST RESULT CALL: 1"
36+
}
37+
],
38+
"usage": {}
39+
}
40+
]
41+
}
42+
},
43+
"result": "TEST RESULT CALL: 1"
44+
},
45+
"continuationToolID": "testdata/TestInput/test.gpt:"
46+
}
47+
}`

pkg/tests/testdata/TestInput/test.gpt

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
input filter: taunt
2+
context: exporter
3+
chat: true
4+
5+
Tool body
6+
7+
---
8+
name: taunt
9+
args: foo: this is useless
10+
#!/bin/bash
11+
12+
echo "No, ${GPTSCRIPT_INPUT}!"
13+
14+
---
15+
name: exporter
16+
share input filters: taunt2
17+
18+
---
19+
name: taunt2
20+
args: foo: this is useless
21+
22+
#!/bin/bash
23+
echo "${GPTSCRIPT_INPUT} ha ha ha"

pkg/types/set.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ func (t *toolRefSet) HasTool(toolID string) bool {
2929
}
3030

3131
func (t *toolRefSet) AddAll(values []ToolReference, err error) {
32-
if t.err != nil {
32+
if err != nil {
3333
t.err = err
3434
}
3535
for _, v := range values {

0 commit comments

Comments
 (0)