Skip to content

Fix: handle subprocess failing after startup #134

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
40 changes: 35 additions & 5 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,22 @@ import (
"bufio"
"context"
"encoding/json"
"errors"
"fmt"
"io"
"os"
"os/exec"
"sync"
"time"

"github.com/mark3labs/mcp-go/mcp"
)

const (
readyTimeout = 5 * time.Second
readyCheckTimeout = 1 * time.Second
)

// Stdio implements the transport layer of the MCP protocol using stdio communication.
// It launches a subprocess and communicates with it via standard input/output streams
// using JSON-RPC messages. The client handles message routing between requests and
Expand All @@ -31,6 +38,7 @@ type Stdio struct {
done chan struct{}
onNotification func(mcp.JSONRPCNotification)
notifyMu sync.RWMutex
processExitErr chan error
}

// NewStdio creates a new stdio transport to communicate with a subprocess.
Expand All @@ -47,8 +55,9 @@ func NewStdio(
args: args,
env: env,

responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
processExitErr: make(chan error, 1),
}

return client
Expand Down Expand Up @@ -86,17 +95,38 @@ func (c *Stdio) Start(ctx context.Context) error {
return fmt.Errorf("failed to start command: %w", err)
}

go func() {
c.processExitErr <- cmd.Wait()
}()

// Start reading responses in a goroutine and wait for it to be ready
ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()
<-ready

if err := waitUntilReadyOrExit(ready, c.processExitErr, readyTimeout); err != nil {
return err
}
return nil
}

func waitUntilReadyOrExit(ready <-chan struct{}, waitErr <-chan error, timeout time.Duration) error {
select {
case err := <-waitErr:
return fmt.Errorf("process exited early: %w", err)
case <-ready:
select {
case err := <-waitErr:
return fmt.Errorf("process exited after ready: %w", err)
case <-time.After(readyCheckTimeout):
return nil
}
case <-time.After(timeout):
return errors.New("timeout waiting for process ready")
}
}

// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
func (c *Stdio) Close() error {
Expand All @@ -107,7 +137,7 @@ func (c *Stdio) Close() error {
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}
return c.cmd.Wait()
return <-c.processExitErr
}

// OnNotification registers a handler function to be called when notifications are received.
Expand Down
31 changes: 15 additions & 16 deletions server/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ var _ ClientSession = (*sseSession)(nil)
// SSEServer implements a Server-Sent Events (SSE) based MCP server.
// It provides real-time communication capabilities over HTTP using the SSE protocol.
type SSEServer struct {
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc
server *MCPServer
baseURL string
basePath string
useFullURLForMessageEndpoint bool
messageEndpoint string
sseEndpoint string
sessions sync.Map
srv *http.Server
contextFunc SSEContextFunc

keepAlive bool
keepAliveInterval time.Duration
Expand Down Expand Up @@ -158,12 +158,12 @@ func WithSSEContextFunc(fn SSEContextFunc) SSEOption {
// NewSSEServer creates a new SSE server instance with the given MCP server and options.
func NewSSEServer(server *MCPServer, opts ...SSEOption) *SSEServer {
s := &SSEServer{
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
server: server,
sseEndpoint: "/sse",
messageEndpoint: "/message",
useFullURLForMessageEndpoint: true,
keepAlive: false,
keepAliveInterval: 10 * time.Second,
}

// Apply all options
Expand Down Expand Up @@ -293,7 +293,6 @@ func (s *SSEServer) handleSSE(w http.ResponseWriter, r *http.Request) {
}()
}


// Send the initial endpoint event
fmt.Fprintf(w, "event: endpoint\ndata: %s\r\n\r\n", s.GetMessageEndpointForClient(sessionID))
flusher.Flush()
Expand Down