diff --git a/internal/mcp/handlers.go b/internal/mcp/handlers.go index 1940aeab..6a893e4c 100644 --- a/internal/mcp/handlers.go +++ b/internal/mcp/handlers.go @@ -13,6 +13,7 @@ import ( "path/filepath" "strings" "time" + "unicode/utf8" "github.com/mark3labs/mcp-go/mcp" "go.kenn.io/msgvault/internal/deletion" @@ -27,6 +28,14 @@ const ( maxLimit = 1000 maxSearchMessagesLimit = 50 defaultSearchLimit = 20 + defaultBodyChars = 2000 + bodyFormatAuto = "auto" + bodyFormatText = "text" + bodyFormatHTML = "html" + // maxBodyChars caps the body slice returned by get_message regardless of what + // the caller requests via max_chars. Prevents a single tool call from flooding + // the context window; callers page forward using offset. + maxBodyChars = 4000 // totalCountUnknown is returned when the backend cannot report a full match // count (body FTS fallback, hybrid/vector ranking depth, or list_messages // without a separate count query). Clients should use has_more for paging. @@ -644,6 +653,109 @@ func (h *handlers) filterFromFindSimilarArgs(ctx context.Context, args map[strin return f, nil } +// bodyByteSliceRange returns a UTF-8-safe subslice of body[start:end] and the +// adjusted byte offsets actually used. adjEnd is exclusive; callers use it for +// has_more and sequential paging via offset += body_returned. +func bodyByteSliceRange(body string, start, end int) (text string, adjStart, adjEnd int) { + if start < 0 { + start = 0 + } + if end > len(body) { + end = len(body) + } + if start >= len(body) { + return "", len(body), len(body) + } + if start >= end { + return oneRuneSlice(body, start) + } + + adjStart, adjEnd = start, end + for adjStart < adjEnd && !utf8.RuneStart(body[adjStart]) { + adjStart++ + } + for adjEnd > adjStart && adjEnd < len(body) && !utf8.RuneStart(body[adjEnd]) { + adjEnd-- + } + for adjEnd > adjStart { + s := body[adjStart:adjEnd] + if utf8.ValidString(s) { + return s, adjStart, adjEnd + } + adjEnd-- + } + return oneRuneSlice(body, adjStart) +} + +// oneRuneSlice returns a single rune starting at or after start so tiny windows +// and mid-rune offsets still advance sequential paging. +func oneRuneSlice(body string, start int) (text string, adjStart, adjEnd int) { + adjStart = start + for adjStart < len(body) && !utf8.RuneStart(body[adjStart]) { + adjStart++ + } + if adjStart >= len(body) { + return "", len(body), len(body) + } + _, size := utf8.DecodeRuneInString(body[adjStart:]) + if size <= 0 { + return "", adjStart, adjStart + } + adjEnd = min(len(body), adjStart+size) + return body[adjStart:adjEnd], adjStart, adjEnd +} + +// bodyByteSlice returns body[start:end], nudging boundaries inward so the +// result is always valid UTF-8. MCP body APIs use byte offsets; without +// this, a window can split a multibyte rune (emoji, CJK, accented letters). +func bodyByteSlice(body string, start, end int) string { + text, _, _ := bodyByteSliceRange(body, start, end) + return text +} + +// contextWindow returns byte offsets [start, end) for a window of up to +// contextChars bytes centered on a match at pos with byte length termLen. +func contextWindow(bodyLen, pos, termLen, contextChars int) (start, end int) { + start = pos - (contextChars-termLen)/2 + end = start + contextChars + if start < 0 { + start = 0 + end = min(bodyLen, contextChars) + } else if end > bodyLen { + end = bodyLen + start = max(0, end-contextChars) + } + return start, end +} + +type getMessageResponse struct { + ID int64 `json:"id"` + SourceMessageID string `json:"source_message_id"` + ConversationID int64 `json:"conversation_id"` + SourceConversationID string `json:"source_conversation_id"` + Subject string `json:"subject"` + MessageType string `json:"message_type,omitempty"` + Snippet string `json:"snippet"` + SentAt time.Time `json:"sent_at"` + ReceivedAt *time.Time `json:"received_at,omitempty"` + DeletedAt *time.Time `json:"deleted_at,omitempty"` + SizeEstimate int64 `json:"size_estimate"` + HasAttachments bool `json:"has_attachments"` + From []query.Address `json:"from"` + To []query.Address `json:"to"` + Cc []query.Address `json:"cc"` + Bcc []query.Address `json:"bcc"` + BodyText string `json:"body_text"` + BodyHTML string `json:"body_html"` + BodyFormat string `json:"body_format,omitempty"` + BodyLength int `json:"body_length"` + BodyReturned int `json:"body_returned"` + Offset int `json:"offset"` + HasMore bool `json:"has_more"` + Labels []string `json:"labels"` + Attachments []query.AttachmentInfo `json:"attachments"` +} + func (h *handlers) getMessage(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) { args := req.GetArguments() @@ -656,8 +768,87 @@ func (h *handlers) getMessage(ctx context.Context, req mcp.CallToolRequest) (*mc if err != nil { return mcp.NewToolResultError(fmt.Sprintf("message not found: %v", err)), nil } + if msg == nil { + return mcp.NewToolResultError("message not found"), nil + } + + maxChars := intArg(args, "max_chars", defaultBodyChars) + if maxChars <= 0 { + maxChars = defaultBodyChars + } else if maxChars > maxBodyChars { + maxChars = maxBodyChars + } - return jsonResult(msg) + requestedBodyFormat, _ := args["body_format"].(string) + if requestedBodyFormat == "" { + requestedBodyFormat = bodyFormatAuto + } + + fullBody := msg.BodyText + bodyFormat := bodyFormatText + switch requestedBodyFormat { + case bodyFormatAuto: + if fullBody == "" && msg.BodyHTML != "" { + fullBody = msg.BodyHTML + bodyFormat = bodyFormatHTML + } + case bodyFormatText: + case bodyFormatHTML: + fullBody = msg.BodyHTML + bodyFormat = bodyFormatHTML + default: + return mcp.NewToolResultError("body_format must be one of auto, text, html"), nil + } + bodyLen := len(fullBody) + + var start, end int + fullBodyRequested, _ := args["full_body"].(bool) + if fullBodyRequested { + start, end = 0, bodyLen + } else if centerAt := intArg(args, "center_at", -1); centerAt >= 0 { + // Center the window on the given byte offset. contextWindow handles + // clamping to body boundaries. + start, end = contextWindow(bodyLen, centerAt, 0, maxChars) + } else { + start = min(intArg(args, "offset", 0), bodyLen) + end = min(start+maxChars, bodyLen) + } + + bodySlice, sliceStart, sliceEnd := bodyByteSliceRange(fullBody, start, end) + bodyText := bodySlice + bodyHTML := "" + if bodyFormat == bodyFormatHTML { + bodyText = "" + bodyHTML = bodySlice + } + + return jsonResult(getMessageResponse{ + ID: msg.ID, + SourceMessageID: msg.SourceMessageID, + ConversationID: msg.ConversationID, + SourceConversationID: msg.SourceConversationID, + Subject: msg.Subject, + MessageType: msg.MessageType, + Snippet: msg.Snippet, + SentAt: msg.SentAt, + ReceivedAt: msg.ReceivedAt, + DeletedAt: msg.DeletedAt, + SizeEstimate: msg.SizeEstimate, + HasAttachments: msg.HasAttachments, + From: msg.From, + To: msg.To, + Cc: msg.Cc, + Bcc: msg.Bcc, + BodyText: bodyText, + BodyHTML: bodyHTML, + BodyFormat: bodyFormat, + BodyLength: bodyLen, + BodyReturned: len(bodySlice), + Offset: sliceStart, + HasMore: sliceEnd < bodyLen, + Labels: msg.Labels, + Attachments: msg.Attachments, + }) } const maxAttachmentSize = 50 * 1024 * 1024 // 50MB @@ -943,6 +1134,18 @@ func (h *handlers) aggregate(ctx context.Context, req mcp.CallToolRequest) (*mcp return jsonResult(rows) } +// intArg extracts a non-negative integer from a map, with a default. +func intArg(args map[string]any, key string, def int) int { + v, ok := args[key].(float64) + if !ok { + return def + } + if math.IsNaN(v) || v < 0 || math.IsInf(v, 1) || v > float64(math.MaxInt) { + return def + } + return int(v) +} + // limitArg extracts a non-negative integer limit from a map, with a default. // JSON numbers arrive as float64. Clamps to maxLimit to prevent excessive // result sets. diff --git a/internal/mcp/server.go b/internal/mcp/server.go index 58a3c23e..7d4dc720 100644 --- a/internal/mcp/server.go +++ b/internal/mcp/server.go @@ -234,12 +234,33 @@ func searchMessagesTool(vectorAvailable bool) mcp.Tool { func getMessageTool() mcp.Tool { return mcp.NewTool(ToolGetMessage, - mcp.WithDescription("Get full message details including body text, recipients, labels, and attachments by message ID."), + mcp.WithDescription("Get message details including recipients, labels, attachments, and a slice of the message body. "+ + "Returns plain text when available; HTML-only messages return a body_html slice with body_format=html. "+ + "Body paging mirrors search pagination: body_length=total bytes, offset=where this chunk starts, body_returned=bytes in this chunk, has_more=more body follows. "+ + "To read sequentially: call again with offset += body_returned. "+ + "To jump to a known match location: use center_at= to center the window on that location. "+ + "Note: snippet is pre-stored source metadata (may be empty for non-Gmail sources)."), mcp.WithReadOnlyHintAnnotation(true), mcp.WithNumber("id", mcp.Required(), mcp.Description("Message ID"), ), + mcp.WithNumber("offset", + mcp.Description("Byte offset from the start of the selected body to begin reading (default 0). Ignored when center_at is provided."), + ), + mcp.WithNumber("center_at", + mcp.Description("Byte offset from the start of the selected body to center the window on. Takes precedence over offset."), + ), + mcp.WithNumber("max_chars", + mcp.Description("Maximum selected-body bytes to return (default 2000, max 4000). Values above 4000 are clamped to 4000; zero or negative values use the default."), + ), + mcp.WithString("body_format", + mcp.Description("Which body representation to page: auto (default, plain text when available, HTML fallback), text, or html."), + mcp.Enum(bodyFormatAuto, bodyFormatText, bodyFormatHTML), + ), + mcp.WithBoolean("full_body", + mcp.Description("Return the complete selected body in one response, ignoring offset, center_at, and max_chars. Use only when the full content is explicitly needed."), + ), ) } diff --git a/internal/mcp/server_test.go b/internal/mcp/server_test.go index 5cdd97ac..0e9ab0a3 100644 --- a/internal/mcp/server_test.go +++ b/internal/mcp/server_test.go @@ -11,6 +11,7 @@ import ( "strings" "testing" "time" + "unicode/utf8" "github.com/mark3labs/mcp-go/mcp" assertpkg "github.com/stretchr/testify/assert" @@ -65,6 +66,19 @@ type paginatedListMessages struct { HasMore bool `json:"has_more"` } +type getMessageResp struct { + ID int64 `json:"id"` + Subject string `json:"subject"` + BodyText string `json:"body_text"` + BodyHTML string `json:"body_html"` + BodyFormat string `json:"body_format"` + BodyLength int `json:"body_length"` + BodyReturned int `json:"body_returned"` + Offset int `json:"offset"` + HasMore bool `json:"has_more"` + ConversationID int64 `json:"conversation_id"` +} + // newTestHandlers creates a handlers instance with the given mock engine. func newTestHandlers(eng *querytest.MockEngine) *handlers { return &handlers{engine: eng} @@ -600,6 +614,54 @@ func TestSearchMessages_HybridPagination_ProbeRowDetectsMore(t *testing.T) { assert.False(resp2.HasMore, "has_more page 2") } +func TestBodyByteSlice(t *testing.T) { + t.Run("ascii unchanged", func(t *testing.T) { + body := "hello world" + assertpkg.Equal(t, "hello", bodyByteSlice(body, 0, 5)) + }) + + t.Run("does not split multibyte rune", func(t *testing.T) { + body := "café" + s := bodyByteSlice(body, 0, 4) + assertpkg.True(t, utf8.ValidString(s), "result must be valid UTF-8: %q", s) + assertpkg.Equal(t, "caf", s) + }) + + t.Run("emoji not bisected", func(t *testing.T) { + body := strings.Repeat("a", 10) + "😀" + strings.Repeat("b", 10) + emojiStart := 10 + s := bodyByteSlice(body, emojiStart, emojiStart+2) + assertpkg.True(t, utf8.ValidString(s), "result must be valid UTF-8: %q", s) + wide := bodyByteSlice(body, emojiStart, emojiStart+4) + assertpkg.True(t, utf8.ValidString(wide)) + assertpkg.Equal(t, "😀", wide) + }) + + t.Run("returns adjusted offsets for paging", func(t *testing.T) { + assert := assertpkg.New(t) + body := "aaa😀bbb" + text, adjStart, adjEnd := bodyByteSliceRange(body, 0, 5) + assert.Equal("aaa", text) + assert.Equal(0, adjStart) + assert.Equal(3, adjEnd) + + text2, adjStart2, adjEnd2 := bodyByteSliceRange(body, 3, 8) + assert.True(utf8.ValidString(text2)) + assert.Equal(3, adjStart2) + assert.Equal("😀b", text2) + assert.Equal(8, adjEnd2) + }) + + t.Run("tiny window returns one rune", func(t *testing.T) { + assert := assertpkg.New(t) + body := "aaa😀bbb" + text, adjStart, adjEnd := bodyByteSliceRange(body, 3, 4) + assert.Equal("😀", text) + assert.Equal(3, adjStart) + assert.Equal(7, adjEnd) + }) +} + func TestSearchMessages_UnknownMode(t *testing.T) { h := newTestHandlers(&querytest.MockEngine{}) @@ -620,9 +682,225 @@ func TestGetMessage(t *testing.T) { h := newTestHandlers(eng) t.Run("found", func(t *testing.T) { - msg := runTool[query.MessageDetail](t, "get_message", h.getMessage, map[string]any{"id": float64(42)}) - assertpkg.Equal(t, "Test Message", msg.Subject, "subject") - assertpkg.Equal(t, "thread-xyz", msg.SourceConversationID, "SourceConversationID") + assert := assertpkg.New(t) + msg := runTool[getMessageResp](t, "get_message", h.getMessage, map[string]any{"id": float64(42)}) + assert.Equal("Test Message", msg.Subject, "subject") + assert.Equal("Hello world", msg.BodyText, "body_text") + assert.Empty(msg.BodyHTML, "body_html stripped") + assert.Equal("text", msg.BodyFormat, "body_format") + assert.Equal(11, msg.BodyLength, "body_length") + assert.Equal(11, msg.BodyReturned, "body_returned") + assert.False(msg.HasMore, "has_more") + }) + + t.Run("html-only body returns html slice", func(t *testing.T) { + assert := assertpkg.New(t) + htmlBody := "

Hello world

" + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 57: testutil.NewMessageDetail(57).WithBodyText("").WithBodyHTML(htmlBody).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{"id": float64(57)}) + assert.Empty(msg.BodyText, "body_text") + assert.Equal(htmlBody, msg.BodyHTML, "body_html") + assert.Equal("html", msg.BodyFormat, "body_format") + assert.Equal(len(htmlBody), msg.BodyLength, "body_length") + assert.Equal(len(htmlBody), msg.BodyReturned, "body_returned") + assert.False(msg.HasMore, "has_more") + }) + + t.Run("html format selects html from mixed body", func(t *testing.T) { + assert := assertpkg.New(t) + htmlBody := "

Hello HTML

" + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 58: testutil.NewMessageDetail(58). + WithBodyText("Hello text"). + WithBodyHTML(htmlBody). + BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(58), + "body_format": "html", + }) + assert.Empty(msg.BodyText, "body_text") + assert.Equal(htmlBody, msg.BodyHTML, "body_html") + assert.Equal("html", msg.BodyFormat, "body_format") + assert.Equal(len(htmlBody), msg.BodyLength, "body_length") + assert.Equal(len(htmlBody), msg.BodyReturned, "body_returned") + assert.False(msg.HasMore, "has_more") + }) + + t.Run("truncates long body", func(t *testing.T) { + assert := assertpkg.New(t) + longBody := strings.Repeat("x", 5000) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 50: testutil.NewMessageDetail(50).WithBodyText(longBody).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{"id": float64(50)}) + assert.Equal(5000, msg.BodyLength, "body_length") + assert.Equal(2000, msg.BodyReturned, "body_returned") + assert.Len(msg.BodyText, 2000, "truncated body_text") + assert.True(msg.HasMore, "has_more") + }) + + t.Run("full_body returns complete selected body", func(t *testing.T) { + assert := assertpkg.New(t) + longBody := strings.Repeat("x", 5000) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 59: testutil.NewMessageDetail(59).WithBodyText(longBody).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(59), + "full_body": true, + "max_chars": float64(10), + "offset": float64(2000), + "center_at": float64(3000), + }) + assert.Equal(longBody, msg.BodyText, "body_text") + assert.Equal("text", msg.BodyFormat, "body_format") + assert.Equal(5000, msg.BodyLength, "body_length") + assert.Equal(5000, msg.BodyReturned, "body_returned") + assert.Equal(0, msg.Offset, "offset") + assert.False(msg.HasMore, "has_more") + }) + + t.Run("offset pagination", func(t *testing.T) { + assert := assertpkg.New(t) + body := strings.Repeat("a", 3000) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 51: testutil.NewMessageDetail(51).WithBodyText(body).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(51), + "offset": float64(2000), + }) + assert.Equal(2000, msg.Offset, "offset") + assert.Equal(1000, msg.BodyReturned, "body_returned") + assert.Len(msg.BodyText, 1000, "second page length") + assert.False(msg.HasMore, "has_more") + }) + + t.Run("center_at mid-body", func(t *testing.T) { + body := strings.Repeat("a", 1000) + "KEYWORD" + strings.Repeat("z", 1000) + matchOffset := 1000 + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 52: testutil.NewMessageDetail(52).WithBodyText(body).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(52), + "center_at": float64(matchOffset), + "max_chars": float64(200), + }) + assertpkg.Contains(t, msg.BodyText, "KEYWORD") + assertpkg.LessOrEqual(t, msg.Offset, matchOffset, "window starts before match") + assertpkg.LessOrEqual(t, len(msg.BodyText), 200, "respects max_chars") + }) + + t.Run("center_at near start", func(t *testing.T) { + body := "KEYWORD" + strings.Repeat("z", 1000) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 53: testutil.NewMessageDetail(53).WithBodyText(body).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(53), + "center_at": float64(0), + "max_chars": float64(200), + }) + assertpkg.Contains(t, msg.BodyText, "KEYWORD") + assertpkg.Equal(t, 0, msg.Offset, "starts at body start") + }) + + t.Run("max_chars above cap clamps to 4000", func(t *testing.T) { + assert := assertpkg.New(t) + longBody := strings.Repeat("x", 5000) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 54: testutil.NewMessageDetail(54).WithBodyText(longBody).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(54), + "max_chars": float64(5000), + }) + assert.Equal(4000, msg.BodyReturned, "body_returned") + assert.Len(msg.BodyText, 4000, "clamped body_text") + assert.True(msg.HasMore, "has_more") + }) + + t.Run("max_chars zero uses default", func(t *testing.T) { + assert := assertpkg.New(t) + longBody := strings.Repeat("x", 5000) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 55: testutil.NewMessageDetail(55).WithBodyText(longBody).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(55), + "max_chars": float64(0), + }) + assert.Equal(2000, msg.BodyReturned, "body_returned") + assert.Len(msg.BodyText, 2000, "default body_text") + assert.True(msg.HasMore, "has_more") + }) + + t.Run("nil message without error", func(t *testing.T) { + eng2 := &querytest.MockEngine{ + GetMessageFunc: func(context.Context, int64) (*query.MessageDetail, error) { + return nil, nil //nolint:nilnil // mirrors Engine.GetMessage not-found contract + }, + } + h2 := newTestHandlers(eng2) + runToolExpectError(t, "get_message", h2.getMessage, map[string]any{"id": float64(42)}) + }) + + t.Run("utf8 sequential paging", func(t *testing.T) { + assert := assertpkg.New(t) + body := strings.Repeat("a", 10) + "😀" + strings.Repeat("b", 10) + eng2 := &querytest.MockEngine{ + Messages: map[int64]*query.MessageDetail{ + 56: testutil.NewMessageDetail(56).WithBodyText(body).BuildPtr(), + }, + } + h2 := newTestHandlers(eng2) + + var parts []string + offset := 0 + for { + msg := runTool[getMessageResp](t, "get_message", h2.getMessage, map[string]any{ + "id": float64(56), + "offset": float64(offset), + "max_chars": float64(5), + }) + parts = append(parts, msg.BodyText) + if !msg.HasMore { + break + } + offset += msg.BodyReturned + } + assert.Equal(body, strings.Join(parts, ""), "rejoined pages") }) errorCases := []struct { @@ -642,6 +920,15 @@ func TestGetMessage(t *testing.T) { } } +func TestGetMessageToolDescriptionDoesNotReferenceFutureTools(t *testing.T) { + tool := getMessageTool() + assertpkg.NotContains(t, tool.Description, "search_in_message") + centerAt := tool.InputSchema.Properties["center_at"] + raw, err := json.Marshal(centerAt) + requirepkg.NoError(t, err, "marshal center_at schema") + assertpkg.NotContains(t, string(raw), "search_in_message") +} + func TestGetStats_VectorDisabled(t *testing.T) { assert := assertpkg.New(t) eng := &querytest.MockEngine{