Skip to content

Commit f2d90d0

Browse files
federetykMattdl
authored andcommitted
refactor: make language_aggregation_mode a non-optional parameter in evaluate()
Address PR #34 review feedback from @Mattdl. The aggregation mode now flows consistently through the entire evaluation and aggregation pipeline instead of being an optional parameter.
1 parent c93edd9 commit f2d90d0

4 files changed

Lines changed: 289 additions & 115 deletions

File tree

src/workrb/metrics/reporting.py

Lines changed: 55 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Literal
88

99
from workrb.results import BenchmarkResults
10+
from workrb.types import LanguageAggregationMode
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -21,6 +22,7 @@ def format_results(
2122
show_error: bool = True,
2223
error_type: Literal["ci_margin", "stderr", "std"] = "ci_margin",
2324
show_only_key_metrics: bool = True,
25+
language_aggregation_mode: LanguageAggregationMode | None = None,
2426
) -> str:
2527
"""
2628
Display benchmark results using BenchmarkResults aggregation methods.
@@ -36,11 +38,19 @@ def format_results(
3638
show_error: Whether to show error bars
3739
error_type: Type of error to show - "ci_margin", "stderr", or "std"
3840
show_only_key_metrics: If True, only show key metrics defined in task groups
41+
language_aggregation_mode: How to determine the grouping language for
42+
aggregation. When ``None``, reads the mode stored in
43+
``results.metadata.language_aggregation_mode``.
3944
4045
Returns
4146
-------
4247
String containing formatted results
4348
"""
49+
if language_aggregation_mode is None:
50+
language_aggregation_mode = LanguageAggregationMode(
51+
results.metadata.language_aggregation_mode
52+
)
53+
4454
# Get aggregations - always include mean and error_type
4555
aggregations = ("mean", error_type) if show_error else ("mean",)
4656

