Skip to content

Commit 237bd11

Browse files
committed
mc/3806/aggregations
parent 3c40ee2 author lsabor <[email protected]> 1763674087 -0800 committer lsabor <[email protected]> 1764534427 -0800 adjust aggregations to play nicely with placeholders improve test for comput_weighted_semi_standard_deviations add support for 0.0s in prediction difference for sorting plus tests update prediction difference for display to handle placeholders
1 parent 6a938a3 commit 237bd11

File tree

6 files changed

+379
-49
lines changed

6 files changed

+379
-49
lines changed

tests/unit/test_questions/conftest.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
__all__ = [
1111
"question_binary",
12+
"question_multiple_choice",
1213
"question_numeric",
1314
"conditional_1",
1415
"question_binary_with_forecast_user_1",
Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1,4 @@
1-
from tests.unit.test_questions.conftest import question_binary # noqa
1+
from tests.unit.test_questions.conftest import (
2+
question_binary,
3+
question_multiple_choice,
4+
) # noqa

tests/unit/test_utils/test_the_math/test_aggregations.py

Lines changed: 199 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@
2323
GoldMedalistsAggregation,
2424
JoinedBeforeDateAggregation,
2525
SingleAggregation,
26+
compute_weighted_semi_standard_deviations,
27+
)
28+
from utils.typing import (
29+
ForecastValues,
30+
ForecastsValues,
31+
Weights,
2632
)
2733

2834

@@ -46,6 +52,64 @@ def test_summarize_array(array, max_size, expceted_array):
4652

4753
class TestAggregations:
4854

