Skip to content

Commit 4906b9a

Browse files
committed
improve test for comput_weighted_semi_standard_deviations
1 parent 189c408 commit 4906b9a

File tree

1 file changed

+42
-4
lines changed

1 file changed

+42
-4
lines changed

tests/unit/test_utils/test_the_math/test_aggregations.py

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,10 +56,44 @@ class TestAggregations:
5656
"forecasts_values, weights, expected",
5757
[
5858
(
59-
[[1.0]],
59+
[[0.5, 0.5]],
6060
None,
61-
([0.0], [0.0]),
61+
([0.0, 0.0], [0.0, 0.0]),
6262
), # Trivial
63+
(
64+
[
65+
[0.5, 0.5],
66+
[0.5, 0.5],
67+
[0.5, 0.5],
68+
],
69+
None,
70+
([0.0, 0.0], [0.0, 0.0]),
71+
), # 3 unwavaring forecasts
72+
(
73+
[
74+
[0.2, 0.8],
75+
[0.5, 0.5],
76+
[0.8, 0.2],
77+
],
78+
None,
79+
([0.3, 0.3], [0.3, 0.3]),
80+
), # 3 unwavaring forecasts
81+
(
82+
[
83+
[0.6, 0.15, 0.0, 0.25],
84+
[0.6, 0.15, 0.0, 0.25],
85+
],
86+
None,
87+
([0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]),
88+
), # identical forecasts with placeholders
89+
(
90+
[
91+
[0.4, 0.25, 0.0, 0.35],
92+
[0.6, 0.15, 0.0, 0.25],
93+
],
94+
None,
95+
([0.1, 0.05, 0.0, 0.05], [0.1, 0.05, 0.0, 0.05]),
96+
), # minorly different forecasts with placeholders
6397
],
6498
)
6599
def test_compute_weighted_semi_standard_deviations(
@@ -69,8 +103,12 @@ def test_compute_weighted_semi_standard_deviations(
69103
expected: tuple[ForecastValues, ForecastValues],
70104
):
71105
result = compute_weighted_semi_standard_deviations(forecasts_values, weights)
72-
assert result[0] == expected[0]
73-
assert result[1] == expected[1]
106+
rl, ru = result
107+
el, eu = expected
108+
for v, e in zip(rl, el):
109+
np.testing.assert_approx_equal(v, e)
110+
for v, e in zip(ru, eu):
111+
np.testing.assert_approx_equal(v, e)
74112

75113
@pytest.mark.parametrize("aggregation_name", [Agg.method for Agg in AGGREGATIONS])
76114
def test_aggregations_initialize(

0 commit comments

Comments
 (0)