2323 GoldMedalistsAggregation ,
2424 JoinedBeforeDateAggregation ,
2525 SingleAggregation ,
26+ compute_weighted_semi_standard_deviations ,
27+ )
28+ from utils .typing import (
29+ ForecastValues ,
30+ ForecastsValues ,
31+ Weights ,
2632)
2733
2834
@@ -46,6 +52,64 @@ def test_summarize_array(array, max_size, expceted_array):
4652
4753class TestAggregations :
4854
55+ @pytest .mark .parametrize (
56+ "forecasts_values, weights, expected" ,
57+ [
58+ (
59+ [[0.5 , 0.5 ]],
60+ None ,
61+ ([0.0 , 0.0 ], [0.0 , 0.0 ]),
62+ ), # Trivial
63+ (
64+ [
65+ [0.5 , 0.5 ],
66+ [0.5 , 0.5 ],
67+ [0.5 , 0.5 ],
68+ ],
69+ None ,
70+ ([0.0 , 0.0 ], [0.0 , 0.0 ]),
71+ ), # 3 unwavaring forecasts
72+ (
73+ [
74+ [0.2 , 0.8 ],
75+ [0.5 , 0.5 ],
76+ [0.8 , 0.2 ],
77+ ],
78+ None ,
79+ ([0.3 , 0.3 ], [0.3 , 0.3 ]),
80+ ), # 3 unwavaring forecasts
81+ (
82+ [
83+ [0.6 , 0.15 , None , 0.25 ],
84+ [0.6 , 0.15 , None , 0.25 ],
85+ ],
86+ None ,
87+ ([0.0 , 0.0 , 0.0 , 0.0 ], [0.0 , 0.0 , 0.0 , 0.0 ]),
88+ ), # identical forecasts with placeholders
89+ (
90+ [
91+ [0.4 , 0.25 , None , 0.35 ],
92+ [0.6 , 0.15 , None , 0.25 ],
93+ ],
94+ None ,
95+ ([0.1 , 0.05 , 0.0 , 0.05 ], [0.1 , 0.05 , 0.0 , 0.05 ]),
96+ ), # minorly different forecasts with placeholders
97+ ],
98+ )
99+ def test_compute_weighted_semi_standard_deviations (
100+ self ,
101+ forecasts_values : ForecastsValues ,
102+ weights : Weights | None ,
103+ expected : tuple [ForecastValues , ForecastValues ],
104+ ):
105+ result = compute_weighted_semi_standard_deviations (forecasts_values , weights )
106+ rl , ru = result
107+ el , eu = expected
108+ for v , e in zip (rl , el ):
109+ np .testing .assert_approx_equal (v , e )
110+ for v , e in zip (ru , eu ):
111+ np .testing .assert_approx_equal (v , e )
112+
49113 @pytest .mark .parametrize ("aggregation_name" , [Agg .method for Agg in AGGREGATIONS ])
50114 def test_aggregations_initialize (
51115 self , question_binary : Question , aggregation_name : str
@@ -241,46 +305,120 @@ def test_aggregations_initialize(
241305 histogram = None ,
242306 ),
243307 ),
308+ # Multiple choice with placeholders
309+ (
310+ {},
311+ ForecastSet (
312+ forecasts_values = [
313+ [0.6 , 0.15 , None , 0.25 ],
314+ [0.6 , 0.25 , None , 0.15 ],
315+ ],
316+ timestep = datetime (2024 , 1 , 1 , tzinfo = dt_timezone .utc ),
317+ forecaster_ids = [1 , 2 ],
318+ timesteps = [
319+ datetime (2022 , 1 , 1 , tzinfo = dt_timezone .utc ),
320+ datetime (2023 , 1 , 1 , tzinfo = dt_timezone .utc ),
321+ ],
322+ ),
323+ True ,
324+ False ,
325+ AggregateForecast (
326+ start_time = datetime (2024 , 1 , 1 , tzinfo = dt_timezone .utc ),
327+ method = AggregationMethod .UNWEIGHTED ,
328+ forecast_values = [0.6 , 0.20 , None , 0.20 ],
329+ interval_lower_bounds = [0.6 , 0.15 , None , 0.15 ],
330+ centers = [0.6 , 0.20 , None , 0.20 ],
331+ interval_upper_bounds = [0.6 , 0.25 , None , 0.25 ],
332+ means = [0.6 , 0.20 , None , 0.20 ],
333+ forecaster_count = 2 ,
334+ ),
335+ ),
336+ (
337+ {},
338+ ForecastSet (
339+ forecasts_values = [
340+ [0.6 , 0.15 , None , 0.25 ],
341+ [0.6 , 0.25 , None , 0.15 ],
342+ [0.4 , 0.35 , None , 0.25 ],
343+ ],
344+ timestep = datetime (2024 , 1 , 1 , tzinfo = dt_timezone .utc ),
345+ forecaster_ids = [1 , 2 ],
346+ timesteps = [
347+ datetime (2022 , 1 , 1 , tzinfo = dt_timezone .utc ),
348+ datetime (2023 , 1 , 1 , tzinfo = dt_timezone .utc ),
349+ ],
350+ ),
351+ True ,
352+ False ,
353+ AggregateForecast (
354+ start_time = datetime (2024 , 1 , 1 , tzinfo = dt_timezone .utc ),
355+ method = AggregationMethod .UNWEIGHTED ,
356+ forecast_values = [
357+ 0.5453965360072925 ,
358+ 0.22730173199635367 ,
359+ None ,
360+ 0.22730173199635367 ,
361+ ],
362+ interval_lower_bounds = [
363+ 0.3635976906715284 ,
364+ 0.1363810391978122 ,
365+ None ,
366+ 0.1363810391978122 ,
367+ ],
368+ centers = [
369+ 0.5453965360072925 ,
370+ 0.22730173199635367 ,
371+ None ,
372+ 0.22730173199635367 ,
373+ ],
374+ interval_upper_bounds = [
375+ 0.5453965360072925 ,
376+ 0.3182224247948951 ,
377+ None ,
378+ 0.22730173199635367 ,
379+ ],
380+ means = [
381+ 0.5333333333333333 ,
382+ 0.25 ,
383+ None ,
384+ 0.21666666666666667 ,
385+ ],
386+ forecaster_count = 3 ,
387+ ),
388+ ),
244389 ],
245390 )
246391 def test_UnweightedAggregation (
247392 self ,
248393 question_binary : Question ,
394+ question_multiple_choice : Question ,
249395 init_params : dict ,
250396 forecast_set : ForecastSet ,
251397 include_stats : bool ,
252398 histogram : bool ,
253399 expected : AggregateForecast ,
254400 ):
255- aggregation = UnweightedAggregation (question = question_binary , ** init_params )
256- new_aggregation = aggregation .calculate_aggregation_entry (
401+ if len (forecast_set .forecasts_values [0 ]) == 2 :
402+ question = question_binary
403+ else :
404+ question = question_multiple_choice
405+
406+ aggregation = UnweightedAggregation (question = question , ** init_params )
407+ new_aggregation : AggregateForecast = aggregation .calculate_aggregation_entry (
257408 forecast_set , include_stats , histogram
258409 )
259410
260- assert new_aggregation .start_time == expected .start_time
261- assert (
262- new_aggregation .forecast_values == expected .forecast_values
263- ) or np .allclose (new_aggregation .forecast_values , expected .forecast_values )
264- assert new_aggregation .forecaster_count == expected .forecaster_count
265- assert (
266- new_aggregation .interval_lower_bounds == expected .interval_lower_bounds
267- ) or np .allclose (
268- new_aggregation .interval_lower_bounds , expected .interval_lower_bounds
269- )
270- assert (new_aggregation .centers == expected .centers ) or np .allclose (
271- new_aggregation .centers , expected .centers
272- )
273- assert (
274- new_aggregation .interval_upper_bounds == expected .interval_upper_bounds
275- ) or np .allclose (
276- new_aggregation .interval_upper_bounds , expected .interval_upper_bounds
277- )
278- assert (new_aggregation .means == expected .means ) or np .allclose (
279- new_aggregation .means , expected .means
280- )
281- assert (new_aggregation .histogram == expected .histogram ) or np .allclose (
282- new_aggregation .histogram , expected .histogram
283- )
411+ for r , e in [
412+ (new_aggregation .forecast_values , expected .forecast_values ),
413+ (new_aggregation .interval_lower_bounds , expected .interval_lower_bounds ),
414+ (new_aggregation .centers , expected .centers ),
415+ (new_aggregation .interval_upper_bounds , expected .interval_upper_bounds ),
416+ (new_aggregation .means , expected .means ),
417+ (new_aggregation .histogram , expected .histogram ),
418+ ]:
419+ r = np .where (np .equal (r , None ), np .nan , r ).astype (float )
420+ e = np .where (np .equal (e , None ), np .nan , e ).astype (float )
421+ np .testing .assert_allclose (r , e , equal_nan = True )
284422
285423 @pytest .mark .parametrize (
286424 "init_params, forecast_set, include_stats, histogram, expected" ,
@@ -468,20 +606,52 @@ def test_UnweightedAggregation(
468606 histogram = None ,
469607 ),
470608 ),
609+ # Multiple choice with placeholders
610+ (
611+ {},
612+ ForecastSet (
613+ forecasts_values = [
614+ [0.6 , 0.15 , None , 0.25 ],
615+ [0.6 , 0.25 , None , 0.15 ],
616+ ],
617+ timestep = datetime (2024 , 1 , 1 , tzinfo = dt_timezone .utc ),
618+ forecaster_ids = [1 , 2 ],
619+ timesteps = [
620+ datetime (2022 , 1 , 1 , tzinfo = dt_timezone .utc ),
621+ datetime (2023 , 1 , 1 , tzinfo = dt_timezone .utc ),
622+ ],
623+ ),
624+ True ,
625+ False ,
626+ AggregateForecast (
627+ start_time = datetime (2024 , 1 , 1 , tzinfo = dt_timezone .utc ),
628+ method = AggregationMethod .UNWEIGHTED ,
629+ forecast_values = [0.6 , 0.20 , None , 0.20 ],
630+ interval_lower_bounds = [0.6 , 0.15 , None , 0.15 ],
631+ centers = [0.6 , 0.20 , None , 0.20 ],
632+ interval_upper_bounds = [0.6 , 0.25 , None , 0.25 ],
633+ means = [0.6 , 0.20 , None , 0.20 ],
634+ forecaster_count = 2 ,
635+ ),
636+ ),
471637 ],
472638 )
473639 def test_RecencyWeightedAggregation (
474640 self ,
475641 question_binary : Question ,
642+ question_multiple_choice : Question ,
476643 init_params : dict ,
477644 forecast_set : ForecastSet ,
478645 include_stats : bool ,
479646 histogram : bool ,
480647 expected : AggregateForecast ,
481648 ):
482- aggregation = RecencyWeightedAggregation (
483- question = question_binary , ** init_params
484- )
649+ if len (forecast_set .forecasts_values [0 ]) == 2 :
650+ question = question_binary
651+ else :
652+ question = question_multiple_choice
653+
654+ aggregation = RecencyWeightedAggregation (question = question , ** init_params )
485655 new_aggregation = aggregation .calculate_aggregation_entry (
486656 forecast_set , include_stats , histogram
487657 )
0 commit comments