Skip to content

Commit 0f4134f

Browse files
committed
validate tools params
1 parent a26315f commit 0f4134f

File tree

8 files changed

+991
-228
lines changed

8 files changed

+991
-228
lines changed

pkg/github/code_scanning.go

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,18 @@ func getCodeScanningAlert(client *github.Client, t translations.TranslationHelpe
3030
),
3131
),
3232
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
33-
owner, _ := request.Params.Arguments["owner"].(string)
34-
repo, _ := request.Params.Arguments["repo"].(string)
35-
alertNumber, _ := request.Params.Arguments["alert_number"].(float64)
33+
owner, err := requiredStringParam(request, "owner")
34+
if err != nil {
35+
return mcp.NewToolResultError(err.Error()), nil
36+
}
37+
repo, err := requiredStringParam(request, "repo")
38+
if err != nil {
39+
return mcp.NewToolResultError(err.Error()), nil
40+
}
41+
alertNumber, err := requiredNumberParam(request, "alert_number")
42+
if err != nil {
43+
return mcp.NewToolResultError(err.Error()), nil
44+
}
3645

3746
alert, resp, err := client.CodeScanning.GetAlert(ctx, owner, repo, int64(alertNumber))
3847
if err != nil {
@@ -80,11 +89,26 @@ func listCodeScanningAlerts(client *github.Client, t translations.TranslationHel
8089
),
8190
),
8291
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
83-
owner, _ := request.Params.Arguments["owner"].(string)
84-
repo, _ := request.Params.Arguments["repo"].(string)
85-
ref, _ := request.Params.Arguments["ref"].(string)
86-
state, _ := request.Params.Arguments["state"].(string)
87-
severity, _ := request.Params.Arguments["severity"].(string)
92+
owner, err := requiredStringParam(request, "owner")
93+
if err != nil {
94+
return mcp.NewToolResultError(err.Error()), nil
95+
}
96+
repo, err := requiredStringParam(request, "repo")
97+
if err != nil {
98+
return mcp.NewToolResultError(err.Error()), nil
99+
}
100+
ref, err := optionalStringParam(request, "ref")
101+
if err != nil {
102+
return mcp.NewToolResultError(err.Error()), nil
103+
}
104+
state, err := optionalStringParam(request, "state")
105+
if err != nil {
106+
return mcp.NewToolResultError(err.Error()), nil
107+
}
108+
severity, err := optionalStringParam(request, "severity")
109+
if err != nil {
110+
return mcp.NewToolResultError(err.Error()), nil
111+
}
88112

89113
alerts, resp, err := client.CodeScanning.ListAlertsForRepo(ctx, owner, repo, &github.AlertListOptions{Ref: ref, State: state, Severity: severity})
90114
if err != nil {

pkg/github/issues.go

Lines changed: 137 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,18 @@ func getIssue(client *github.Client, t translations.TranslationHelperFunc) (tool
3232
),
3333
),
3434
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
35-
owner := request.Params.Arguments["owner"].(string)
36-
repo := request.Params.Arguments["repo"].(string)
37-
issueNumber := int(request.Params.Arguments["issue_number"].(float64))
35+
owner, err := requiredStringParam(request, "owner")
36+
if err != nil {
37+
return mcp.NewToolResultError(err.Error()), nil
38+
}
39+
repo, err := requiredStringParam(request, "repo")
40+
if err != nil {
41+
return mcp.NewToolResultError(err.Error()), nil
42+
}
43+
issueNumber, err := requiredNumberParam(request, "issue_number")
44+
if err != nil {
45+
return mcp.NewToolResultError(err.Error()), nil
46+
}
3847

