Skip to content

Commit c91045a

Browse files
authored
fix(fail-branch): prevent streaming output in exception branches (#17153)
1 parent 44cdb3d commit c91045a

File tree

2 files changed

+73
-7
lines changed

2 files changed

+73
-7
lines changed

api/core/workflow/nodes/answer/answer_stream_processor.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,28 @@ def _get_stream_out_answer_node_ids(self, event: NodeRunStreamChunkEvent) -> lis
155155
for answer_node_id, route_position in self.route_position.items():
156156
if answer_node_id not in self.rest_node_ids:
157157
continue
158-
# exclude current node id
158+
# Remove current node id from answer dependencies to support stream output if it is a success branch
159159
answer_dependencies = self.generate_routes.answer_dependencies
160-
if event.node_id in answer_dependencies[answer_node_id]:
160+
edge_mapping = self.graph.edge_mapping.get(event.node_id)
161+
success_edge = (
162+
next(
163+
(
164+
edge
165+
for edge in edge_mapping
166+
if edge.run_condition
167+
and edge.run_condition.type == "branch_identify"
168+
and edge.run_condition.branch_identify == "success-branch"
169+
),
170+
None,
171+
)
172+
if edge_mapping
173+
else None
174+
)
175+
if (
176+
event.node_id in answer_dependencies[answer_node_id]
177+
and success_edge
178+
and success_edge.target_node_id == answer_node_id
179+
):
161180
answer_dependencies[answer_node_id].remove(event.node_id)
162181
answer_dependencies_ids = answer_dependencies.get(answer_node_id, [])
163182
# all depends on answer node id not in rest node ids

api/tests/unit_tests/core/workflow/nodes/test_continue_on_error.py

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
1+
from unittest.mock import patch
2+
13
from core.app.entities.app_invoke_entities import InvokeFrom
4+
from core.workflow.entities.node_entities import NodeRunMetadataKey, NodeRunResult
25
from core.workflow.enums import SystemVariableKey
36
from core.workflow.graph_engine.entities.event import (
47
GraphRunPartialSucceededEvent,
58
NodeRunExceptionEvent,
9+
NodeRunFailedEvent,
610
NodeRunStreamChunkEvent,
711
)
812
from core.workflow.graph_engine.entities.graph import Graph
913
from core.workflow.graph_engine.graph_engine import GraphEngine
14+
from core.workflow.nodes.event.event import RunCompletedEvent, RunStreamChunkEvent
15+
from core.workflow.nodes.llm.node import LLMNode
1016
from models.enums import UserFrom
11-
from models.workflow import WorkflowType
17+
from models.workflow import WorkflowNodeExecutionStatus, WorkflowType
1218

1319

1420
class ContinueOnErrorTestHelper:
@@ -492,10 +498,7 @@ def test_no_node_in_fail_branch_continue_on_error():
492498
"edges": FAIL_BRANCH_EDGES[:-1],
493499
"nodes": [
494500
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
495-
{
496-
"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"},
497-
"id": "success",
498-
},
501+
{"data": {"title": "success", "type": "answer", "answer": "HTTP request successful"}, "id": "success"},
499502
ContinueOnErrorTestHelper.get_http_node(),
500503
],
501504
}
@@ -506,3 +509,47 @@ def test_no_node_in_fail_branch_continue_on_error():
506509
assert any(isinstance(e, NodeRunExceptionEvent) for e in events)
507510
assert any(isinstance(e, GraphRunPartialSucceededEvent) and e.outputs == {} for e in events)
508511
assert sum(1 for e in events if isinstance(e, NodeRunStreamChunkEvent)) == 0
512+
513+
514+
def test_stream_output_with_fail_branch_continue_on_error():
515+
"""Test stream output with fail-branch error strategy"""
516+
graph_config = {
517+
"edges": FAIL_BRANCH_EDGES,
518+
"nodes": [
519+
{"data": {"title": "Start", "type": "start", "variables": []}, "id": "start"},
520+
{
521+
"data": {"title": "success", "type": "answer", "answer": "LLM request successful"},
522+
"id": "success",
523+
},
524+
{
525+
"data": {"title": "error", "type": "answer", "answer": "{{#node.text#}}"},
526+
"id": "error",
527+
},
528+
ContinueOnErrorTestHelper.get_llm_node(),
529+
],
530+
}
531+
graph_engine = ContinueOnErrorTestHelper.create_test_graph_engine(graph_config)
532+
533+
def llm_generator(self):
534+
contents = ["hi", "bye", "good morning"]
535+
536+
yield RunStreamChunkEvent(chunk_content=contents[0], from_variable_selector=[self.node_id, "text"])
537+
538+
yield RunCompletedEvent(
539+
run_result=NodeRunResult(
540+
status=WorkflowNodeExecutionStatus.SUCCEEDED,
541+
inputs={},
542+
process_data={},
543+
outputs={},
544+
metadata={
545+
NodeRunMetadataKey.TOTAL_TOKENS: 1,
546+
NodeRunMetadataKey.TOTAL_PRICE: 1,
547+
NodeRunMetadataKey.CURRENCY: "USD",
548+
},
549+
)
550+
)
551+
552+
with patch.object(LLMNode, "_run", new=llm_generator):
553+
events = list(graph_engine.run())
554+
assert sum(isinstance(e, NodeRunStreamChunkEvent) for e in events) == 1
555+
assert all(not isinstance(e, NodeRunFailedEvent | NodeRunExceptionEvent) for e in events)

0 commit comments

Comments
 (0)