-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathcompute_metrics.py
More file actions
157 lines (128 loc) · 5.67 KB
/
compute_metrics.py
File metadata and controls
157 lines (128 loc) · 5.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
"""
Compute metrics from combined result JSON files.
"""
import json
from collections import defaultdict
from typing import List, Dict, Any
def compute_metrics_for_group(results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Compute all metrics for a group of results (one strategy/retrieval/top_k combo)."""
if not results:
return {}
n = len(results)
# String-based metrics
em_with = sum(r.get('em_with', False) for r in results) / n
f1_with = sum(r.get('f1_with', 0) for r in results) / n
em_without = sum(r.get('em_without', False) for r in results) / n
f1_without = sum(r.get('f1_without', 0) for r in results) / n
# LLM-judge accuracy (uses answer_with_correct from utilization probe)
n_with_correct = sum(r.get('answer_with_correct', False) for r in results)
n_without_correct = sum(r.get('answer_without_correct', False) for r in results)
accuracy_with = n_with_correct / n
accuracy_without = n_without_correct / n
# Retrieval metrics
avg_retrieval_precision = sum(r.get('retrieval_precision', 0) for r in results) / n
avg_relevant_per_query = sum(r.get('n_relevant_retrieved', 0) for r in results) / n
# Utilization breakdown
utilization_counts = defaultdict(int)
for r in results:
cat = r.get('utilization_category', 'unknown')
utilization_counts[cat] += 1
utilization = {
'ignored': utilization_counts.get('ignored', 0) / n,
'beneficial': utilization_counts.get('beneficial', 0) / n,
'harmful': utilization_counts.get('harmful', 0) / n,
'neutral': utilization_counts.get('neutral', 0) / n,
'counts': dict(utilization_counts)
}
# Failure mode breakdown
failure_counts = defaultdict(int)
for r in results:
cat = r.get('failure_category', 'unknown')
failure_counts[cat] += 1
failure_modes = {cat: count / n for cat, count in failure_counts.items()}
# By category breakdown
by_category = {}
categories = set(r.get('category') for r in results if r.get('category'))
for cat in categories:
cat_results = [r for r in results if r.get('category') == cat]
if cat_results:
n_cat = len(cat_results)
by_category[cat] = {
'n': n_cat,
'em': sum(r.get('em_with', False) for r in cat_results) / n_cat,
'f1': sum(r.get('f1_with', 0) for r in cat_results) / n_cat,
'llm_accuracy': sum(r.get('answer_with_correct', False) for r in cat_results) / n_cat,
'avg_retrieval_precision': sum(r.get('retrieval_precision', 0) for r in cat_results) / n_cat,
}
return {
'n_questions': n,
'top_k': results[0].get('top_k', 5) if results else 5,
'em_with': em_with,
'f1_with': f1_with,
'em_without': em_without,
'f1_without': f1_without,
'em_delta': em_with - em_without,
'f1_delta': f1_with - f1_without,
'accuracy_with_memory': accuracy_with,
'accuracy_without_memory': accuracy_without,
'accuracy_delta': accuracy_with - accuracy_without,
'avg_retrieval_precision': avg_retrieval_precision,
'avg_relevant_per_query': avg_relevant_per_query,
'utilization': utilization,
'failure_modes': failure_modes,
'failure_counts': dict(failure_counts),
'by_category': by_category,
}
def compute_all_metrics(results: List[Dict[str, Any]]) -> Dict[str, Any]:
"""Group results and compute metrics for each group."""
# Group by strategy / retrieval_method / top_k
groups = defaultdict(list)
for r in results:
strategy = r.get('strategy', 'unknown')
retrieval = r.get('retrieval_method', 'cosine')
top_k = r.get('top_k', 5)
key = f"{strategy} / {retrieval} / k={top_k}"
groups[key].append(r)
# Compute metrics for each group
all_metrics = {}
for key, group_results in sorted(groups.items()):
all_metrics[key] = compute_metrics_for_group(group_results)
return all_metrics
def main():
"""Load two result files, combine them, and compute metrics."""
print("Loading result files...")
# Load first file (sessions 0-3)
with open('results/combined.json', 'r') as f:
results_0_3 = json.load(f)
print(f"Loaded {len(results_0_3)} results from sessions 0-3")
# Load second file (sessions 4-9)
with open('results/results_20260211_143917_4_9.json', 'r') as f:
results_4_9 = json.load(f)
print(f"Loaded {len(results_4_9)} results from sessions 4-9")
# Combine
combined_results = results_0_3 + results_4_9
print(f"Combined: {len(combined_results)} total results")
# Save combined results
print("\nSaving combined results...")
with open('results/combined_full.json', 'w') as f:
json.dump(combined_results, f, indent=2)
print("✓ Saved to results/combined_full.json")
# Compute metrics
print("\nComputing metrics...")
metrics = compute_all_metrics(combined_results)
# Print summary
print("\nMetrics summary:")
for key, m in metrics.items():
print(f"\n{key}:")
print(f" N questions: {m['n_questions']}")
print(f" F1 with memory: {m['f1_with']:.4f}")
print(f" LLM accuracy: {m['accuracy_with_memory']:.4f}")
print(f" Retrieval precision: {m['avg_retrieval_precision']:.4f}")
# Save metrics
print("\nSaving metrics...")
with open('results/combined_full_metrics.json', 'w') as f:
json.dump(metrics, f, indent=2)
print("✓ Saved to results/combined_full_metrics.json")
print("\n✓ Done!")
if __name__ == '__main__':
main()