|
| 1 | +package claude |
| 2 | + |
| 3 | +import ( |
| 4 | + "encoding/json" |
| 5 | + "path/filepath" |
| 6 | + "runtime" |
| 7 | + "strings" |
| 8 | + "testing" |
| 9 | + "time" |
| 10 | +) |
| 11 | + |
| 12 | +// testdataPath returns the absolute path to a file in testdata/. |
| 13 | +func testdataPath(name string) string { |
| 14 | + _, file, _, _ := runtime.Caller(0) |
| 15 | + return filepath.Join(filepath.Dir(file), "testdata", name) |
| 16 | +} |
| 17 | + |
| 18 | +// --- ParseSessionFile --- |
| 19 | + |
| 20 | +func TestParseSessionFile_Simple(t *testing.T) { |
| 21 | + session, err := ParseSessionFile(testdataPath("simple.jsonl")) |
| 22 | + if err != nil { |
| 23 | + t.Fatalf("unexpected error: %v", err) |
| 24 | + } |
| 25 | + if session.ID != "abc123" { |
| 26 | + t.Errorf("ID = %q, want %q", session.ID, "abc123") |
| 27 | + } |
| 28 | + if session.UserPrompts != 1 { |
| 29 | + t.Errorf("UserPrompts = %d, want 1", session.UserPrompts) |
| 30 | + } |
| 31 | + if session.AssistantMsgs != 1 { |
| 32 | + t.Errorf("AssistantMsgs = %d, want 1", session.AssistantMsgs) |
| 33 | + } |
| 34 | + if session.Summary != "hello world" { |
| 35 | + t.Errorf("Summary = %q, want %q", session.Summary, "hello world") |
| 36 | + } |
| 37 | +} |
| 38 | + |
| 39 | +func TestParseSessionFile_ToolUse(t *testing.T) { |
| 40 | + session, err := ParseSessionFile(testdataPath("tool_use.jsonl")) |
| 41 | + if err != nil { |
| 42 | + t.Fatalf("unexpected error: %v", err) |
| 43 | + } |
| 44 | + |
| 45 | + var toolEntry *Entry |
| 46 | + for i := range session.Transcript { |
| 47 | + if session.Transcript[i].Kind == EntryToolCall { |
| 48 | + toolEntry = &session.Transcript[i] |
| 49 | + break |
| 50 | + } |
| 51 | + } |
| 52 | + if toolEntry == nil { |
| 53 | + t.Fatal("no tool_call entry found in transcript") |
| 54 | + } |
| 55 | + if toolEntry.Title != "Read" { |
| 56 | + t.Errorf("tool Title = %q, want %q", toolEntry.Title, "Read") |
| 57 | + } |
| 58 | + if toolEntry.Content != "/foo/bar.go" { |
| 59 | + t.Errorf("tool Content = %q, want %q", toolEntry.Content, "/foo/bar.go") |
| 60 | + } |
| 61 | +} |
| 62 | + |
| 63 | +func TestParseSessionFile_Empty(t *testing.T) { |
| 64 | + session, err := ParseSessionFile(testdataPath("empty.jsonl")) |
| 65 | + if err != nil { |
| 66 | + t.Fatalf("unexpected error: %v", err) |
| 67 | + } |
| 68 | + if session.UserPrompts != 0 { |
| 69 | + t.Errorf("UserPrompts = %d, want 0", session.UserPrompts) |
| 70 | + } |
| 71 | + if session.AssistantMsgs != 0 { |
| 72 | + t.Errorf("AssistantMsgs = %d, want 0", session.AssistantMsgs) |
| 73 | + } |
| 74 | + if session.Summary != "(no user prompt found)" { |
| 75 | + t.Errorf("Summary = %q, want %q", session.Summary, "(no user prompt found)") |
| 76 | + } |
| 77 | +} |
| 78 | + |
| 79 | +// --- oneLine --- |
| 80 | + |
| 81 | +func TestOneLine(t *testing.T) { |
| 82 | + cases := []struct { |
| 83 | + name string |
| 84 | + input string |
| 85 | + want string |
| 86 | + }{ |
| 87 | + { |
| 88 | + name: "empty string", |
| 89 | + input: "", |
| 90 | + want: "", |
| 91 | + }, |
| 92 | + { |
| 93 | + name: "multiline collapses to single line", |
| 94 | + input: "line one\nline two\nline three", |
| 95 | + want: "line one line two line three", |
| 96 | + }, |
| 97 | + { |
| 98 | + name: "long string truncated with ellipsis", |
| 99 | + input: strings.Repeat("a", 130), |
| 100 | + want: strings.Repeat("a", 117) + "...", |
| 101 | + }, |
| 102 | + { |
| 103 | + name: "short string unchanged", |
| 104 | + input: "short", |
| 105 | + want: "short", |
| 106 | + }, |
| 107 | + { |
| 108 | + name: "exactly 120 runes unchanged", |
| 109 | + input: strings.Repeat("x", 120), |
| 110 | + want: strings.Repeat("x", 120), |
| 111 | + }, |
| 112 | + { |
| 113 | + name: "121 runes truncated", |
| 114 | + input: strings.Repeat("x", 121), |
| 115 | + want: strings.Repeat("x", 117) + "...", |
| 116 | + }, |
| 117 | + } |
| 118 | + |
| 119 | + for _, tc := range cases { |
| 120 | + t.Run(tc.name, func(t *testing.T) { |
| 121 | + got := oneLine(tc.input) |
| 122 | + if got != tc.want { |
| 123 | + t.Errorf("oneLine(%q) = %q, want %q", tc.input, got, tc.want) |
| 124 | + } |
| 125 | + }) |
| 126 | + } |
| 127 | +} |
| 128 | + |
| 129 | +// --- truncateRunes --- |
| 130 | + |
| 131 | +func TestTruncateRunes(t *testing.T) { |
| 132 | + cases := []struct { |
| 133 | + name string |
| 134 | + input string |
| 135 | + limit int |
| 136 | + want string |
| 137 | + }{ |
| 138 | + { |
| 139 | + name: "shorter than limit unchanged", |
| 140 | + input: "hello", |
| 141 | + limit: 10, |
| 142 | + want: "hello", |
| 143 | + }, |
| 144 | + { |
| 145 | + name: "longer than limit truncated with ellipsis", |
| 146 | + input: "hello world", |
| 147 | + limit: 8, |
| 148 | + want: "hello...", |
| 149 | + }, |
| 150 | + { |
| 151 | + name: "limit <= 3 returns value unchanged", |
| 152 | + input: "hello", |
| 153 | + limit: 3, |
| 154 | + want: "hello", |
| 155 | + }, |
| 156 | + { |
| 157 | + name: "limit 1 returns value unchanged", |
| 158 | + input: "hello", |
| 159 | + limit: 1, |
| 160 | + want: "hello", |
| 161 | + }, |
| 162 | + { |
| 163 | + name: "exactly at limit unchanged", |
| 164 | + input: "hello", |
| 165 | + limit: 5, |
| 166 | + want: "hello", |
| 167 | + }, |
| 168 | + } |
| 169 | + |
| 170 | + for _, tc := range cases { |
| 171 | + t.Run(tc.name, func(t *testing.T) { |
| 172 | + got := truncateRunes(tc.input, tc.limit) |
| 173 | + if got != tc.want { |
| 174 | + t.Errorf("truncateRunes(%q, %d) = %q, want %q", tc.input, tc.limit, got, tc.want) |
| 175 | + } |
| 176 | + }) |
| 177 | + } |
| 178 | +} |
| 179 | + |
| 180 | +// --- summarizeToolInput --- |
| 181 | + |
| 182 | +func mustMarshal(t *testing.T, v any) json.RawMessage { |
| 183 | + t.Helper() |
| 184 | + b, err := json.Marshal(v) |
| 185 | + if err != nil { |
| 186 | + t.Fatalf("json.Marshal failed: %v", err) |
| 187 | + } |
| 188 | + return b |
| 189 | +} |
| 190 | + |
| 191 | +func TestSummarizeToolInput(t *testing.T) { |
| 192 | + cases := []struct { |
| 193 | + name string |
| 194 | + toolName string |
| 195 | + input map[string]any |
| 196 | + wantContains string // substring that must appear in result |
| 197 | + wantExact string // exact match (if set, checked instead of wantContains) |
| 198 | + }{ |
| 199 | + { |
| 200 | + name: "Bash with command and description", |
| 201 | + toolName: "Bash", |
| 202 | + input: map[string]any{"command": "ls -la", "description": "list files"}, |
| 203 | + wantExact: "list files ls -la", |
| 204 | + }, |
| 205 | + { |
| 206 | + name: "Read with file_path", |
| 207 | + toolName: "Read", |
| 208 | + input: map[string]any{"file_path": "/some/path.go"}, |
| 209 | + wantExact: "/some/path.go", |
| 210 | + }, |
| 211 | + { |
| 212 | + name: "WebSearch with query", |
| 213 | + toolName: "WebSearch", |
| 214 | + input: map[string]any{"query": "golang testing"}, |
| 215 | + wantExact: "golang testing", |
| 216 | + }, |
| 217 | + { |
| 218 | + name: "unknown tool falls back to key=value pairs", |
| 219 | + toolName: "UnknownTool", |
| 220 | + input: map[string]any{"alpha": "val1", "beta": "val2"}, |
| 221 | + wantContains: "alpha=val1", |
| 222 | + }, |
| 223 | + } |
| 224 | + |
| 225 | + for _, tc := range cases { |
| 226 | + t.Run(tc.name, func(t *testing.T) { |
| 227 | + raw := mustMarshal(t, tc.input) |
| 228 | + got := summarizeToolInput(tc.toolName, raw) |
| 229 | + if tc.wantExact != "" { |
| 230 | + if got != tc.wantExact { |
| 231 | + t.Errorf("summarizeToolInput(%q) = %q, want %q", tc.toolName, got, tc.wantExact) |
| 232 | + } |
| 233 | + } else if !strings.Contains(got, tc.wantContains) { |
| 234 | + t.Errorf("summarizeToolInput(%q) = %q, want it to contain %q", tc.toolName, got, tc.wantContains) |
| 235 | + } |
| 236 | + }) |
| 237 | + } |
| 238 | +} |
| 239 | + |
| 240 | +// --- normalizeRecord --- |
| 241 | + |
| 242 | +func makeTS() time.Time { |
| 243 | + ts, _ := time.Parse(time.RFC3339, "2024-01-01T10:00:00Z") |
| 244 | + return ts |
| 245 | +} |
| 246 | + |
| 247 | +func TestNormalizeRecord_UserStringContent(t *testing.T) { |
| 248 | + record := rawRecord{ |
| 249 | + Type: "user", |
| 250 | + Message: json.RawMessage(`{"role":"user","content":"hello"}`), |
| 251 | + } |
| 252 | + entries := normalizeRecord(record, makeTS()) |
| 253 | + if len(entries) != 1 { |
| 254 | + t.Fatalf("got %d entries, want 1", len(entries)) |
| 255 | + } |
| 256 | + if entries[0].Kind != EntryHumanPrompt { |
| 257 | + t.Errorf("Kind = %q, want %q", entries[0].Kind, EntryHumanPrompt) |
| 258 | + } |
| 259 | + if entries[0].Content != "hello" { |
| 260 | + t.Errorf("Content = %q, want %q", entries[0].Content, "hello") |
| 261 | + } |
| 262 | +} |
| 263 | + |
| 264 | +func TestNormalizeRecord_AssistantTextBlock(t *testing.T) { |
| 265 | + record := rawRecord{ |
| 266 | + Type: "assistant", |
| 267 | + Message: json.RawMessage(`{"role":"assistant","content":[{"type":"text","text":"good morning"}]}`), |
| 268 | + } |
| 269 | + entries := normalizeRecord(record, makeTS()) |
| 270 | + if len(entries) != 1 { |
| 271 | + t.Fatalf("got %d entries, want 1", len(entries)) |
| 272 | + } |
| 273 | + if entries[0].Kind != EntryAssistantText { |
| 274 | + t.Errorf("Kind = %q, want %q", entries[0].Kind, EntryAssistantText) |
| 275 | + } |
| 276 | + if entries[0].Content != "good morning" { |
| 277 | + t.Errorf("Content = %q, want %q", entries[0].Content, "good morning") |
| 278 | + } |
| 279 | +} |
| 280 | + |
| 281 | +func TestNormalizeRecord_ProgressBashProgress(t *testing.T) { |
| 282 | + record := rawRecord{ |
| 283 | + Type: "progress", |
| 284 | + Data: json.RawMessage(`{"type":"bash_progress","output":"running tests..."}`), |
| 285 | + } |
| 286 | + entries := normalizeRecord(record, makeTS()) |
| 287 | + if len(entries) != 1 { |
| 288 | + t.Fatalf("got %d entries, want 1", len(entries)) |
| 289 | + } |
| 290 | + if entries[0].Kind != EntryProgress { |
| 291 | + t.Errorf("Kind = %q, want %q", entries[0].Kind, EntryProgress) |
| 292 | + } |
| 293 | + if entries[0].Content != "running tests..." { |
| 294 | + t.Errorf("Content = %q, want %q", entries[0].Content, "running tests...") |
| 295 | + } |
| 296 | +} |
| 297 | + |
| 298 | +func TestNormalizeRecord_UnknownType(t *testing.T) { |
| 299 | + record := rawRecord{ |
| 300 | + Type: "totally_unknown_xyz", |
| 301 | + } |
| 302 | + entries := normalizeRecord(record, makeTS()) |
| 303 | + if len(entries) != 0 { |
| 304 | + t.Errorf("got %d entries for unknown type, want 0", len(entries)) |
| 305 | + } |
| 306 | +} |
0 commit comments