Skip to content

Commit a1158cc

Browse files
zalcit朱庆超crazywoola
authored
fix: Update prompt message content types to use Literal and add union type for content (#17136)
Co-authored-by: 朱庆超 <[email protected]> Co-authored-by: crazywoola <[email protected]>
1 parent 404f8a7 commit a1158cc

File tree

10 files changed

+73
-39
lines changed

10 files changed

+73
-39
lines changed

api/core/agent/base_agent_runner.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,13 @@
2121
AssistantPromptMessage,
2222
LLMUsage,
2323
PromptMessage,
24-
PromptMessageContent,
2524
PromptMessageTool,
2625
SystemPromptMessage,
2726
TextPromptMessageContent,
2827
ToolPromptMessage,
2928
UserPromptMessage,
3029
)
31-
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
30+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
3231
from core.model_runtime.entities.model_entities import ModelFeature
3332
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
3433
from core.prompt.utils.extract_thread_messages import extract_thread_messages
@@ -501,7 +500,7 @@ def organize_agent_user_prompt(self, message: Message) -> UserPromptMessage:
501500
)
502501
if not file_objs:
503502
return UserPromptMessage(content=message.query)
504-
prompt_message_contents: list[PromptMessageContent] = []
503+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
505504
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
506505
for file in file_objs:
507506
prompt_message_contents.append(

api/core/agent/cot_chat_agent_runner.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,11 @@
55
from core.model_runtime.entities import (
66
AssistantPromptMessage,
77
PromptMessage,
8-
PromptMessageContent,
98
SystemPromptMessage,
109
TextPromptMessageContent,
1110
UserPromptMessage,
1211
)
13-
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
12+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
1413
from core.model_runtime.utils.encoders import jsonable_encoder
1514

1615

@@ -40,7 +39,7 @@ def _organize_user_query(self, query, prompt_messages: list[PromptMessage]) -> l
4039
Organize user query
4140
"""
4241
if self.files:
43-
prompt_message_contents: list[PromptMessageContent] = []
42+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
4443
prompt_message_contents.append(TextPromptMessageContent(data=query))
4544

4645
# get image detail config

api/core/agent/fc_agent_runner.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,13 @@
1515
LLMResultChunkDelta,
1616
LLMUsage,
1717
PromptMessage,
18-
PromptMessageContent,
1918
PromptMessageContentType,
2019
SystemPromptMessage,
2120
TextPromptMessageContent,
2221
ToolPromptMessage,
2322
UserPromptMessage,
2423
)
25-
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
24+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
2625
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
2726
from core.tools.entities.tool_entities import ToolInvokeMeta
2827
from core.tools.tool_engine import ToolEngine
@@ -395,7 +394,7 @@ def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage])
395394
Organize user query
396395
"""
397396
if self.files:
398-
prompt_message_contents: list[PromptMessageContent] = []
397+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
399398
prompt_message_contents.append(TextPromptMessageContent(data=query))
400399

401400
# get image detail config

api/core/file/file_manager.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
AudioPromptMessageContent,
88
DocumentPromptMessageContent,
99
ImagePromptMessageContent,
10-
MultiModalPromptMessageContent,
1110
VideoPromptMessageContent,
1211
)
12+
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
1313
from extensions.ext_storage import storage
1414

