Skip to content

Commit dbdef79

Browse files
refactor: update notification functions to use GetClientFn . Fix conflicts
1 parent b5b3211 commit dbdef79

File tree

2 files changed

+79
-30
lines changed

2 files changed

+79
-30
lines changed

pkg/github/notifications.go

Lines changed: 33 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ import (
1515
)
1616

1717
// getNotifications creates a tool to list notifications for the current user.
18-
func getNotifications(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
18+
func GetNotifications(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
1919
return mcp.NewTool("get_notifications",
2020
mcp.WithDescription(t("TOOL_GET_NOTIFICATIONS_DESCRIPTION", "Get notifications for the authenticated GitHub user")),
2121
mcp.WithBoolean("all",
@@ -38,33 +38,38 @@ func getNotifications(client *github.Client, t translations.TranslationHelperFun
3838
),
3939
),
4040
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
41+
client, err := getClient(ctx)
42+
if err != nil {
43+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
44+
}
45+
4146
// Extract optional parameters with defaults
42-
all, err := optionalParamWithDefault[bool](request, "all", false)
47+
all, err := OptionalBoolParamWithDefault(request, "all", false)
4348
if err != nil {
4449
return mcp.NewToolResultError(err.Error()), nil
4550
}
4651

47-
participating, err := optionalParamWithDefault[bool](request, "participating", false)
52+
participating, err := OptionalBoolParamWithDefault(request, "participating", false)
4853
if err != nil {
4954
return mcp.NewToolResultError(err.Error()), nil
5055
}
5156

52-
since, err := optionalParam[string](request, "since")
57+
since, err := OptionalStringParamWithDefault(request, "since", "")
5358
if err != nil {
5459
return mcp.NewToolResultError(err.Error()), nil
5560
}
5661

57-
before, err := optionalParam[string](request, "before")
62+
before, err := OptionalStringParam(request, "before")
5863
if err != nil {
5964
return mcp.NewToolResultError(err.Error()), nil
6065
}
6166

62-
perPage, err := optionalIntParamWithDefault(request, "per_page", 30)
67+
perPage, err := OptionalIntParamWithDefault(request, "per_page", 30)
6368
if err != nil {
6469
return mcp.NewToolResultError(err.Error()), nil
6570
}
6671

67-
page, err := optionalIntParamWithDefault(request, "page", 1)
72+
page, err := OptionalIntParamWithDefault(request, "page", 1)
6873
if err != nil {
6974
return mcp.NewToolResultError(err.Error()), nil
7075
}
@@ -122,7 +127,7 @@ func getNotifications(client *github.Client, t translations.TranslationHelperFun
122127
}
123128

124129
// markNotificationRead creates a tool to mark a notification as read.
125-
func markNotificationRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
130+
func MarkNotificationRead(getclient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
126131
return mcp.NewTool("mark_notification_read",
127132
mcp.WithDescription(t("TOOL_MARK_NOTIFICATION_READ_DESCRIPTION", "Mark a notification as read")),
128133
mcp.WithString("threadID",
@@ -131,6 +136,11 @@ func markNotificationRead(client *github.Client, t translations.TranslationHelpe
131136
),
132137
),
133138
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
139+
client, err := getclient(ctx)
140+
if err != nil {
141+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
142+
}
143+
134144
threadID, err := requiredParam[string](request, "threadID")
135145
if err != nil {
136146
return mcp.NewToolResultError(err.Error()), nil
@@ -154,16 +164,21 @@ func markNotificationRead(client *github.Client, t translations.TranslationHelpe
154164
}
155165
}
156166

157-
// markAllNotificationsRead creates a tool to mark all notifications as read.
158-
func markAllNotificationsRead(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
167+
// MarkAllNotificationsRead creates a tool to mark all notifications as read.
168+
func MarkAllNotificationsRead(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
159169
return mcp.NewTool("mark_all_notifications_read",
160170
mcp.WithDescription(t("TOOL_MARK_ALL_NOTIFICATIONS_READ_DESCRIPTION", "Mark all notifications as read")),
161171
mcp.WithString("lastReadAt",
162172
mcp.Description("Describes the last point that notifications were checked (optional). Default: Now"),
163173
),
164174
),
165175
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
166-
lastReadAt, err := optionalParam[string](request, "lastReadAt")
176+
client, err := getClient(ctx)
177+
if err != nil {
178+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
179+
}
180+
181+
lastReadAt, err := OptionalStringParam(request, "lastReadAt")
167182
if err != nil {
168183
return mcp.NewToolResultError(err.Error()), nil
169184
}
@@ -197,8 +212,8 @@ func markAllNotificationsRead(client *github.Client, t translations.TranslationH
197212
}
198213
}
199214

200-
// getNotificationThread creates a tool to get a specific notification thread.
201-
func getNotificationThread(client *github.Client, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
215+
// GetNotificationThread creates a tool to get a specific notification thread.
216+
func GetNotificationThread(getClient GetClientFn, t translations.TranslationHelperFunc) (tool mcp.Tool, handler server.ToolHandlerFunc) {
202217
return mcp.NewTool("get_notification_thread",
203218
mcp.WithDescription(t("TOOL_GET_NOTIFICATION_THREAD_DESCRIPTION", "Get a specific notification thread")),
204219
mcp.WithString("threadID",
@@ -207,6 +222,11 @@ func getNotificationThread(client *github.Client, t translations.TranslationHelp
207222
),
208223
),
209224
func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
225+
client, err := getClient(ctx)
226+
if err != nil {
227+
return nil, fmt.Errorf("failed to get GitHub client: %w", err)
228+
}
229+
210230
threadID, err := requiredParam[string](request, "threadID")
211231
if err != nil {
212232
return mcp.NewToolResultError(err.Error()), nil

pkg/github/server.go

Lines changed: 46 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -91,12 +91,14 @@ func NewServer(getClient GetClientFn, version string, readOnly bool, t translati
9191
s.AddTool(GetCodeScanningAlert(getClient, t))
9292
s.AddTool(ListCodeScanningAlerts(getClient, t))
9393

94-
// Add GitHub tools - Notifications
94+
// Add GitHub tools - Notifications
95+
s.AddTool(GetNotifications(getClient, t))
96+
s.AddTool(GetNotificationThread(getClient, t))
9597
if !readOnly {
96-
s.AddTool(markNotificationRead(client, t))
97-
s.AddTool(markAllNotificationsRead(client, t))
98+
s.AddTool(MarkNotificationRead(getClient, t))
99+
s.AddTool(MarkAllNotificationsRead(getClient, t))
98100
}
99-
101+
100102
return s
101103
}
102104

@@ -237,28 +239,55 @@ func OptionalIntParam(r mcp.CallToolRequest, p string) (int, error) {
237239
return int(v), nil
238240
}
239241

240-
// optionalParamWithDefault is a generic helper function that can be used to fetch a requested parameter from the request
241-
// with a default value if the parameter is not provided or is zero value.
242-
func optionalParamWithDefault[T comparable](r mcp.CallToolRequest, p string, d T) (T, error) {
243-
var zero T
244-
v, err := optionalParam[T](r, p)
242+
// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
243+
// similar to optionalIntParam, but it also takes a default value.
244+
func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {
245+
v, err := OptionalIntParam(r, p)
245246
if err != nil {
246-
return zero, err
247+
return 0, err
247248
}
248-
if v == zero {
249+
if v == 0 {
249250
return d, nil
250251
}
251252
return v, nil
252253
}
253254

254-
// OptionalIntParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
255-
// similar to optionalIntParam, but it also takes a default value.
256-
func OptionalIntParamWithDefault(r mcp.CallToolRequest, p string, d int) (int, error) {
257-
v, err := OptionalIntParam(r, p)
255+
// OptionalBoolParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
256+
// similar to optionalParam, but it also takes a default value.
257+
func OptionalBoolParamWithDefault(r mcp.CallToolRequest, p string, d bool) (bool, error) {
258+
v, err := OptionalParam[bool](r, p)
258259
if err != nil {
259-
return 0, err
260+
return false, err
260261
}
261-
if v == 0 {
262+
if v == false {
263+
return d, nil
264+
}
265+
return v, nil
266+
}
267+
268+
// OptionalStringParam is a helper function that can be used to fetch a requested parameter from the request.
269+
// It does the following checks:
270+
// 1. Checks if the parameter is present in the request, if not, it returns its zero-value
271+
// 2. If it is present, it checks if the parameter is of the expected type and returns it
272+
func OptionalStringParam(r mcp.CallToolRequest, p string) (string, error) {
273+
v, err := OptionalParam[string](r, p)
274+
if err != nil {
275+
return "", err
276+
}
277+
if v == "" {
278+
return "", nil
279+
}
280+
return v, nil
281+
}
282+
283+
// OptionalStringParamWithDefault is a helper function that can be used to fetch a requested parameter from the request
284+
// similar to optionalParam, but it also takes a default value.
285+
func OptionalStringParamWithDefault(r mcp.CallToolRequest, p string, d string) (string, error) {
286+
v, err := OptionalParam[string](r, p)
287+
if err != nil {
288+
return "", err
289+
}
290+
if v == "" {
262291
return d, nil
263292
}
264293
return v, nil

0 commit comments

Comments
 (0)