Skip to content

Commit f47e2bc

Browse files
yash025Yashwanth H L
and
Yashwanth H L
authored
fix: Use detached context for SSE message handling (#244)
* fix: Use detached context for SSE message handling Prevents premature cancellation of message processing when HTTP request ends. * test for message processing when we return early to the client * rename variable --------- Co-authored-by: Yashwanth H L <[email protected]>
1 parent a999079 commit f47e2bc

File tree

2 files changed

+94
-3
lines changed

2 files changed

+94
-3
lines changed

server/sse.go

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -465,10 +465,19 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
465465
return
466466
}
467467

468+
// Create a context that preserves all values from parent ctx but won't be canceled when the parent is canceled.
469+
// this is required because the http ctx will be canceled when the client disconnects
470+
detachedCtx := context.WithoutCancel(ctx)
471+
468472
// quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
469473
w.WriteHeader(http.StatusAccepted)
470474

471-
go func() {
475+
// Create a new context for handling the message that will be canceled when the message handling is done
476+
messageCtx, cancel := context.WithCancel(detachedCtx)
477+
478+
go func(ctx context.Context) {
479+
defer cancel()
480+
// Use the context that will be canceled when session is done
472481
// Process message through MCPServer
473482
response := s.server.HandleMessage(ctx, rawMessage)
474483
// Only send response if there is one (not for notifications)
@@ -493,7 +502,7 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
493502
log.Printf("Event queue full for session %s", sessionID)
494503
}
495504
}
496-
}()
505+
}(messageCtx)
497506
}
498507

499508
// writeJSONRPCError writes a JSON-RPC error response with the given error details.

server/sse_test.go

Lines changed: 83 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,6 @@ func TestSSEServer(t *testing.T) {
203203
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
204204
)
205205

206-
fmt.Printf("========> %v", respFromSee)
207206
var response map[string]interface{}
208207
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
209208
t.Errorf(
@@ -1318,6 +1317,89 @@ func TestSSEServer(t *testing.T) {
13181317
t.Errorf("Expected id to be null")
13191318
}
13201319
})
1320+
1321+
t.Run("Message processing continues after we return back result to client", func(t *testing.T) {
1322+
mcpServer := NewMCPServer("test", "1.0.0")
1323+
1324+
processingCompleted := make(chan struct{})
1325+
processingStarted := make(chan struct{})
1326+
1327+
mcpServer.AddTool(mcp.NewTool("slowMethod"), func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
1328+
close(processingStarted) // signal for processing started
1329+
1330+
select {
1331+
case <-ctx.Done(): // If this happens, the test will fail because processingCompleted won't be closed
1332+
return nil, fmt.Errorf("context was canceled")
1333+
case <-time.After(1 * time.Second): // Simulate processing time
1334+
// Successfully completed processing, now close the completed channel to signal completion
1335+
close(processingCompleted)
1336+
return &mcp.CallToolResult{
1337+
Content: []mcp.Content{
1338+
mcp.TextContent{
1339+
Type: "text",
1340+
Text: "success",
1341+
},
1342+
},
1343+
}, nil
1344+
}
1345+
})
1346+
1347+
testServer := NewTestServer(mcpServer)
1348+
defer testServer.Close()
1349+
1350+
sseResp, err := http.Get(fmt.Sprintf("%s/sse", testServer.URL))
1351+
require.NoError(t, err, "Failed to connect to SSE endpoint")
1352+
defer sseResp.Body.Close()
1353+
1354+
endpointEvent, err := readSSEEvent(sseResp)
1355+
require.NoError(t, err, "Failed to read SSE response")
1356+
require.Contains(t, endpointEvent, "event: endpoint", "Expected endpoint event")
1357+
1358+
messageURL := strings.TrimSpace(
1359+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
1360+
)
1361+
1362+
messageRequest := map[string]interface{}{
1363+
"jsonrpc": "2.0",
1364+
"id": 1,
1365+
"method": "tools/call",
1366+
"params": map[string]interface{}{
1367+
"name": "slowMethod",
1368+
"parameters": map[string]interface{}{},
1369+
},
1370+
}
1371+
1372+
requestBody, err := json.Marshal(messageRequest)
1373+
require.NoError(t, err, "Failed to marshal request")
1374+
1375+
ctx, cancel := context.WithCancel(context.Background())
1376+
req, err := http.NewRequestWithContext(ctx, "POST", messageURL, bytes.NewBuffer(requestBody))
1377+
require.NoError(t, err, "Failed to create request")
1378+
req.Header.Set("Content-Type", "application/json")
1379+
1380+
client := &http.Client{}
1381+
resp, err := client.Do(req)
1382+
require.NoError(t, err, "Failed to send message")
1383+
defer resp.Body.Close()
1384+
1385+
require.Equal(t, http.StatusAccepted, resp.StatusCode, "Expected status 202 Accepted")
1386+
1387+
// Wait for processing to start
1388+
select {
1389+
case <-processingStarted: // Processing has started, now cancel the client context to simulate disconnection
1390+
case <-time.After(2 * time.Second):
1391+
t.Fatal("Timed out waiting for processing to start")
1392+
}
1393+
1394+
cancel() // cancel the client context to simulate disconnection
1395+
1396+
// wait for processing to complete, if the test passes, it means the processing continued despite client disconnection
1397+
select {
1398+
case <-processingCompleted:
1399+
case <-time.After(2 * time.Second):
1400+
t.Fatal("Processing did not complete after client disconnection")
1401+
}
1402+
})
13211403
}
13221404

13231405
func readSSEEvent(sseResp *http.Response) (string, error) {

0 commit comments

Comments
 (0)