Skip to content

Commit 6093748

Browse files
chore: allow image data to be in prompt input
1 parent f9c09f1 commit 6093748

File tree

3 files changed

+82
-5
lines changed

3 files changed

+82
-5
lines changed

pkg/openai/client.go

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,7 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
282282
chatMessage.ToolCalls = append(chatMessage.ToolCalls, toToolCall(*content.ToolCall))
283283
}
284284
if content.Text != "" {
285-
chatMessage.MultiContent = append(chatMessage.MultiContent, openai.ChatMessagePart{
286-
Type: openai.ChatMessagePartTypeText,
287-
Text: content.Text,
288-
})
285+
chatMessage.MultiContent = append(chatMessage.MultiContent, textToMultiContent(content.Text)...)
289286
}
290287
}
291288

@@ -307,6 +304,35 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C
307304
return
308305
}
309306

307+
const imagePrefix = "data:image/png;base64,"
308+
309+
func textToMultiContent(text string) []openai.ChatMessagePart {
310+
var chatParts []openai.ChatMessagePart
311+
parts := strings.Split(text, "\n")
312+
for i := len(parts) - 1; i >= 0; i-- {
313+
if strings.HasPrefix(parts[i], imagePrefix) {
314+
chatParts = append(chatParts, openai.ChatMessagePart{
315+
Type: openai.ChatMessagePartTypeImageURL,
316+
ImageURL: &openai.ChatMessageImageURL{
317+
URL: parts[i],
318+
},
319+
})
320+
parts = parts[:i]
321+
} else {
322+
break
323+
}
324+
}
325+
if len(parts) > 0 {
326+
chatParts = append(chatParts, openai.ChatMessagePart{
327+
Type: openai.ChatMessagePartTypeText,
328+
Text: strings.Join(parts, "\n"),
329+
})
330+
}
331+
332+
slices.Reverse(chatParts)
333+
return chatParts
334+
}
335+
310336
func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) {
311337
if err := c.ValidAuth(); err != nil {
312338
if err := c.RetrieveAPIKey(ctx, env); err != nil {

pkg/openai/client_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,44 @@ import (
99
"github.com/hexops/valast"
1010
)
1111

12+
func TestTextToMultiContent(t *testing.T) {
13+
autogold.Expect([]openai.ChatMessagePart{{
14+
Type: "text",
15+
Text: "hi\ndata:image/png;base64,xxxxx\n",
16+
}}).Equal(t, textToMultiContent("hi\ndata:image/png;base64,xxxxx\n"))
17+
18+
autogold.Expect([]openai.ChatMessagePart{
19+
{
20+
Type: "text",
21+
Text: "hi",
22+
},
23+
{
24+
Type: "image_url",
25+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"},
26+
},
27+
}).Equal(t, textToMultiContent("hi\ndata:image/png;base64,xxxxx"))
28+
29+
autogold.Expect([]openai.ChatMessagePart{{
30+
Type: "image_url",
31+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"},
32+
}}).Equal(t, textToMultiContent("data:image/png;base64,xxxxx"))
33+
34+
autogold.Expect([]openai.ChatMessagePart{
35+
{
36+
Type: "text",
37+
Text: "\none\ntwo",
38+
},
39+
{
40+
Type: "image_url",
41+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,xxxxx"},
42+
},
43+
{
44+
Type: "image_url",
45+
ImageURL: &openai.ChatMessageImageURL{URL: "data:image/png;base64,yyyyy"},
46+
},
47+
}).Equal(t, textToMultiContent("\none\ntwo\ndata:image/png;base64,xxxxx\ndata:image/png;base64,yyyyy"))
48+
}
49+
1250
func Test_appendMessage(t *testing.T) {
1351
autogold.Expect(types.CompletionMessage{Content: []types.ContentPart{
1452
{ToolCall: &types.CompletionToolCall{

pkg/runner/runner.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,6 +655,17 @@ func (r *Runner) newDispatcher(ctx context.Context) dispatcher {
655655
return newParallelDispatcher(ctx)
656656
}
657657

658+
func idForToolCall(id string, state *engine.Return) string {
659+
if state == nil || state.State == nil {
660+
return id
661+
}
662+
tc, ok := state.State.Pending[id]
663+
if !ok || tc.Index == nil {
664+
return id
665+
}
666+
return fmt.Sprintf("%03d", *tc.Index)
667+
}
668+
658669
func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string, state *State, toolCategory engine.ToolCategory) (*State, []SubCallResult, error) {
659670
var (
660671
resultLock sync.Mutex
@@ -698,7 +709,9 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
698709

699710
// Sort the id so if sequential the results are predictable
700711
ids := maps.Keys(state.Continuation.Calls)
701-
sort.Strings(ids)
712+
sort.Slice(ids, func(i, j int) bool {
713+
return idForToolCall(ids[i], state.Continuation) < idForToolCall(ids[j], state.Continuation)
714+
})
702715

703716
for _, id := range ids {
704717
call := state.Continuation.Calls[id]

0 commit comments

Comments
 (0)