Skip to content

Commit 9da532f

Browse files
committed
optimize hooks
1 parent db92287 commit 9da532f

File tree

1 file changed

+27
-0
lines changed

1 file changed

+27
-0
lines changed

server/server_test.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/json"
77
"errors"
88
"fmt"
9+
"sync"
910
"testing"
1011
"time"
1112

@@ -1305,6 +1306,7 @@ var globalHookCtx = &globalHookContext{
13051306
}
13061307

13071308
type globalHookContext struct {
1309+
mu sync.Mutex
13081310
BeforeAnyMessages []any
13091311
BeforeAnyCount *int
13101312
OnSuccessCount *int
@@ -1347,6 +1349,7 @@ func (*globalHook) BeforeAny(ctx context.Context, hookContext HookContext, id an
13471349
// Only collect ping messages for our test
13481350
if method == mcp.MethodPing {
13491351
hookContext.(*globalHookContext).BeforeAnyMessages = append(hookContext.(*globalHookContext).BeforeAnyMessages, message)
1352+
hookContext.(*globalHookContext).mu.Unlock()
13501353
}
13511354
}
13521355

@@ -1368,14 +1371,23 @@ func (*globalHook) OnError(ctx context.Context, hookContext HookContext, id any,
13681371
}
13691372

13701373
func (e *globalHook) BeforeInitialize(ctx context.Context, hookContext HookContext, id any, message *mcp.InitializeRequest) {
1374+
hookContext.(*globalHookContext).mu.Lock()
1375+
defer hookContext.(*globalHookContext).mu.Unlock()
1376+
13711377
e.BeforeAny(ctx, hookContext, id, mcp.MethodInitialize, message)
13721378
}
13731379

13741380
func (e *globalHook) AfterInitialize(ctx context.Context, hookContext HookContext, id any, message *mcp.InitializeRequest, result *mcp.InitializeResult) {
1381+
hookContext.(*globalHookContext).mu.Lock()
1382+
defer hookContext.(*globalHookContext).mu.Unlock()
1383+
13751384
e.OnSuccess(ctx, hookContext, id, mcp.MethodInitialize, message, result)
13761385
}
13771386

13781387
func (e *globalHook) BeforePing(ctx context.Context, hookContext HookContext, id any, message *mcp.PingRequest) {
1388+
hookContext.(*globalHookContext).mu.Lock()
1389+
defer hookContext.(*globalHookContext).mu.Unlock()
1390+
13791391
e.BeforeAny(ctx, hookContext, id, mcp.MethodPing, message)
13801392

13811393
newCount := (*hookContext.(*globalHookContext).BeforePingCount) + 1
@@ -1385,6 +1397,9 @@ func (e *globalHook) BeforePing(ctx context.Context, hookContext HookContext, id
13851397
}
13861398

13871399
func (e *globalHook) AfterPing(ctx context.Context, hookContext HookContext, id any, message *mcp.PingRequest, result *mcp.EmptyResult) {
1400+
hookContext.(*globalHookContext).mu.Lock()
1401+
defer hookContext.(*globalHookContext).mu.Unlock()
1402+
13881403
e.OnSuccess(ctx, hookContext, id, mcp.MethodPing, message, result)
13891404

13901405
newCount := (*hookContext.(*globalHookContext).AfterPingCount) + 1
@@ -1427,24 +1442,36 @@ func (*globalHook) AfterGetPrompt(ctx context.Context, hookContext HookContext,
14271442
}
14281443

14291444
func (e *globalHook) BeforeListTools(ctx context.Context, hookContext HookContext, id any, message *mcp.ListToolsRequest) {
1445+
hookContext.(*globalHookContext).mu.Lock()
1446+
defer hookContext.(*globalHookContext).mu.Unlock()
1447+
14301448
e.BeforeAny(ctx, hookContext, id, mcp.MethodToolsList, message)
14311449

14321450
newCount := (*hookContext.(*globalHookContext).BeforeToolsCount) + 1
14331451
hookContext.(*globalHookContext).BeforeToolsCount = &(newCount)
14341452
}
14351453

14361454
func (e *globalHook) AfterListTools(ctx context.Context, hookContext HookContext, id any, message *mcp.ListToolsRequest, result *mcp.ListToolsResult) {
1455+
hookContext.(*globalHookContext).mu.Lock()
1456+
defer hookContext.(*globalHookContext).mu.Unlock()
1457+
14371458
e.OnSuccess(ctx, hookContext, id, mcp.MethodToolsList, message, result)
14381459

14391460
newCount := (*hookContext.(*globalHookContext).AfterToolsCount) + 1
14401461
hookContext.(*globalHookContext).AfterToolsCount = &(newCount)
14411462
}
14421463

14431464
func (e *globalHook) BeforeCallTool(ctx context.Context, hookContext HookContext, id any, message *mcp.CallToolRequest) {
1465+
hookContext.(*globalHookContext).mu.Lock()
1466+
defer hookContext.(*globalHookContext).mu.Unlock()
1467+
14441468
e.BeforeAny(ctx, hookContext, id, mcp.MethodToolsCall, message)
14451469
}
14461470

14471471
func (e *globalHook) AfterCallTool(ctx context.Context, hookContext HookContext, id any, message *mcp.CallToolRequest, result *mcp.CallToolResult) {
1472+
hookContext.(*globalHookContext).mu.Lock()
1473+
defer hookContext.(*globalHookContext).mu.Unlock()
1474+
14481475
e.OnSuccess(ctx, hookContext, id, mcp.MethodToolsCall, message, result)
14491476
}
14501477

0 commit comments

Comments
 (0)