@@ -41,6 +41,7 @@ async def handle_sse(request):
41
41
from anyio .streams .memory import MemoryObjectReceiveStream , MemoryObjectSendStream
42
42
from pydantic import ValidationError
43
43
from sse_starlette import EventSourceResponse
44
+ from starlette .background import BackgroundTask
44
45
from starlette .requests import Request
45
46
from starlette .responses import Response
46
47
from starlette .types import Receive , Scope , Send
@@ -78,6 +79,18 @@ def __init__(self, endpoint: str) -> None:
78
79
self ._read_stream_writers = {}
79
80
logger .debug (f"SseServerTransport initialized with endpoint: { endpoint } " )
80
81
82
+ async def _remove_stream_writer (self , session_id : UUID ) -> None :
83
+ """
84
+ Remove the SSE session with the given session ID.
85
+ """
86
+ logger .debug (f"Remove SSE session with ID: { session_id } " )
87
+ writer = self ._read_stream_writers .pop (session_id , None )
88
+ if writer :
89
+ await writer .aclose ()
90
+ logger .debug (f"Closed SSE session with ID: { session_id } " )
91
+ else :
92
+ logger .warning (f"Session ID { session_id } not found for removal" )
93
+
81
94
@asynccontextmanager
82
95
async def connect_sse (self , scope : Scope , receive : Receive , send : Send ):
83
96
if scope ["type" ] != "http" :
@@ -119,10 +132,11 @@ async def sse_writer():
119
132
),
120
133
}
121
134
)
122
-
135
+ background_task = BackgroundTask ( self . _remove_stream_writer , session_id )
123
136
async with anyio .create_task_group () as tg :
124
137
response = EventSourceResponse (
125
- content = sse_stream_reader , data_sender_callable = sse_writer
138
+ content = sse_stream_reader , data_sender_callable = sse_writer ,
139
+ background = background_task ,
126
140
)
127
141
logger .debug ("Starting SSE response task" )
128
142
tg .start_soon (response , scope , receive , send )
0 commit comments