Skip to content

Commit c1559a7

Browse files
authored
fix: LLMResultChunk cause concatenate str and list exception (#18852)
1 parent 993ef87 commit c1559a7

File tree

3 files changed

+25
-15
lines changed

3 files changed

+25
-15
lines changed

api/core/model_runtime/model_providers/__base/large_language_model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import time
33
import uuid
44
from collections.abc import Generator, Sequence
5-
from typing import Optional, Union
5+
from typing import Optional, Union, cast
66

77
from pydantic import ConfigDict
88

@@ -20,6 +20,7 @@
2020
PriceType,
2121
)
2222
from core.model_runtime.model_providers.__base.ai_model import AIModel
23+
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
2324
from core.plugin.manager.model import PluginModelManager
2425

2526
logger = logging.getLogger(__name__)
@@ -280,7 +281,9 @@ def _invoke_result_generator(
280281
callbacks=callbacks,
281282
)
282283

283-
assistant_message.content += chunk.delta.message.content
284+
text = convert_llm_result_chunk_to_str(chunk.delta.message.content)
285+
current_content = cast(str, assistant_message.content)
286+
assistant_message.content = current_content + text
284287
real_model = chunk.model
285288
if chunk.delta.usage:
286289
usage = chunk.delta.usage
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,27 @@
11
import pydantic
22
from pydantic import BaseModel
33

4+
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
5+
46

57
def dump_model(model: BaseModel) -> dict:
68
if hasattr(pydantic, "model_dump"):
79
# FIXME mypy error, try to fix it instead of using type: ignore
810
return pydantic.model_dump(model) # type: ignore
911
else:
1012
return model.model_dump()
13+
14+
15+
def convert_llm_result_chunk_to_str(content: None | str | list[PromptMessageContentUnionTypes]) -> str:
16+
if content is None:
17+
message_text = ""
18+
elif isinstance(content, str):
19+
message_text = content
20+
elif isinstance(content, list):
21+
# Assuming the list contains PromptMessageContent objects with a "data" attribute
22+
message_text = "".join(
23+
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
24+
)
25+
else:
26+
message_text = str(content)
27+
return message_text

api/core/workflow/nodes/llm/node.py

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
)
3939
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
4040
from core.model_runtime.utils.encoders import jsonable_encoder
41+
from core.model_runtime.utils.helper import convert_llm_result_chunk_to_str
4142
from core.plugin.entities.plugin import ModelProviderID
4243
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
4344
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@@ -269,18 +270,7 @@ def _invoke_llm(
269270

270271
def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generator[NodeEvent, None, None]:
271272
if isinstance(invoke_result, LLMResult):
272-
content = invoke_result.message.content
273-
if content is None:
274-
message_text = ""
275-
elif isinstance(content, str):
276-
message_text = content
277-
elif isinstance(content, list):
278-
# Assuming the list contains PromptMessageContent objects with a "data" attribute
279-
message_text = "".join(
280-
item.data if hasattr(item, "data") and isinstance(item.data, str) else str(item) for item in content
281-
)
282-
else:
283-
message_text = str(content)
273+
message_text = convert_llm_result_chunk_to_str(invoke_result.message.content)
284274

285275
yield ModelInvokeCompletedEvent(
286276
text=message_text,
@@ -295,7 +285,7 @@ def _handle_invoke_result(self, invoke_result: LLMResult | Generator) -> Generat
295285
usage = None
296286
finish_reason = None
297287
for result in invoke_result:
298-
text = result.delta.message.content
288+
text = convert_llm_result_chunk_to_str(result.delta.message.content)
299289
full_text += text
300290

301291
yield RunStreamChunkEvent(chunk_content=text, from_variable_selector=[self.node_id, "text"])

0 commit comments

Comments
 (0)