1515
from . import helpers
@@ -43,7 +43,7 @@ def to_prompt_message_content(
4343
/,
4444
*,
4545
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
46-
) -> MultiModalPromptMessageContent:
46+
) -> PromptMessageContentUnionTypes:
4747
if f.extension is None:
4848
raise ValueError("Missing file extension")
4949
if f.mime_type is None:
@@ -58,7 +58,7 @@ def to_prompt_message_content(
5858
if f.type == FileType.IMAGE:
5959
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
6060

61-
prompt_class_map: Mapping[FileType, type[MultiModalPromptMessageContent]] = {
61+
prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = {
6262
FileType.IMAGE: ImagePromptMessageContent,
6363
FileType.AUDIO: AudioPromptMessageContent,
6464
FileType.VIDEO: VideoPromptMessageContent,

api/core/memory/token_buffer_memory.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,11 @@
88
AssistantPromptMessage,
99
ImagePromptMessageContent,
1010
PromptMessage,
11-
PromptMessageContent,
1211
PromptMessageRole,
1312
TextPromptMessageContent,
1413
UserPromptMessage,
1514
)
15+
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
1616
from core.prompt.utils.extract_thread_messages import extract_thread_messages
1717
from extensions.ext_database import db
1818
from factories import file_factory
@@ -100,7 +100,7 @@ def get_history_prompt_messages(
100100
if not file_objs:
101101
prompt_messages.append(UserPromptMessage(content=message.query))
102102
else:
103-
prompt_message_contents: list[PromptMessageContent] = []
103+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
104104
prompt_message_contents.append(TextPromptMessageContent(data=message.query))
105105
for file in file_objs:
106106
prompt_message = file_manager.to_prompt_message_content(

api/core/model_runtime/entities/message_entities.py

+20-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from collections.abc import Sequence
22
from enum import Enum, StrEnum
3-
from typing import Any, Optional, Union
3+
from typing import Annotated, Any, Literal, Optional, Union
44

55
from pydantic import BaseModel, Field, field_serializer, field_validator
66

@@ -61,19 +61,15 @@ class PromptMessageContentType(StrEnum):
6161

6262

6363
class PromptMessageContent(BaseModel):
64-
"""
65-
Model class for prompt message content.
66-
"""
67-
68-
type: PromptMessageContentType
64+
pass
6965

7066

7167
class TextPromptMessageContent(PromptMessageContent):
7268
"""
7369
Model class for text prompt message content.
7470
"""
7571

76-
type: PromptMessageContentType = PromptMessageContentType.TEXT
72+
type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT
7773
data: str
7874

7975

@@ -82,7 +78,6 @@ class MultiModalPromptMessageContent(PromptMessageContent):
8278
Model class for multi-modal prompt message content.
8379
"""
8480

85-
type: PromptMessageContentType
8681
format: str = Field(default=..., description="the format of multi-modal file")
8782
base64_data: str = Field(default="", description="the base64 data of multi-modal file")
8883
url: str = Field(default="", description="the url of multi-modal file")
@@ -94,11 +89,11 @@ def data(self):
9489

9590

9691
class VideoPromptMessageContent(MultiModalPromptMessageContent):
97-
type: PromptMessageContentType = PromptMessageContentType.VIDEO
92+
type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO
9893

9994

10095
class AudioPromptMessageContent(MultiModalPromptMessageContent):
101-
type: PromptMessageContentType = PromptMessageContentType.AUDIO
96+
type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO
10297

10398

10499
class ImagePromptMessageContent(MultiModalPromptMessageContent):
@@ -110,12 +105,24 @@ class DETAIL(StrEnum):
110105
LOW = "low"
111106
HIGH = "high"
112107

113-
type: PromptMessageContentType = PromptMessageContentType.IMAGE
108+
type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE
114109
detail: DETAIL = DETAIL.LOW
115110

116111

117112
class DocumentPromptMessageContent(MultiModalPromptMessageContent):
118-
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
113+
type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT
114+
115+
116+
PromptMessageContentUnionTypes = Annotated[
117+
Union[
118+
TextPromptMessageContent,
119+
ImagePromptMessageContent,
120+
DocumentPromptMessageContent,
121+
AudioPromptMessageContent,
122+
VideoPromptMessageContent,
123+
],
124+
Field(discriminator="type"),
125+
]
119126

120127

121128
class PromptMessage(BaseModel):
@@ -124,7 +131,7 @@ class PromptMessage(BaseModel):
124131
"""
125132

126133
role: PromptMessageRole
127-
content: Optional[str | Sequence[PromptMessageContent]] = None
134+
content: Optional[str | list[PromptMessageContentUnionTypes]] = None
128135
name: Optional[str] = None
129136

130137
def is_empty(self) -> bool:

api/core/prompt/advanced_prompt_transform.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,12 @@
99
from core.model_runtime.entities import (
1010
AssistantPromptMessage,
1111
PromptMessage,
12-
PromptMessageContent,
1312
PromptMessageRole,
1413
SystemPromptMessage,
1514
TextPromptMessageContent,
1615
UserPromptMessage,
1716
)
18-
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
17+
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
1918
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
2019
from core.prompt.prompt_transform import PromptTransform
2120
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
@@ -125,7 +124,7 @@ def _get_completion_model_prompt_messages(
125124
prompt = Jinja2Formatter.format(prompt, prompt_inputs)
126125

127126
if files:
128-
prompt_message_contents: list[PromptMessageContent] = []
127+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
129128
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
130129
for file in files:
131130
prompt_message_contents.append(
@@ -201,7 +200,7 @@ def _get_chat_model_prompt_messages(
201200
prompt_messages = self._append_chat_histories(memory, memory_config, prompt_messages, model_config)
202201

203202
if files and query is not None:
204-
prompt_message_contents: list[PromptMessageContent] = []
203+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
205204
prompt_message_contents.append(TextPromptMessageContent(data=query))
206205
for file in files:
207206
prompt_message_contents.append(

api/core/prompt/simple_prompt_transform.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from core.model_runtime.entities.message_entities import (
1212
ImagePromptMessageContent,
1313
PromptMessage,
14-
PromptMessageContent,
14+
PromptMessageContentUnionTypes,
1515
SystemPromptMessage,
1616
TextPromptMessageContent,
1717
UserPromptMessage,
@@ -277,7 +277,7 @@ def _get_last_user_message(
277277
image_detail_config: Optional[ImagePromptMessageContent.DETAIL] = None,
278278
) -> UserPromptMessage:
279279
if files:
280-
prompt_message_contents: list[PromptMessageContent] = []
280+
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
281281
prompt_message_contents.append(TextPromptMessageContent(data=prompt))
282282
for file in files:
283283
prompt_message_contents.append(

api/core/workflow/nodes/llm/node.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
2525
from core.model_runtime.entities.message_entities import (
2626
AssistantPromptMessage,
27-
PromptMessageContent,
27+
PromptMessageContentUnionTypes,
2828
PromptMessageRole,
2929
SystemPromptMessage,
3030
UserPromptMessage,
@@ -594,8 +594,7 @@ def _fetch_prompt_messages(
594594
variable_pool: VariablePool,
595595
jinja2_variables: Sequence[VariableSelector],
596596
) -> tuple[Sequence[PromptMessage], Optional[Sequence[str]]]:
597-
# FIXME: fix the type error cause prompt_messages is type quick a few times
598-
prompt_messages: list[Any] = []
597+
prompt_messages: list[PromptMessage] = []
599598

600599
if isinstance(prompt_template, list):
601600
# For chat model
@@ -657,12 +656,14 @@ def _fetch_prompt_messages(
657656
# For issue #11247 - Check if prompt content is a string or a list
658657
prompt_content_type = type(prompt_content)
659658
if prompt_content_type == str:
659+
prompt_content = str(prompt_content)
660660
if "#histories#" in prompt_content:
661661
prompt_content = prompt_content.replace("#histories#", memory_text)
662662
else:
663663
prompt_content = memory_text + "\n" + prompt_content
664664
prompt_messages[0].content = prompt_content
665665
elif prompt_content_type == list:
666+
prompt_content = prompt_content if isinstance(prompt_content, list) else []
666667
for content_item in prompt_content:
667668
if content_item.type == PromptMessageContentType.TEXT:
668669
if "#histories#" in content_item.data:
@@ -675,9 +676,10 @@ def _fetch_prompt_messages(
675676
# Add current query to the prompt message
676677
if sys_query:
677678
if prompt_content_type == str:
678-
prompt_content = prompt_messages[0].content.replace("#sys.query#", sys_query)
679+
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
679680
prompt_messages[0].content = prompt_content
680681
elif prompt_content_type == list:
682+
prompt_content = prompt_content if isinstance(prompt_content, list) else []
681683
for content_item in prompt_content:
682684
if content_item.type == PromptMessageContentType.TEXT:
683685
content_item.data = sys_query + "\n" + content_item.data
@@ -707,7 +709,7 @@ def _fetch_prompt_messages(
707709
filtered_prompt_messages = []
708710
for prompt_message in prompt_messages:
709711
if isinstance(prompt_message.content, list):
710-
prompt_message_content = []
712+
prompt_message_content: list[PromptMessageContentUnionTypes] = []
711713
for content_item in prompt_message.content:
712714
# Skip content if features are not defined
713715
if not model_config.model_schema.features:
@@ -1132,7 +1134,9 @@ def _check_model_structured_output_support(self) -> SupportStructuredOutputStatu
11321134
)
11331135

11341136

1135-
def _combine_message_content_with_role(*, contents: Sequence[PromptMessageContent], role: PromptMessageRole):
1137+
def _combine_message_content_with_role(
1138+
*, contents: Optional[str | list[PromptMessageContentUnionTypes]] = None, role: PromptMessageRole
1139+
):
11361140
match role:
11371141
case PromptMessageRole.USER:
11381142
return UserPromptMessage(content=contents)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
from core.model_runtime.entities.message_entities import (
2+
ImagePromptMessageContent,
3+
TextPromptMessageContent,
4+
UserPromptMessage,
5+
)
6+
7+
8+
def test_build_prompt_message_with_prompt_message_contents():
9+
prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")])
10+
assert isinstance(prompt.content, list)
11+
assert isinstance(prompt.content[0], TextPromptMessageContent)
12+
assert prompt.content[0].data == "Hello, World!"
13+
14+
15+
def test_dump_prompt_message():
16+
example_url = "https://example.com/image.jpg"
17+
prompt = UserPromptMessage(
18+
content=[
19+
ImagePromptMessageContent(
20+
url=example_url,
21+
format="jpeg",
22+
mime_type="image/jpeg",
23+
)
24+
]
25+
)
26+
data = prompt.model_dump()
27+
assert data["content"][0].get("url") == example_url

0 commit comments

Comments
 (0)