@@ -50,30 +60,67 @@ def format_results(
5060
for metrics in results.key_metrics_by_task_group.values():
5161
key_metrics.update(metrics)
5262

63+
# Compute all aggregation levels at once
64+
all_results = results._get_summary_metrics(
65+
aggregations=aggregations,
66+
language_aggregation_mode=language_aggregation_mode,
67+
)
68+
69+
# Partition results by tag name prefix for selective display
70+
results_by_level: dict[str, dict] = {
71+
"mean_per_task": {},
72+
"mean_per_task_group": {},
73+
"mean_per_language": {},
74+
"mean_benchmark": {},
75+
}
76+
for tag, value in all_results.items():
77+
if tag.name in results_by_level:
78+
results_by_level[tag.name][tag] = value
79+
5380
# Display each requested aggregation level
5481
metric_strs = []
5582
if display_per_task:
56-
agg_results = results._aggregate_per_task(aggregations=aggregations)
5783
metric_strs.append(
58-
_display_aggregation(agg_results, key_metrics, value_format, show_error, error_type)
84+
_display_aggregation(
85+
results_by_level["mean_per_task"],
86+
key_metrics,
87+
value_format,
88+
show_error,
89+
error_type,
90+
)
5991
)
6092

6193
if display_per_task_group:
62-
agg_results = results._aggregate_per_task_group(aggregations=aggregations)
6394
metric_strs.append(
64-
_display_aggregation(agg_results, key_metrics, value_format, show_error, error_type)
95+
_display_aggregation(
96+
results_by_level["mean_per_task_group"],
97+
key_metrics,
98+
value_format,
99+
show_error,
100+
error_type,
101+
)
65102
)
66103

67104
if display_per_language:
68-
agg_results = results._aggregate_per_language(aggregations=aggregations)
69105
metric_strs.append(
70-
_display_aggregation(agg_results, key_metrics, value_format, show_error, error_type)
106+
_display_aggregation(
107+
results_by_level["mean_per_language"],
108+
key_metrics,
109+
value_format,
110+
show_error,
111+
error_type,
112+
)
71113
)
72114

73115
if display_overall:
74-
agg_results = results._aggregate_benchmark(aggregations=aggregations)
75116
metric_strs.append(
76-
_display_aggregation(agg_results, key_metrics, value_format, show_error, error_type)
117+
_display_aggregation(
118+
results_by_level["mean_benchmark"],
119+
key_metrics,
120+
value_format,
121+
show_error,
122+
error_type,
123+
)
77124
)
78125

79126
return "\n".join(metric_strs)

src/workrb/results.py

Lines changed: 77 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class BenchmarkMetadata(BaseModel):
6060
num_tasks: int = Field(ge=1)
6161
languages: list[str]
6262
resumed_from_checkpoint: bool = False
63+
language_aggregation_mode: str = LanguageAggregationMode.MONOLINGUAL_ONLY.value
6364

6465

6566
class ResultTagString(BaseModel):
@@ -99,10 +100,11 @@ class BenchmarkResults(BaseModel):
99100

100101
def __str__(self) -> str:
101102
"""String representation of the benchmark results."""
103+
mode = LanguageAggregationMode(self.metadata.language_aggregation_mode)
102104
lines = [
103105
"BenchmarkResults",
104106
"=" * 80,
105-
pprint.pformat(self.get_summary_metrics()),
107+
pprint.pformat(self.get_summary_metrics(language_aggregation_mode=mode)),
106108
]
107109
return "\n".join(lines)
108110

@@ -126,43 +128,98 @@ def get_summary_metrics(
126128
How to determine the grouping language for per-language aggregation.
127129
Defaults to ``MONOLINGUAL_ONLY``.
128130
"""
129-
mean_per_task = self._aggregate_per_task(
131+
combined = self._get_summary_metrics(
132+
aggregations=aggregations,
133+
language_aggregation_mode=language_aggregation_mode,
134+
)
135+
return {str(k): v for k, v in combined.items()}
136+
137+
def _get_summary_metrics(
138+
self,
139+
aggregations: tuple = ("mean", "ci_margin"),
140+
language_aggregation_mode: LanguageAggregationMode = LanguageAggregationMode.MONOLINGUAL_ONLY,
141+
) -> dict[ResultTagString, float]:
142+
"""Compute all aggregation levels and return combined results.
143+
144+
Returns a single dict with ``ResultTagString`` keys covering:
145+
``mean_per_task``, ``mean_per_task_group``, ``mean_per_task_type``,
146+
``mean_per_language``, and ``mean_benchmark``.
147+
148+
Parameters
149+
----------
150+
aggregations : tuple
151+
Statistics to compute (e.g. ``"mean"``, ``"ci_margin"``).
152+
language_aggregation_mode : LanguageAggregationMode
153+
How to determine the grouping language for aggregation.
154+
"""
155+
mean_per_task = self._aggregate_datasetids_per_task(
156+
language_aggregation_mode=language_aggregation_mode,
130157
aggregations=aggregations,
131158
)
132159
mean_per_task_group = self._aggregate_per_task_group(
133-
aggregations=aggregations, task_results=mean_per_task
160+
language_aggregation_mode=language_aggregation_mode,
161+
aggregations=aggregations,
162+
task_results=mean_per_task,
134163
)
135164
mean_per_task_type = self._aggregate_per_task_type(
136-
aggregations=aggregations, task_group_results=mean_per_task_group
165+
language_aggregation_mode=language_aggregation_mode,
166+
aggregations=aggregations,
167+
task_group_results=mean_per_task_group,
137168
)
138169
mean_benchmark = self._aggregate_benchmark(
139-
aggregations=aggregations, task_type_results=mean_per_task_type
170+
language_aggregation_mode=language_aggregation_mode,
171+
aggregations=aggregations,
172+
task_type_results=mean_per_task_type,
140173
)
141174
mean_per_language = self._aggregate_per_language(
142175
aggregations=aggregations,
143176
aggregation_mode=language_aggregation_mode,
144177
)
145178

146-
combined = {
179+
return {
147180
**mean_per_language,
148181
**mean_per_task,
149182
**mean_per_task_group,
150183
**mean_per_task_type,
151184
**mean_benchmark,
152185
}
153-
return {str(k): v for k, v in combined.items()}
154186

155-
def _aggregate_per_task(
187+
def _aggregate_datasetids_per_task(
156188
self,
189+
language_aggregation_mode: LanguageAggregationMode,
157190
tag_name: str = "mean_per_task",
158191
aggregations: tuple = ("mean", "stderr", "ci_margin"),
159192
) -> dict[ResultTagString, float]:
160-
"""Aggregate results per task, by aggregating over languages within tasks."""
161-
# Collect metric values per task
193+
"""Aggregate dataset results per task, filtering by language aggregation mode.
194+
195+
For each task, only datasets compatible with the given
196+
``language_aggregation_mode`` are included in the per-task average.
197+
Incompatible datasets are skipped with a warning, using the same
198+
``_get_language_grouping_key`` logic as ``_aggregate_per_language``.
199+
200+
This is the root aggregation level: per-task results feed into
201+
per-task-group, per-task-type, and benchmark-level aggregations,
202+
so filtering here ensures consistency across the entire chain.
203+
"""
162204
raw_results = defaultdict(list)
163205
for task_name, task_result in self.task_results.items():
164-
for lang_metrics_result in task_result.datasetid_results.values():
165-
for metric_name, metric_value in lang_metrics_result.metrics_dict.items():
206+
for dataset_id, metrics_result in task_result.datasetid_results.items():
207+
language_key = self._get_language_grouping_key(
208+
metrics_result, language_aggregation_mode
209+
)
210+
if language_key is None:
211+
logger.warning(
212+
"Skipping dataset '%s' of task '%s' in per-task aggregation: "
213+
"incompatible with mode '%s' "
214+
"(input_languages=%s, output_languages=%s).",
215+
dataset_id,
216+
task_name,
217+
language_aggregation_mode.value,
218+
metrics_result.input_languages,
219+
metrics_result.output_languages,
220+
)
221+
continue
222+
for metric_name, metric_value in metrics_result.metrics_dict.items():
166223
raw_results[(task_name, metric_name)].append(metric_value)
167224

168225
# Compute stats
@@ -179,6 +236,7 @@ def _aggregate_per_task(
179236

180237
def _aggregate_per_task_group(
181238
self,
239+
language_aggregation_mode: LanguageAggregationMode,
182240
tag_name: str = "mean_per_task_group",
183241
aggregations: tuple = ("mean", "stderr", "ci_margin"),
184242
task_results: dict[ResultTagString, float] | None = None,
@@ -187,7 +245,9 @@ def _aggregate_per_task_group(
187245
188246
First aggregates over languages within tasks, then over tasks within task groups.
189247
"""
190-
task_results = task_results or self._aggregate_per_task(aggregations=("mean",))
248+
task_results = task_results or self._aggregate_datasetids_per_task(
249+
language_aggregation_mode=language_aggregation_mode, aggregations=("mean",)
250+
)
191251

192252
task_group_list_results = defaultdict(list)
193253
for task_result_tag, value in task_results.items():
@@ -221,6 +281,7 @@ def _aggregate_per_task_group(
221281

222282
def _aggregate_per_task_type(
223283
self,
284+
language_aggregation_mode: LanguageAggregationMode,
224285
tag_name: str = "mean_per_task_type",
225286
aggregations: tuple = ("mean", "stderr", "ci_margin"),
226287
task_group_results: dict[ResultTagString, float] | None = None,
@@ -231,7 +292,7 @@ def _aggregate_per_task_type(
231292
then over task groups within task types.
232293
"""
233294
task_group_results = task_group_results or self._aggregate_per_task_group(
234-
aggregations=("mean",)
295+
language_aggregation_mode=language_aggregation_mode, aggregations=("mean",)
235296
)
236297

237298
# Mapping from task group name to task type name
@@ -275,6 +336,7 @@ def _aggregate_per_task_type(
275336

276337
def _aggregate_benchmark(
277338
self,
339+
language_aggregation_mode: LanguageAggregationMode,
278340
tag_name: str = "mean_benchmark",
279341
aggregations: tuple = ("mean", "stderr", "ci_margin"),
280342
task_type_results: dict[ResultTagString, float] | None = None,
@@ -288,7 +350,7 @@ def _aggregate_benchmark(
288350
4. Aggregates over task types for final benchmark scores
289351
"""
290352
task_type_results = task_type_results or self._aggregate_per_task_type(
291-
aggregations=("mean",)
353+
language_aggregation_mode=language_aggregation_mode, aggregations=("mean",)
292354
)
293355

294356
metric_list_results = defaultdict(list)

0 commit comments

Comments
 (0)