3948
issue, resp, err := client.Issues.Get(ctx, owner, repo, issueNumber)
4049
if err != nil {
@@ -81,10 +90,22 @@ func addIssueComment(client *github.Client, t translations.TranslationHelperFunc
8190
),
8291
),
8392
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
84-
owner := request.Params.Arguments["owner"].(string)
85-
repo := request.Params.Arguments["repo"].(string)
86-
issueNumber := int(request.Params.Arguments["issue_number"].(float64))
87-
body := request.Params.Arguments["body"].(string)
93+
owner, err := requiredStringParam(request, "owner")
94+
if err != nil {
95+
return mcp.NewToolResultError(err.Error()), nil
96+
}
97+
repo, err := requiredStringParam(request, "repo")
98+
if err != nil {
99+
return mcp.NewToolResultError(err.Error()), nil
100+
}
101+
issueNumber, err := requiredNumberParam(request, "issue_number")
102+
if err != nil {
103+
return mcp.NewToolResultError(err.Error()), nil
104+
}
105+
body, err := requiredStringParam(request, "body")
106+
if err != nil {
107+
return mcp.NewToolResultError(err.Error()), nil
108+
}
88109

89110
comment := &github.IssueComment{
90111
Body: github.Ptr(body),
@@ -135,22 +156,25 @@ func searchIssues(client *github.Client, t translations.TranslationHelperFunc) (
135156
),
136157
),
137158
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
138-
query := request.Params.Arguments["q"].(string)
139-
sort := ""
140-
if s, ok := request.Params.Arguments["sort"].(string); ok {
141-
sort = s
159+
query, err := requiredStringParam(request, "q")
160+
if err != nil {
161+
return mcp.NewToolResultError(err.Error()), nil
142162
}
143-
order := ""
144-
if o, ok := request.Params.Arguments["order"].(string); ok {
145-
order = o
163+
sort, err := optionalStringParam(request, "sort")
164+
if err != nil {
165+
return mcp.NewToolResultError(err.Error()), nil
146166
}
147-
perPage := 30
148-
if pp, ok := request.Params.Arguments["per_page"].(float64); ok {
149-
perPage = int(pp)
167+
order, err := optionalStringParam(request, "order")
168+
if err != nil {
169+
return mcp.NewToolResultError(err.Error()), nil
150170
}
151-
page := 1
152-
if p, ok := request.Params.Arguments["page"].(float64); ok {
153-
page = int(p)
171+
perPage, err := optionalNumberParamWithDefault(request, "per_page", 30)
172+
if err != nil {
173+
return mcp.NewToolResultError(err.Error()), nil
174+
}
175+
page, err := optionalNumberParamWithDefault(request, "page", 1)
176+
if err != nil {
177+
return mcp.NewToolResultError(err.Error()), nil
154178
}
155179

156180
opts := &github.SearchOptions{
@@ -212,26 +236,34 @@ func createIssue(client *github.Client, t translations.TranslationHelperFunc) (t
212236
),
213237
),
214238
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
215-
owner := request.Params.Arguments["owner"].(string)
216-
repo := request.Params.Arguments["repo"].(string)
217-
title := request.Params.Arguments["title"].(string)
239+
owner, err := requiredStringParam(request, "owner")
240+
if err != nil {
241+
return mcp.NewToolResultError(err.Error()), nil
242+
}
243+
repo, err := requiredStringParam(request, "repo")
244+
if err != nil {
245+
return mcp.NewToolResultError(err.Error()), nil
246+
}
247+
title, err := requiredStringParam(request, "title")
248+
if err != nil {
249+
return mcp.NewToolResultError(err.Error()), nil
250+
}
218251

219252
// Optional parameters
220-
var body string
221-
if b, ok := request.Params.Arguments["body"].(string); ok {
222-
body = b
253+
body, err := optionalStringParam(request, "body")
254+
if err != nil {
255+
return mcp.NewToolResultError(err.Error()), nil
223256
}
224257

225-
// Parse assignees if present
226-
assignees := []string{} // default to empty slice, can't be nil
227-
if a, ok := request.Params.Arguments["assignees"].(string); ok && a != "" {
228-
assignees = parseCommaSeparatedList(a)
258+
// Get assignees
259+
assignees, err := optionalCommaSeparatedListParam(request, "assignees")
260+
if err != nil {
261+
return mcp.NewToolResultError(err.Error()), nil
229262
}
230-
231-
// Parse labels if present
232-
labels := []string{} // default to empty slice, can't be nil
233-
if l, ok := request.Params.Arguments["labels"].(string); ok && l != "" {
234-
labels = parseCommaSeparatedList(l)
263+
// Get labels
264+
labels, err := optionalCommaSeparatedListParam(request, "labels")
265+
if err != nil {
266+
return mcp.NewToolResultError(err.Error()), nil
235267
}
236268

237269
// Create the issue request
@@ -300,29 +332,43 @@ func listIssues(client *github.Client, t translations.TranslationHelperFunc) (to
300332
),
301333
),
302334
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
303-
owner := request.Params.Arguments["owner"].(string)
304-
repo := request.Params.Arguments["repo"].(string)
335+
owner, err := requiredStringParam(request, "owner")
336+
if err != nil {
337+
return mcp.NewToolResultError(err.Error()), nil
338+
}
339+
repo, err := requiredStringParam(request, "repo")
340+
if err != nil {
341+
return mcp.NewToolResultError(err.Error()), nil
342+
}
305343

306344
opts := &github.IssueListByRepoOptions{}
307345

308346
// Set optional parameters if provided
309-
if state, ok := request.Params.Arguments["state"].(string); ok && state != "" {
310-
opts.State = state
347+
opts.State, err = optionalStringParam(request, "state")
348+
if err != nil {
349+
return mcp.NewToolResultError(err.Error()), nil
311350
}
312351

313-
if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" {
314-
opts.Labels = parseCommaSeparatedList(labels)
352+
opts.Labels, err = optionalCommaSeparatedListParam(request, "labels")
353+
if err != nil {
354+
return mcp.NewToolResultError(err.Error()), nil
315355
}
316356

317-
if sort, ok := request.Params.Arguments["sort"].(string); ok && sort != "" {
318-
opts.Sort = sort
357+
opts.Sort, err = optionalStringParam(request, "sort")
358+
if err != nil {
359+
return mcp.NewToolResultError(err.Error()), nil
319360
}
320361

321-
if direction, ok := request.Params.Arguments["direction"].(string); ok && direction != "" {
322-
opts.Direction = direction
362+
opts.Direction, err = optionalStringParam(request, "direction")
363+
if err != nil {
364+
return mcp.NewToolResultError(err.Error()), nil
323365
}
324366

325-
if since, ok := request.Params.Arguments["since"].(string); ok && since != "" {
367+
since, err := optionalStringParam(request, "since")
368+
if err != nil {
369+
return mcp.NewToolResultError(err.Error()), nil
370+
}
371+
if since != "" {
326372
timestamp, err := parseISOTimestamp(since)
327373
if err != nil {
328374
return mcp.NewToolResultError(fmt.Sprintf("failed to list issues: %s", err.Error())), nil
@@ -397,38 +443,69 @@ func updateIssue(client *github.Client, t translations.TranslationHelperFunc) (t
397443
),
398444
),
399445
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
400-
owner := request.Params.Arguments["owner"].(string)
401-
repo := request.Params.Arguments["repo"].(string)
402-
issueNumber := int(request.Params.Arguments["issue_number"].(float64))
446+
owner, err := requiredStringParam(request, "owner")
447+
if err != nil {
448+
return mcp.NewToolResultError(err.Error()), nil
449+
}
450+
repo, err := requiredStringParam(request, "repo")
451+
if err != nil {
452+
return mcp.NewToolResultError(err.Error()), nil
453+
}
454+
issueNumber, err := requiredNumberParam(request, "issue_number")
455+
if err != nil {
456+
return mcp.NewToolResultError(err.Error()), nil
457+
}
403458

404459
// Create the issue request with only provided fields
405460
issueRequest := &github.IssueRequest{}
406461

407462
// Set optional parameters if provided
408-
if title, ok := request.Params.Arguments["title"].(string); ok && title != "" {
463+
title, err := optionalStringParam(request, "title")
464+
if err != nil {
465+
return mcp.NewToolResultError(err.Error()), nil
466+
}
467+
if title != "" {
409468
issueRequest.Title = github.Ptr(title)
410469
}
411470

412-
if body, ok := request.Params.Arguments["body"].(string); ok && body != "" {
471+
body, err := optionalStringParam(request, "body")
472+
if err != nil {
473+
return mcp.NewToolResultError(err.Error()), nil
474+
}
475+
if body != "" {
413476
issueRequest.Body = github.Ptr(body)
414477
}
415478

416-
if state, ok := request.Params.Arguments["state"].(string); ok && state != "" {
479+
state, err := optionalStringParam(request, "state")
480+
if err != nil {
481+
return mcp.NewToolResultError(err.Error()), nil
482+
}
483+
if state != "" {
417484
issueRequest.State = github.Ptr(state)
418485
}
419486

420-
if labels, ok := request.Params.Arguments["labels"].(string); ok && labels != "" {
421-
labelsList := parseCommaSeparatedList(labels)
422-
issueRequest.Labels = &labelsList
487+
labels, err := optionalCommaSeparatedListParam(request, "labels")
488+
if err != nil {
489+
return mcp.NewToolResultError(err.Error()), nil
490+
}
491+
if len(labels) > 0 {
492+
issueRequest.Labels = &labels
423493
}
424494

425-
if assignees, ok := request.Params.Arguments["assignees"].(string); ok && assignees != "" {
426-
assigneesList := parseCommaSeparatedList(assignees)
427-
issueRequest.Assignees = &assigneesList
495+
assignees, err := optionalCommaSeparatedListParam(request, "assignees")
496+
if err != nil {
497+
return mcp.NewToolResultError(err.Error()), nil
498+
}
499+
if len(assignees) > 0 {
500+
issueRequest.Assignees = &assignees
428501
}
429502

430-
if milestone, ok := request.Params.Arguments["milestone"].(float64); ok {
431-
milestoneNum := int(milestone)
503+
milestone, err := optionalNumberParam(request, "milestone")
504+
if err != nil {
505+
return mcp.NewToolResultError(err.Error()), nil
506+
}
507+
if milestone != 0 {
508+
milestoneNum := milestone
432509
issueRequest.Milestone = &milestoneNum
433510
}
434511

pkg/github/issues_test.go

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ func Test_AddIssueComment(t *testing.T) {
176176
"issue_number": float64(42),
177177
"body": "",
178178
},
179-
expectError: true,
180-
expectedErrMsg: "failed to create comment",
179+
expectError: false,
180+
expectedErrMsg: "missing required parameter: body",
181181
},
182182
}
183183

@@ -210,6 +210,13 @@ func Test_AddIssueComment(t *testing.T) {
210210
return
211211
}
212212

213+
if tc.expectedErrMsg != "" {
214+
require.NotNil(t, result)
215+
textContent := getTextResult(t, result)
216+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
217+
return
218+
}
219+
213220
require.NoError(t, err)
214221

215222
// Parse the result and get the text content if no error
@@ -419,8 +426,8 @@ func Test_CreateIssue(t *testing.T) {
419426
"repo": "repo",
420427
"title": "Test Issue",
421428
"body": "This is a test issue",
422-
"assignees": []interface{}{"user1", "user2"},
423-
"labels": []interface{}{"bug", "help wanted"},
429+
"assignees": "user1, user2",
430+
"labels": "bug, help wanted",
424431
},
425432
expectError: false,
426433
expectedIssue: mockIssue,
@@ -467,8 +474,8 @@ func Test_CreateIssue(t *testing.T) {
467474
"repo": "repo",
468475
"title": "",
469476
},
470-
expectError: true,
471-
expectedErrMsg: "failed to create issue",
477+
expectError: false,
478+
expectedErrMsg: "missing required parameter: title",
472479
},
473480
}
474481

@@ -491,6 +498,13 @@ func Test_CreateIssue(t *testing.T) {
491498
return
492499
}
493500

501+
if tc.expectedErrMsg != "" {
502+
require.NotNil(t, result)
503+
textContent := getTextResult(t, result)
504+
assert.Contains(t, textContent.Text, tc.expectedErrMsg)
505+
return
506+
}
507+
494508
require.NoError(t, err)
495509
textContent := getTextResult(t, result)
496510

0 commit comments

Comments
 (0)