55+
@pytest.mark.parametrize(
56+
"forecasts_values, weights, expected",
57+
[
58+
(
59+
[[0.5, 0.5]],
60+
None,
61+
([0.0, 0.0], [0.0, 0.0]),
62+
), # Trivial
63+
(
64+
[
65+
[0.5, 0.5],
66+
[0.5, 0.5],
67+
[0.5, 0.5],
68+
],
69+
None,
70+
([0.0, 0.0], [0.0, 0.0]),
71+
), # 3 unwavaring forecasts
72+
(
73+
[
74+
[0.2, 0.8],
75+
[0.5, 0.5],
76+
[0.8, 0.2],
77+
],
78+
None,
79+
([0.3, 0.3], [0.3, 0.3]),
80+
), # 3 unwavaring forecasts
81+
(
82+
[
83+
[0.6, 0.15, None, 0.25],
84+
[0.6, 0.15, None, 0.25],
85+
],
86+
None,
87+
([0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]),
88+
), # identical forecasts with placeholders
89+
(
90+
[
91+
[0.4, 0.25, None, 0.35],
92+
[0.6, 0.15, None, 0.25],
93+
],
94+
None,
95+
([0.1, 0.05, 0.0, 0.05], [0.1, 0.05, 0.0, 0.05]),
96+
), # minorly different forecasts with placeholders
97+
],
98+
)
99+
def test_compute_weighted_semi_standard_deviations(
100+
self,
101+
forecasts_values: ForecastsValues,
102+
weights: Weights | None,
103+
expected: tuple[ForecastValues, ForecastValues],
104+
):
105+
result = compute_weighted_semi_standard_deviations(forecasts_values, weights)
106+
rl, ru = result
107+
el, eu = expected
108+
for v, e in zip(rl, el):
109+
np.testing.assert_approx_equal(v, e)
110+
for v, e in zip(ru, eu):
111+
np.testing.assert_approx_equal(v, e)
112+
49113
@pytest.mark.parametrize("aggregation_name", [Agg.method for Agg in AGGREGATIONS])
50114
def test_aggregations_initialize(
51115
self, question_binary: Question, aggregation_name: str
@@ -241,46 +305,120 @@ def test_aggregations_initialize(
241305
histogram=None,
242306
),
243307
),
308+
# Multiple choice with placeholders
309+
(
310+
{},
311+
ForecastSet(
312+
forecasts_values=[
313+
[0.6, 0.15, None, 0.25],
314+
[0.6, 0.25, None, 0.15],
315+
],
316+
timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
317+
forecaster_ids=[1, 2],
318+
timesteps=[
319+
datetime(2022, 1, 1, tzinfo=dt_timezone.utc),
320+
datetime(2023, 1, 1, tzinfo=dt_timezone.utc),
321+
],
322+
),
323+
True,
324+
False,
325+
AggregateForecast(
326+
start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
327+
method=AggregationMethod.UNWEIGHTED,
328+
forecast_values=[0.6, 0.20, None, 0.20],
329+
interval_lower_bounds=[0.6, 0.15, None, 0.15],
330+
centers=[0.6, 0.20, None, 0.20],
331+
interval_upper_bounds=[0.6, 0.25, None, 0.25],
332+
means=[0.6, 0.20, None, 0.20],
333+
forecaster_count=2,
334+
),
335+
),
336+
(
337+
{},
338+
ForecastSet(
339+
forecasts_values=[
340+
[0.6, 0.15, None, 0.25],
341+
[0.6, 0.25, None, 0.15],
342+
[0.4, 0.35, None, 0.25],
343+
],
344+
timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
345+
forecaster_ids=[1, 2],
346+
timesteps=[
347+
datetime(2022, 1, 1, tzinfo=dt_timezone.utc),
348+
datetime(2023, 1, 1, tzinfo=dt_timezone.utc),
349+
],
350+
),
351+
True,
352+
False,
353+
AggregateForecast(
354+
start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
355+
method=AggregationMethod.UNWEIGHTED,
356+
forecast_values=[
357+
0.5453965360072925,
358+
0.22730173199635367,
359+
None,
360+
0.22730173199635367,
361+
],
362+
interval_lower_bounds=[
363+
0.3635976906715284,
364+
0.1363810391978122,
365+
None,
366+
0.1363810391978122,
367+
],
368+
centers=[
369+
0.5453965360072925,
370+
0.22730173199635367,
371+
None,
372+
0.22730173199635367,
373+
],
374+
interval_upper_bounds=[
375+
0.5453965360072925,
376+
0.3182224247948951,
377+
None,
378+
0.22730173199635367,
379+
],
380+
means=[
381+
0.5333333333333333,
382+
0.25,
383+
None,
384+
0.21666666666666667,
385+
],
386+
forecaster_count=3,
387+
),
388+
),
244389
],
245390
)
246391
def test_UnweightedAggregation(
247392
self,
248393
question_binary: Question,
394+
question_multiple_choice: Question,
249395
init_params: dict,
250396
forecast_set: ForecastSet,
251397
include_stats: bool,
252398
histogram: bool,
253399
expected: AggregateForecast,
254400
):
255-
aggregation = UnweightedAggregation(question=question_binary, **init_params)
256-
new_aggregation = aggregation.calculate_aggregation_entry(
401+
if len(forecast_set.forecasts_values[0]) == 2:
402+
question = question_binary
403+
else:
404+
question = question_multiple_choice
405+
406+
aggregation = UnweightedAggregation(question=question, **init_params)
407+
new_aggregation: AggregateForecast = aggregation.calculate_aggregation_entry(
257408
forecast_set, include_stats, histogram
258409
)
259410

260-
assert new_aggregation.start_time == expected.start_time
261-
assert (
262-
new_aggregation.forecast_values == expected.forecast_values
263-
) or np.allclose(new_aggregation.forecast_values, expected.forecast_values)
264-
assert new_aggregation.forecaster_count == expected.forecaster_count
265-
assert (
266-
new_aggregation.interval_lower_bounds == expected.interval_lower_bounds
267-
) or np.allclose(
268-
new_aggregation.interval_lower_bounds, expected.interval_lower_bounds
269-
)
270-
assert (new_aggregation.centers == expected.centers) or np.allclose(
271-
new_aggregation.centers, expected.centers
272-
)
273-
assert (
274-
new_aggregation.interval_upper_bounds == expected.interval_upper_bounds
275-
) or np.allclose(
276-
new_aggregation.interval_upper_bounds, expected.interval_upper_bounds
277-
)
278-
assert (new_aggregation.means == expected.means) or np.allclose(
279-
new_aggregation.means, expected.means
280-
)
281-
assert (new_aggregation.histogram == expected.histogram) or np.allclose(
282-
new_aggregation.histogram, expected.histogram
283-
)
411+
for r, e in [
412+
(new_aggregation.forecast_values, expected.forecast_values),
413+
(new_aggregation.interval_lower_bounds, expected.interval_lower_bounds),
414+
(new_aggregation.centers, expected.centers),
415+
(new_aggregation.interval_upper_bounds, expected.interval_upper_bounds),
416+
(new_aggregation.means, expected.means),
417+
(new_aggregation.histogram, expected.histogram),
418+
]:
419+
r = np.where(np.equal(r, None), np.nan, r).astype(float)
420+
e = np.where(np.equal(e, None), np.nan, e).astype(float)
421+
np.testing.assert_allclose(r, e, equal_nan=True)
284422

285423
@pytest.mark.parametrize(
286424
"init_params, forecast_set, include_stats, histogram, expected",
@@ -468,20 +606,52 @@ def test_UnweightedAggregation(
468606
histogram=None,
469607
),
470608
),
609+
# Multiple choice with placeholders
610+
(
611+
{},
612+
ForecastSet(
613+
forecasts_values=[
614+
[0.6, 0.15, None, 0.25],
615+
[0.6, 0.25, None, 0.15],
616+
],
617+
timestep=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
618+
forecaster_ids=[1, 2],
619+
timesteps=[
620+
datetime(2022, 1, 1, tzinfo=dt_timezone.utc),
621+
datetime(2023, 1, 1, tzinfo=dt_timezone.utc),
622+
],
623+
),
624+
True,
625+
False,
626+
AggregateForecast(
627+
start_time=datetime(2024, 1, 1, tzinfo=dt_timezone.utc),
628+
method=AggregationMethod.UNWEIGHTED,
629+
forecast_values=[0.6, 0.20, None, 0.20],
630+
interval_lower_bounds=[0.6, 0.15, None, 0.15],
631+
centers=[0.6, 0.20, None, 0.20],
632+
interval_upper_bounds=[0.6, 0.25, None, 0.25],
633+
means=[0.6, 0.20, None, 0.20],
634+
forecaster_count=2,
635+
),
636+
),
471637
],
472638
)
473639
def test_RecencyWeightedAggregation(
474640
self,
475641
question_binary: Question,
642+
question_multiple_choice: Question,
476643
init_params: dict,
477644
forecast_set: ForecastSet,
478645
include_stats: bool,
479646
histogram: bool,
480647
expected: AggregateForecast,
481648
):
482-
aggregation = RecencyWeightedAggregation(
483-
question=question_binary, **init_params
484-
)
649+
if len(forecast_set.forecasts_values[0]) == 2:
650+
question = question_binary
651+
else:
652+
question = question_multiple_choice
653+
654+
aggregation = RecencyWeightedAggregation(question=question, **init_params)
485655
new_aggregation = aggregation.calculate_aggregation_entry(
486656
forecast_set, include_stats, histogram
487657
)

0 commit comments

Comments
 (0)