Skip to content

Commit c65164d

Browse files
fix
1 parent 7a5c8b3 commit c65164d

File tree

1 file changed

+24
-19
lines changed

1 file changed

+24
-19
lines changed

client/transport/stdio.go

Lines changed: 24 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ type Stdio struct {
3939
done chan struct{}
4040
onNotification func(mcp.JSONRPCNotification)
4141
notifyMu sync.RWMutex
42-
processExited chan struct{}
4342
exitErr atomic.Value
4443
}
4544

@@ -71,9 +70,8 @@ func NewStdio(
7170
args: args,
7271
env: env,
7372

74-
responses: make(map[int64]chan *JSONRPCResponse),
75-
done: make(chan struct{}),
76-
processExited: make(chan struct{}),
73+
responses: make(map[int64]chan *JSONRPCResponse),
74+
done: make(chan struct{}),
7775
}
7876

7977
return client
@@ -83,7 +81,14 @@ func (c *Stdio) Start(ctx context.Context) error {
8381
if err := c.spawnCommand(ctx); err != nil {
8482
return err
8583
}
86-
return nil
84+
85+
// Start reading responses in a goroutine and wait for it to be ready
86+
ready := make(chan struct{})
87+
go func() {
88+
close(ready)
89+
c.readResponses()
90+
}()
91+
return waitUntilReadyOrExit(ready, c.done, readyTimeout)
8792
}
8893

8994
// spawnCommand spawns a new process running c.command.
@@ -121,27 +126,25 @@ func (c *Stdio) spawnCommand(ctx context.Context) error {
121126
if err := cmd.Start(); err != nil {
122127
return fmt.Errorf("failed to start command: %w", err)
123128
}
129+
124130
go func() {
125131
err := cmd.Wait()
126132
if err != nil {
127133
c.exitErr.Store(err)
128134
}
129-
close(c.processExited)
130-
}()
131-
132-
// Start reading responses in a goroutine and wait for it to be ready
133-
ready := make(chan struct{})
134-
go func() {
135-
close(ready)
136-
c.readResponses()
135+
tryCloseDone(c.done)
137136
}()
138-
139-
if err := waitUntilReadyOrExit(ready, c.processExited, readyTimeout); err != nil {
140-
return err
141-
}
142137
return nil
143138
}
144139

140+
func tryCloseDone(done chan struct{}) {
141+
select {
142+
case <-done:
143+
return
144+
default:
145+
}
146+
close(done)
147+
}
145148
func waitUntilReadyOrExit(ready <-chan struct{}, exited <-chan struct{}, timeout time.Duration) error {
146149
select {
147150
case <-exited:
@@ -161,14 +164,15 @@ func waitUntilReadyOrExit(ready <-chan struct{}, exited <-chan struct{}, timeout
161164
// Close shuts down the stdio client, closing the stdin pipe and waiting for the subprocess to exit.
162165
// Returns an error if there are issues closing stdin or waiting for the subprocess to terminate.
163166
func (c *Stdio) Close() error {
164-
close(c.done)
167+
// cancel all in-flight request
168+
tryCloseDone(c.done)
165169
if err := c.stdin.Close(); err != nil {
166170
return fmt.Errorf("failed to close stdin: %w", err)
167171
}
168172
if err := c.stderr.Close(); err != nil {
169173
return fmt.Errorf("failed to close stderr: %w", err)
170174
}
171-
<-c.processExited
175+
172176
if err, ok := c.exitErr.Load().(error); ok && err != nil {
173177
return err
174178
}
@@ -270,6 +274,7 @@ func (c *Stdio) SendRequest(
270274
deleteResponseChan()
271275
return nil, fmt.Errorf("failed to write request: %w", err)
272276
}
277+
273278
select {
274279
case <-ctx.Done():
275280
deleteResponseChan()

0 commit comments

Comments
 (0)