@@ -60,6 +60,7 @@ class BenchmarkMetadata(BaseModel):
6060 num_tasks : int = Field (ge = 1 )
6161 languages : list [str ]
6262 resumed_from_checkpoint : bool = False
63+ language_aggregation_mode : str = LanguageAggregationMode .MONOLINGUAL_ONLY .value
6364
6465
6566class ResultTagString (BaseModel ):
@@ -99,10 +100,11 @@ class BenchmarkResults(BaseModel):
99100
100101 def __str__ (self ) -> str :
101102 """String representation of the benchmark results."""
103+ mode = LanguageAggregationMode (self .metadata .language_aggregation_mode )
102104 lines = [
103105 "BenchmarkResults" ,
104106 "=" * 80 ,
105- pprint .pformat (self .get_summary_metrics ()),
107+ pprint .pformat (self .get_summary_metrics (language_aggregation_mode = mode )),
106108 ]
107109 return "\n " .join (lines )
108110
@@ -126,43 +128,98 @@ def get_summary_metrics(
126128 How to determine the grouping language for per-language aggregation.
127129 Defaults to ``MONOLINGUAL_ONLY``.
128130 """
129- mean_per_task = self ._aggregate_per_task (
131+ combined = self ._get_summary_metrics (
132+ aggregations = aggregations ,
133+ language_aggregation_mode = language_aggregation_mode ,
134+ )
135+ return {str (k ): v for k , v in combined .items ()}
136+
137+ def _get_summary_metrics (
138+ self ,
139+ aggregations : tuple = ("mean" , "ci_margin" ),
140+ language_aggregation_mode : LanguageAggregationMode = LanguageAggregationMode .MONOLINGUAL_ONLY ,
141+ ) -> dict [ResultTagString , float ]:
142+ """Compute all aggregation levels and return combined results.
143+
144+ Returns a single dict with ``ResultTagString`` keys covering:
145+ ``mean_per_task``, ``mean_per_task_group``, ``mean_per_task_type``,
146+ ``mean_per_language``, and ``mean_benchmark``.
147+
148+ Parameters
149+ ----------
150+ aggregations : tuple
151+ Statistics to compute (e.g. ``"mean"``, ``"ci_margin"``).
152+ language_aggregation_mode : LanguageAggregationMode
153+ How to determine the grouping language for aggregation.
154+ """
155+ mean_per_task = self ._aggregate_datasetids_per_task (
156+ language_aggregation_mode = language_aggregation_mode ,
130157 aggregations = aggregations ,
131158 )
132159 mean_per_task_group = self ._aggregate_per_task_group (
133- aggregations = aggregations , task_results = mean_per_task
160+ language_aggregation_mode = language_aggregation_mode ,
161+ aggregations = aggregations ,
162+ task_results = mean_per_task ,
134163 )
135164 mean_per_task_type = self ._aggregate_per_task_type (
136- aggregations = aggregations , task_group_results = mean_per_task_group
165+ language_aggregation_mode = language_aggregation_mode ,
166+ aggregations = aggregations ,
167+ task_group_results = mean_per_task_group ,
137168 )
138169 mean_benchmark = self ._aggregate_benchmark (
139- aggregations = aggregations , task_type_results = mean_per_task_type
170+ language_aggregation_mode = language_aggregation_mode ,
171+ aggregations = aggregations ,
172+ task_type_results = mean_per_task_type ,
140173 )
141174 mean_per_language = self ._aggregate_per_language (
142175 aggregations = aggregations ,
143176 aggregation_mode = language_aggregation_mode ,
144177 )
145178
146- combined = {
179+ return {
147180 ** mean_per_language ,
148181 ** mean_per_task ,
149182 ** mean_per_task_group ,
150183 ** mean_per_task_type ,
151184 ** mean_benchmark ,
152185 }
153- return {str (k ): v for k , v in combined .items ()}
154186
155- def _aggregate_per_task (
187+ def _aggregate_datasetids_per_task (
156188 self ,
189+ language_aggregation_mode : LanguageAggregationMode ,
157190 tag_name : str = "mean_per_task" ,
158191 aggregations : tuple = ("mean" , "stderr" , "ci_margin" ),
159192 ) -> dict [ResultTagString , float ]:
160- """Aggregate results per task, by aggregating over languages within tasks."""
161- # Collect metric values per task
193+ """Aggregate dataset results per task, filtering by language aggregation mode.
194+
195+ For each task, only datasets compatible with the given
196+ ``language_aggregation_mode`` are included in the per-task average.
197+ Incompatible datasets are skipped with a warning, using the same
198+ ``_get_language_grouping_key`` logic as ``_aggregate_per_language``.
199+
200+ This is the root aggregation level: per-task results feed into
201+ per-task-group, per-task-type, and benchmark-level aggregations,
202+ so filtering here ensures consistency across the entire chain.
203+ """
162204 raw_results = defaultdict (list )
163205 for task_name , task_result in self .task_results .items ():
164- for lang_metrics_result in task_result .datasetid_results .values ():
165- for metric_name , metric_value in lang_metrics_result .metrics_dict .items ():
206+ for dataset_id , metrics_result in task_result .datasetid_results .items ():
207+ language_key = self ._get_language_grouping_key (
208+ metrics_result , language_aggregation_mode
209+ )
210+ if language_key is None :
211+ logger .warning (
212+ "Skipping dataset '%s' of task '%s' in per-task aggregation: "
213+ "incompatible with mode '%s' "
214+ "(input_languages=%s, output_languages=%s)." ,
215+ dataset_id ,
216+ task_name ,
217+ language_aggregation_mode .value ,
218+ metrics_result .input_languages ,
219+ metrics_result .output_languages ,
220+ )
221+ continue
222+ for metric_name , metric_value in metrics_result .metrics_dict .items ():
166223 raw_results [(task_name , metric_name )].append (metric_value )
167224
168225 # Compute stats
@@ -179,6 +236,7 @@ def _aggregate_per_task(
179236
180237 def _aggregate_per_task_group (
181238 self ,
239+ language_aggregation_mode : LanguageAggregationMode ,
182240 tag_name : str = "mean_per_task_group" ,
183241 aggregations : tuple = ("mean" , "stderr" , "ci_margin" ),
184242 task_results : dict [ResultTagString , float ] | None = None ,
@@ -187,7 +245,9 @@ def _aggregate_per_task_group(
187245
188246 First aggregates over languages within tasks, then over tasks within task groups.
189247 """
190- task_results = task_results or self ._aggregate_per_task (aggregations = ("mean" ,))
248+ task_results = task_results or self ._aggregate_datasetids_per_task (
249+ language_aggregation_mode = language_aggregation_mode , aggregations = ("mean" ,)
250+ )
191251
192252 task_group_list_results = defaultdict (list )
193253 for task_result_tag , value in task_results .items ():
@@ -221,6 +281,7 @@ def _aggregate_per_task_group(
221281
222282 def _aggregate_per_task_type (
223283 self ,
284+ language_aggregation_mode : LanguageAggregationMode ,
224285 tag_name : str = "mean_per_task_type" ,
225286 aggregations : tuple = ("mean" , "stderr" , "ci_margin" ),
226287 task_group_results : dict [ResultTagString , float ] | None = None ,
@@ -231,7 +292,7 @@ def _aggregate_per_task_type(
231292 then over task groups within task types.
232293 """
233294 task_group_results = task_group_results or self ._aggregate_per_task_group (
234- aggregations = ("mean" ,)
295+ language_aggregation_mode = language_aggregation_mode , aggregations = ("mean" ,)
235296 )
236297
237298 # Mapping from task group name to task type name
@@ -275,6 +336,7 @@ def _aggregate_per_task_type(
275336
276337 def _aggregate_benchmark (
277338 self ,
339+ language_aggregation_mode : LanguageAggregationMode ,
278340 tag_name : str = "mean_benchmark" ,
279341 aggregations : tuple = ("mean" , "stderr" , "ci_margin" ),
280342 task_type_results : dict [ResultTagString , float ] | None = None ,
@@ -288,7 +350,7 @@ def _aggregate_benchmark(
288350 4. Aggregates over task types for final benchmark scores
289351 """
290352 task_type_results = task_type_results or self ._aggregate_per_task_type (
291- aggregations = ("mean" ,)
353+ language_aggregation_mode = language_aggregation_mode , aggregations = ("mean" ,)
292354 )
293355
294356 metric_list_results = defaultdict (list )
0 commit comments