Skip to content

Fix: handle subprocess failing after startup #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
4 changes: 2 additions & 2 deletions client/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,13 @@ func TestStdioMCPClient(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down
1 change: 0 additions & 1 deletion client/transport/sse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -501,5 +501,4 @@ func TestSSEErrors(t *testing.T) {
t.Errorf("Expected error when sending request after close, got nil")
}
})

}
76 changes: 51 additions & 25 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,23 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"sync"
"sync/atomic"
"time"

"github.com/mark3labs/mcp-go/mcp"
)

const (
readyTimeout = 5 * time.Second
readyCheckTimeout = 1 * time.Second
)

// Stdio implements the transport layer of the MCP protocol using stdio communication.
// It launches a subprocess and communicates with it via standard input/output streams
// using JSON-RPC messages. The client handles message routing between requests and
Expand All @@ -31,6 +39,8 @@ type Stdio struct {
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
processExited chan struct{}
exitErr atomic.Value
}

// NewIO returns a new stdio-based transport using existing input, output, and
Expand Down Expand Up @@ -61,8 +71,9 @@ func NewStdio(
args: args,
env: env,

responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
processExited: make(chan struct{}),
}

return client
Expand All @@ -72,14 +83,6 @@ func (c *Stdio) Start(ctx context.Context) error {
if err := c.spawnCommand(ctx); err != nil {
return err
}

ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()
<-ready

return nil
}

Expand All @@ -95,7 +98,6 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
mergedEnv = append(mergedEnv, c.env...)

cmd.Env = mergedEnv

stdin, err := cmd.StdinPipe()
if err != nil {
return fmt.Errorf("failed to create stdin pipe: %w", err)
Expand All @@ -119,37 +121,62 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
if err := cmd.Start(); err != nil {
return fmt.Errorf("failed to start command: %w", err)
}
go func() {
err := cmd.Wait()
if err != nil {
c.exitErr.Store(err)
}
close(c.processExited)
}()

// Start reading responses in a goroutine and wait for it to be ready
ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()

if err := waitUntilReadyOrExit(ready, c.processExited, readyTimeout); err != nil {
return err
}
return nil
}

func waitUntilReadyOrExit(ready <-chan struct{}, exited <-chan struct{}, timeout time.Duration) error {
select {
case <-exited:
return errors.New("process exited before signalling readiness")
case <-ready:
select {
case <-exited:
return errors.New("process exited after readiness")
case <-time.After(readyCheckTimeout):
return nil
}
case <-time.After(timeout):
return errors.New("timeout waiting for process ready")
}
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
func (c *Stdio) Close() error {
select {
case <-c.done:
return nil
default:
}
// cancel all in-flight request
close(c.done)

if err := c.stdin.Close(); err != nil {
return fmt.Errorf("failed to close stdin: %w", err)
}
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}

if c.cmd != nil {
return c.cmd.Wait()
<-c.processExited
if err, ok := c.exitErr.Load().(error); ok && err != nil {
return err
}

return nil
}

// SetNotificationHandler sets the handler function to be called when a notification is received.
// Only one handler can be set at a time; setting a new one replaces the previous handler.
// OnNotification registers a handler function to be called when notifications are received.
// Multiple handlers can be registered and will be called in the order they were added.
func (c *Stdio) SetNotificationHandler(
handler func(notification mcp.JSONRPCNotification),
) {
Expand Down Expand Up @@ -243,7 +270,6 @@ func (c *Stdio) SendRequest(
deleteResponseChan()
return nil, fmt.Errorf("failed to write request: %w", err)
}

select {
case <-ctx.Done():
deleteResponseChan()
Expand Down
41 changes: 35 additions & 6 deletions client/transport/stdio_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ func TestStdio(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down Expand Up @@ -329,13 +329,13 @@ func TestStdioErrors(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down Expand Up @@ -368,13 +368,13 @@ func TestStdioErrors(t *testing.T) {
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
Expand Down Expand Up @@ -407,5 +407,34 @@ func TestStdioErrors(t *testing.T) {
t.Errorf("Expected error when sending request after close, got nil")
}
})
t.Run("SubprocessStartsAndExitsImmediately", func(t *testing.T) {
// Create a temporary file for the mock server
tempFile, err := os.CreateTemp("", "mockstdio_server")
if err != nil {
t.Fatalf("Failed to create temp file: %v", err)
}
tempFile.Close()
mockServerPath := tempFile.Name()

// Add .exe suffix on Windows
if runtime.GOOS == "windows" {
os.Remove(mockServerPath) // Remove the empty file first
mockServerPath += ".exe"
}

if compileErr := compileTestServer(mockServerPath); compileErr != nil {
t.Fatalf("Failed to compile mock server: %v", compileErr)
}
//defer os.Remove(mockServerPath)

// Create a new Stdio transport
stdio := NewStdio(mockServerPath, nil)
stdio.env = append(stdio.env, "MOCK_FAIL_IMMEDIATELY=1")
defer stdio.Close()
// Start the transport
ctx := context.Background()
if startErr := stdio.Start(ctx); startErr == nil {
t.Fatalf("Expected error when starting Stdio transport, got nil")
}
})
}
4 changes: 0 additions & 4 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -655,10 +655,6 @@ func (s *SSEServer) MessageHandler() http.Handler {

// ServeHTTP implements the http.Handler interface.
func (s *SSEServer) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if s.dynamicBasePathFunc != nil {
http.Error(w, (&ErrDynamicPathConfig{Method: "ServeHTTP"}).Error(), http.StatusInternalServerError)
return
}
path := r.URL.Path
// Use exact path matching rather than Contains
ssePath := s.CompleteSsePath()
Expand Down
4 changes: 4 additions & 0 deletions testdata/mockstdio_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ type JSONRPCResponse struct {
}

func main() {
if os.Getenv("MOCK_FAIL_IMMEDIATELY") == "1" {
fmt.Fprintln(os.Stderr, "mock server: simulated startup failure")
os.Exit(1)
}
logger := slog.New(slog.NewJSONHandler(os.Stderr, &slog.HandlerOptions{}))
logger.Info("launch successful")
scanner := bufio.NewScanner(os.Stdin)
Expand Down
Loading