Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions questions/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -637,14 +637,16 @@ def get_prediction_values(self) -> list[float | None]:
return self.probability_yes_per_category
return self.continuous_cdf

def get_pmf(self) -> list[float]:
def get_pmf(self, replace_none: bool = False) -> list[float]:
"""
gets the PMF for this forecast, replacing None values with 0.0
Not for serialization use (keep None values in that case)
gets the PMF for this forecast
replaces None values with 0.0 if replace_none is True
"""
if self.probability_yes:
return [1 - self.probability_yes, self.probability_yes]
if self.probability_yes_per_category:
if not replace_none:
return self.probability_yes_per_category
return [
v or 0.0 for v in self.probability_yes_per_category
] # replace None with 0.0
Expand Down Expand Up @@ -719,18 +721,20 @@ def get_cdf(self) -> list[float | None] | None:
return self.forecast_values
return None

def get_pmf(self) -> list[float]:
def get_pmf(self, replace_none: bool = False) -> list[float | None]:
"""
gets the PMF for this forecast, replacing None values with 0.0
Not for serialization use (keep None values in that case)
gets the PMF for this forecast
replacing None values with 0.0 if replace_none is True
"""
# grab annotation if it exists for efficiency
question_type = getattr(self, "question_type", self.question.type)
forecast_values = [
v or 0.0 for v in self.forecast_values
] # replace None with 0.0
forecast_values = self.forecast_values
if question_type == Question.QuestionType.MULTIPLE_CHOICE:
if not replace_none:
return forecast_values
return [v or 0.0 for v in forecast_values] # replace None with 0.0
if question_type in QUESTION_CONTINUOUS_TYPES:
cdf: list[float] = forecast_values
cdf: list[float] = forecast_values # type: ignore
pmf = [cdf[0]]
for i in range(1, len(cdf)):
pmf.append(cdf[i] - cdf[i - 1])
Expand Down
48 changes: 32 additions & 16 deletions scoring/score_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

@dataclass
class AggregationEntry:
pmf: np.ndarray | list[float]
pmf: np.ndarray | list[float | None]
num_forecasters: int
timestamp: float

Expand All @@ -36,7 +36,7 @@ def get_geometric_means(
timesteps.add(forecast.end_time.timestamp())
for timestep in sorted(timesteps):
prediction_values = [
f.get_pmf()
f.get_pmf(replace_none=True)
for f in forecasts
if f.start_time.timestamp() <= timestep
and (f.end_time is None or f.end_time.timestamp() > timestep)
Expand Down Expand Up @@ -84,9 +84,12 @@ def evaluate_forecasts_baseline_accuracy(
forecast_coverage = forecast_duration / total_duration
pmf = forecast.get_pmf()
if question_type in ["binary", "multiple_choice"]:
forecast_score = (
100 * np.log(pmf[resolution_bucket] * len(pmf)) / np.log(len(pmf))
)
# forecasts always have `None` assigned to MC options that aren't
# available at the time. Detecting these allows us to avoid trying to
# follow the question's options_history.
options_at_time = len([p for p in pmf if p is not None])
p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
forecast_score = 100 * np.log(p * options_at_time) / np.log(options_at_time)
else:
if resolution_bucket in [0, len(pmf) - 1]:
baseline = 0.05
Expand Down Expand Up @@ -116,8 +119,13 @@ def evaluate_forecasts_baseline_spot_forecast(
if start <= spot_forecast_timestamp < end:
pmf = forecast.get_pmf()
if question_type in ["binary", "multiple_choice"]:
# forecasts always have `None` assigned to MC options that aren't
# available at the time. Detecting these allows us to avoid trying to
# follow the question's options_history.
options_at_time = len([p for p in pmf if p is not None])
p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
forecast_score = (
100 * np.log(pmf[resolution_bucket] * len(pmf)) / np.log(len(pmf))
100 * np.log(p * options_at_time) / np.log(options_at_time)
)
else:
if resolution_bucket in [0, len(pmf) - 1]:
Expand Down Expand Up @@ -159,17 +167,21 @@ def evaluate_forecasts_peer_accuracy(
continue

pmf = forecast.get_pmf()
p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
interval_scores: list[float | None] = []
for gm in geometric_mean_forecasts:
if forecast_start <= gm.timestamp < forecast_end:
score = (
gmp = (
gm.pmf[resolution_bucket] or gm.pmf[-1]
) # if None, read from Other
interval_score = (
100
* (gm.num_forecasters / (gm.num_forecasters - 1))
* np.log(pmf[resolution_bucket] / gm.pmf[resolution_bucket])
* np.log(p / gmp)
)
if question_type in QUESTION_CONTINUOUS_TYPES:
score /= 2
interval_scores.append(score)
interval_score /= 2
interval_scores.append(interval_score)
else:
interval_scores.append(None)

Expand Down Expand Up @@ -218,10 +230,10 @@ def evaluate_forecasts_peer_spot_forecast(
)
if start <= spot_forecast_timestamp < end:
pmf = forecast.get_pmf()
p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
gmp = gm.pmf[resolution_bucket] or gm.pmf[-1] # if None, read from Other
forecast_score = (
100
* (gm.num_forecasters / (gm.num_forecasters - 1))
* np.log(pmf[resolution_bucket] / gm.pmf[resolution_bucket])
100 * (gm.num_forecasters / (gm.num_forecasters - 1)) * np.log(p / gmp)
)
if question_type in QUESTION_CONTINUOUS_TYPES:
forecast_score /= 2
Expand Down Expand Up @@ -260,11 +272,15 @@ def evaluate_forecasts_legacy_relative(
continue

pmf = forecast.get_pmf()
p = pmf[resolution_bucket] or pmf[-1] # if None, read from Other
interval_scores: list[float | None] = []
for bf in baseline_forecasts:
if forecast_start <= bf.timestamp < forecast_end:
score = np.log2(pmf[resolution_bucket] / bf.pmf[resolution_bucket])
interval_scores.append(score)
bfp = (
bf.pmf[resolution_bucket] or bf.pmf[-1]
) # if None, read from Other
interval_score = np.log2(p / bfp)
interval_scores.append(interval_score)
else:
interval_scores.append(None)

Expand Down Expand Up @@ -316,7 +332,7 @@ def evaluate_question(
if spot_forecast_time:
spot_forecast_timestamp = min(spot_forecast_time.timestamp(), actual_close_time)

# We need all user forecasts to calculated GeoMean even
# We need all user forecasts to calculate GeoMean even
# if we're only scoring some or none of the users
user_forecasts = question.user_forecasts.all()
if only_include_user_ids:
Expand Down
Loading