Skip to content

Commit cde2b2a

Browse files
committed
feat: send a generic error response if marshal response has error
1 parent 5bd00ad commit cde2b2a

File tree

2 files changed

+46
-26
lines changed

2 files changed

+46
-26
lines changed

server/sse.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -365,15 +365,19 @@ func (s *SSEServer) handleMessage(w http.ResponseWriter, r *http.Request) {
365365

366366
// Only send response if there is one (not for notifications)
367367
if response != nil {
368+
var message string
368369
eventData, err := json.Marshal(response)
369370
if err != nil {
370-
s.writeJSONRPCError(w, nil, mcp.INTERNAL_ERROR, "Fail to marshal response")
371+
// If there is an error marshalling the response, send a generic error response
372+
log.Printf("failed to marshal response: %v", err)
373+
message = fmt.Sprintf("event: message\ndata: {\"error\": \"internal error\",\"jsonrpc\": \"2.0\", \"id\": null}\n\n")
371374
return
372375
}
376+
message = fmt.Sprintf("event: message\ndata: %s\n\n", eventData)
373377

374378
// Queue the event for sending via SSE
375379
select {
376-
case session.eventQueue <- fmt.Sprintf("event: message\ndata: %s\n\n", eventData):
380+
case session.eventQueue <- message:
377381
// Event queued successfully
378382
case <-session.done:
379383
// Session is closed, don't try to queue

server/sse_test.go

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,10 @@ func TestSSEServer(t *testing.T) {
5959
defer sseResp.Body.Close()
6060

6161
// Read the endpoint event
62-
buf := make([]byte, 1024)
63-
n, err := sseResp.Body.Read(buf)
62+
endpointEvent, err := readSeeEvent(sseResp)
6463
if err != nil {
6564
t.Fatalf("Failed to read SSE response: %v", err)
6665
}
67-
68-
endpointEvent := string(buf[:n])
6966
if !strings.Contains(endpointEvent, "event: endpoint") {
7067
t.Fatalf("Expected endpoint event, got: %s", endpointEvent)
7168
}
@@ -107,19 +104,6 @@ func TestSSEServer(t *testing.T) {
107104
if resp.StatusCode != http.StatusAccepted {
108105
t.Errorf("Expected status 202, got %d", resp.StatusCode)
109106
}
110-
111-
// Verify response
112-
var response map[string]interface{}
113-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
114-
t.Fatalf("Failed to decode response: %v", err)
115-
}
116-
117-
if response["jsonrpc"] != "2.0" {
118-
t.Errorf("Expected jsonrpc 2.0, got %v", response["jsonrpc"])
119-
}
120-
if response["id"].(float64) != 1 {
121-
t.Errorf("Expected id 1, got %v", response["id"])
122-
}
123107
})
124108

125109
t.Run("Can handle multiple sessions", func(t *testing.T) {
@@ -208,8 +192,17 @@ func TestSSEServer(t *testing.T) {
208192
}
209193
defer resp.Body.Close()
210194

195+
endpointEvent, err = readSeeEvent(sseResp)
196+
if err != nil {
197+
t.Fatalf("Failed to read SSE response: %v", err)
198+
}
199+
respFromSee := strings.TrimSpace(
200+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
201+
)
202+
203+
fmt.Printf("========> %v", respFromSee)
211204
var response map[string]interface{}
212-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
205+
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
213206
t.Errorf(
214207
"Session %d: Failed to decode response: %v",
215208
sessionNum,
@@ -586,13 +579,10 @@ func TestSSEServer(t *testing.T) {
586579
defer sseResp.Body.Close()
587580

588581
// Read the endpoint event
589-
buf := make([]byte, 1024)
590-
n, err := sseResp.Body.Read(buf)
582+
endpointEvent, err := readSeeEvent(sseResp)
591583
if err != nil {
592584
t.Fatalf("Failed to read SSE response: %v", err)
593585
}
594-
595-
endpointEvent := string(buf[:n])
596586
messageURL := strings.TrimSpace(
597587
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
598588
)
@@ -632,8 +622,16 @@ func TestSSEServer(t *testing.T) {
632622
}
633623

634624
// Verify response
625+
endpointEvent, err = readSeeEvent(sseResp)
626+
if err != nil {
627+
t.Fatalf("Failed to read SSE response: %v", err)
628+
}
629+
respFromSee := strings.TrimSpace(
630+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
631+
)
632+
635633
var response map[string]interface{}
636-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
634+
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
637635
t.Fatalf("Failed to decode response: %v", err)
638636
}
639637

@@ -671,8 +669,17 @@ func TestSSEServer(t *testing.T) {
671669
}
672670
defer resp.Body.Close()
673671

672+
endpointEvent, err = readSeeEvent(sseResp)
673+
if err != nil {
674+
t.Fatalf("Failed to read SSE response: %v", err)
675+
}
676+
677+
respFromSee = strings.TrimSpace(
678+
strings.Split(strings.Split(endpointEvent, "data: ")[1], "\n")[0],
679+
)
680+
674681
response = make(map[string]interface{})
675-
if err := json.NewDecoder(resp.Body).Decode(&response); err != nil {
682+
if err := json.NewDecoder(strings.NewReader(respFromSee)).Decode(&response); err != nil {
676683
t.Fatalf("Failed to decode response: %v", err)
677684
}
678685

@@ -740,3 +747,12 @@ func TestSSEServer(t *testing.T) {
740747
}
741748
})
742749
}
750+
751+
func readSeeEvent(sseResp *http.Response) (string, error) {
752+
buf := make([]byte, 1024)
753+
n, err := sseResp.Body.Read(buf)
754+
if err != nil {
755+
return "", err
756+
}
757+
return string(buf[:n]), nil
758+
}

0 commit comments

Comments
 (0)