Skip to content

Commit e27d5fe

Browse files
CeerDecyadlternative
authored andcommitted
feat: quick return tool-call request, send response via SSE in goroutine (mark3labs#163)
* feat: quick return tool-call request, send response via SSE in goroutine * update: comment * feat: handle JSON marshal errors and add logging when queue is full * feat: send a generic error response if marshal response has error
1 parent 31cea2e commit e27d5fe

File tree

2 files changed

+70
-47
lines changed

2 files changed

+70
-47
lines changed

server/sse.go

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"fmt"
7+
"log"
78
"net/http"
89
"net/http/httptest"
910
"net/url"
@@ -364,7 +365,7 @@ func (s *SSEServer) GetMessageEndpointForClient(sessionID string) string {
364365
}
365366

366367
// handleMessage processes incoming JSON-RPC messages from clients and sends responses
367-
// back through both the SSE connection and HTTP response.
368+
// back through the SSE connection and 202 code to HTTP response.
368369
func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
369370
if r.Method != http.MethodPost {
370371
s.writeJSONRPCError(w, nil, mcp.INVALID_REQUEST, "Method not allowed")
@@ -396,31 +397,37 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
396397
return
397398
}
398399

399-
// Process message through MCPServer
400-
response := s.server.HandleMessage(ctx, rawMessage)
400+
// quick return request, send 202 Accepted with no body, then deal the message and sent response via SSE
401+
w.WriteHeader(http.StatusAccepted)
401402

402-
// Only send response if there is one (not for notifications)
403-
if response != nil {
404-
eventData, _ := json.Marshal(response)
403+
go func() {
404+
// Process message through MCPServer
405+
response := s.server.HandleMessage(ctx, rawMessage)
406+
407+
// Only send response if there is one (not for notifications)
408+
if response != nil {
409+
var message string
410+
if eventData, err := json.Marshal(response); err != nil {
411+
// If there is an error marshalling the response, send a generic error response
412+
log.Printf("failed to marshal response: %v", err)
413+
message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n")
414+
return
415+
} else {
416+
message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
417+
}
405418

406-
// Queue the event for sending via SSE
407-
select {
408-
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
409-
// Event queued successfully
410-
case <-session.done:
411-
// Session is closed, don't try to queue
412-
default:
413-
// Queue is full, could log this
419+
// Queue the event for sending via SSE
420+
select {
421+
case session.eventQueue <- message:
422+
// Event queued successfully
423+
case <-session.done:
424+
// Session is closed, don't try to queue
425+
default:
426+
// Queue is full, log this situation
427+
log.Printf("Event queue full for session %s", sessionID)
428+
}
414429
}
415-
416-
// Send HTTP response
417-
w.Header().Set("Content-Type", "application/json")
418-
w.WriteHeader(http.StatusAccepted)
419-
json.NewEncoder(w).Encode(response)
420-
} else {
421-
// For notifications, just send 202 Accepted with no body
422-
w.WriteHeader(http.StatusAccepted)
423-
}
430+
}()
424431
}
425432

426433
// writeJSONRPCError writes a JSON-RPC error response with the given error details.

server/sse_test.go

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -61,13 +61,10 @@ func TestSSEServer(t *testing.T) {
6161
defer sseResp.Body.Close()
6262

6363
// Read the endpoint event
64-
buf := make([]byte, 1024)
65-
n, err := sseResp.Body.Read(buf)
64+
endpointEvent, err := readSeeEvent(sseResp)
6665
if err != nil {
6766
t.Fatalf("Failed to read SSE response: %v", err)
6867
}
69-
70-
endpointEvent := string(buf[:n])
7168
if !strings.Contains(endpointEvent, "event: endpoint") {
7269
t.Fatalf("Expected endpoint event, got: %s", endpointEvent)
7370
}
@@ -109,19 +106,6 @@ func TestSSEServer(t *testing.T) {
109106
if resp.StatusCode != http.StatusAccepted {
110107
t.Errorf("Expected status 202, got %d", resp.StatusCode)
111108
}
112-
113-
// Verify response
114-
var response map[string]interface{}
115-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
116-
t.Fatalf("Failed to decode response: %v", err)
117-
}
118-
119-
if response["jsonrpc"] != "2.0" {
120-
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
121-
}
122-
if response["id"].(float64) != 1 {
123-
t.Errorf("Expected id 1, got %v", response["id"])
124-
}
125109
})
126110

127111
t.Run("Can handle multiple sessions", func(t *testing.T) {
@@ -210,8 +194,17 @@ func TestSSEServer(t *testing.T) {
210194
}
211195
defer resp.Body.Close()
212196

197+
endpointEvent, err = readSeeEvent(sseResp)
198+
if err != nil {
199+
t.Fatalf("Failed to read SSE response: %v", err)
200+
}
201+
respFromSee := strings.TrimSpace(
202+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
203+
)
204+
205+
fmt.Printf("========> %v", respFromSee)
213206
var response map[string]interface{}
214-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
207+
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
215208
t.Errorf(
216209
"Session %d: Failed to decode response: %v",
217210
sessionNum,
@@ -588,13 +581,10 @@ func TestSSEServer(t *testing.T) {
588581
defer sseResp.Body.Close()
589582

590583
// Read the endpoint event
591-
buf := make([]byte, 1024)
592-
n, err := sseResp.Body.Read(buf)
584+
endpointEvent, err := readSeeEvent(sseResp)
593585
if err != nil {
594586
t.Fatalf("Failed to read SSE response: %v", err)
595587
}
596-
597-
endpointEvent := string(buf[:n])
598588
messageURL := strings.TrimSpace(
599589
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
600590
)
@@ -634,8 +624,16 @@ func TestSSEServer(t *testing.T) {
634624
}
635625

636626
// Verify response
627+
endpointEvent, err = readSeeEvent(sseResp)
628+
if err != nil {
629+
t.Fatalf("Failed to read SSE response: %v", err)
630+
}
631+
respFromSee := strings.TrimSpace(
632+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
633+
)
634+
637635
var response map[string]interface{}
638-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
636+
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
639637
t.Fatalf("Failed to decode response: %v", err)
640638
}
641639

@@ -673,8 +671,17 @@ func TestSSEServer(t *testing.T) {
673671
}
674672
defer resp.Body.Close()
675673

674+
endpointEvent, err = readSeeEvent(sseResp)
675+
if err != nil {
676+
t.Fatalf("Failed to read SSE response: %v", err)
677+
}
678+
679+
respFromSee = strings.TrimSpace(
680+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
681+
)
682+
676683
response = make(map[string]interface{})
677-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
684+
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
678685
t.Fatalf("Failed to decode response: %v", err)
679686
}
680687

@@ -855,3 +862,12 @@ func TestSSEServer(t *testing.T) {
855862
}
856863
})
857864
}
865+
866+
func readSeeEvent(sseResp *http.Response) (string, error) {
867+
buf := make([]byte, 1024)
868+
n, err := sseResp.Body.Read(buf)
869+
if err != nil {
870+
return "", err
871+
}
872+
return string(buf[:n]), nil
873+
}

0 commit comments

Comments
 (0)