Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 160 additions & 2 deletions open_instruct/ground_truth_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import string
import weakref
from abc import ABC, abstractmethod
from collections import Counter
from dataclasses import dataclass
from collections import Counter, defaultdict
from collections.abc import Callable
from dataclasses import dataclass, field
from typing import Any

import numpy as np
import requests
from litellm import acompletion

Expand Down Expand Up @@ -987,3 +989,159 @@ async def cleanup_all_llm_judge_clients():
Cleanup function to properly close all LLM judge clients before shutdown.
"""
await LMJudgeVerifier.cleanup_all_clients()


async def apply_verifiable_reward(
reward_fn_mapping: dict[str, VerifierFunction],
responses: list,
decoded_responses: list[str],
ground_truths: list,
datasets: list[str],
reward_mult: int = 10,
queries: list[str] | None = None,
):
if queries is None:
queries = [None] * len(responses)

async_tasks = []
task_metadata = []

for i, (tok_prediction, prediction, ground_truth, dataset, query) in enumerate(
zip(responses, decoded_responses, ground_truths, datasets, queries)
):
ground_truth_list = [ground_truth] if isinstance(ground_truth, str) else ground_truth
dataset_list = [dataset] if isinstance(dataset, str) else dataset
assert len(ground_truth_list) == len(dataset_list), "Ground truth and dataset list lengths do not match."

for gt, ds in zip(ground_truth_list, dataset_list):
reward_func = reward_fn_mapping.get(ds.lower())
if reward_func is None:
logger.warning("No reward function found for dataset %s. Skipping reward.", ds)
continue

task = reward_func.async_call(
tokenized_prediction=tok_prediction, prediction=prediction, label=gt, query=query
)
async_tasks.append(task)
task_metadata.append(
{
"response_idx": i,
"dataset": reward_func.name,
"reward_weight": reward_func.weight,
"reward_mult": reward_mult,
}
)

if async_tasks:
reward_results = await asyncio.gather(*async_tasks)
logger.debug(f"Applied {len(reward_results)} ground truth rewards in parallel")
else:
reward_results = []

response_rewards = [0] * len(responses)
response_per_func_rewards = [{} for _ in range(len(responses))]

for result, metadata in zip(reward_results, task_metadata):
response_idx = metadata["response_idx"]
dataset = metadata["dataset"]
reward_weight = metadata["reward_weight"]
reward_mult = metadata["reward_mult"]

score = result.score if hasattr(result, "score") else result
weighted_reward = reward_mult * score * reward_weight

response_rewards[response_idx] += weighted_reward
response_per_func_rewards[response_idx][dataset] = (
response_per_func_rewards[response_idx].get(dataset, 0) + weighted_reward
)

return response_rewards, response_per_func_rewards


@dataclass
class RewardConfig:
"""Configuration for reward function computation."""

apply_r1_style_format_reward: bool = False
r1_style_format_reward: float = 1.0
apply_verifiable_reward: bool = True
verification_reward: int = 10
non_stop_penalty: bool = False
non_stop_penalty_value: float = -10.0
only_reward_good_outputs: bool = False
additive_format_reward: bool = False
verifier_functions: dict[str, VerifierFunction] = field(default_factory=dict)

def build(self) -> Callable:
"""Build and return the reward function."""

async def reward_fn(
responses: list,
decoded_responses: list[str],
ground_truths: list[Any],
datasets: list[str],
finish_reasons: list[str],
infos,
queries: list[str] | None = None,
) -> tuple[list[float], dict[str, Any]]:
timeouts = infos.timeouts
tool_errors = infos.tool_errors
tool_outputs = infos.tool_outputs
tool_calleds = infos.tool_calleds
good_outputs = [
len(tool_outputs[i]) > 0 and tool_calleds[i] and not timeouts[i] and not tool_errors[i]
for i in range(len(tool_outputs))
]
scores = [0.0] * len(decoded_responses)
metrics: dict[str, Any] = {}
format_scores: list[float] = []

if self.apply_r1_style_format_reward:
format_scores = soft_format_reward_func(decoded_responses, self.r1_style_format_reward)
if len(format_scores) != len(scores):
raise ValueError(f"{len(format_scores)=} != {len(scores)=}")
for i in range(len(format_scores)):
scores[i] = format_scores[i] + scores[i]
metrics["val/format_scores"] = np.array(format_scores).mean()

if self.apply_verifiable_reward:
verifiable_rewards, per_func_rewards = await apply_verifiable_reward(
self.verifier_functions,
responses,
decoded_responses,
ground_truths,
datasets,
reward_mult=self.verification_reward,
queries=queries,
)
if len(verifiable_rewards) != len(scores):
raise ValueError(f"{len(verifiable_rewards)=} != {len(scores)=}")
for i in range(len(verifiable_rewards)):
if not self.only_reward_good_outputs or (good_outputs[i] and self.only_reward_good_outputs):
if self.apply_r1_style_format_reward and self.additive_format_reward:
scores[i] = verifiable_rewards[i] + scores[i]
elif self.apply_r1_style_format_reward and not self.additive_format_reward:
scores[i] = verifiable_rewards[i] if format_scores[i] == 1 else 0
else:
scores[i] = verifiable_rewards[i]
np_verifiable_rewards = np.array(verifiable_rewards)
metrics["objective/verifiable_reward"] = np_verifiable_rewards.mean()
metrics["objective/verifiable_correct_rate"] = (np_verifiable_rewards > 0.0).mean()
per_func_lists: dict[str, list] = defaultdict(list)
for reward_dict in per_func_rewards:
for key, value in reward_dict.items():
per_func_lists[key].append(value)
for key, value in per_func_lists.items():
np_value = np.array(value)
metrics[f"objective/{key}_reward"] = np_value.mean()
metrics[f"objective/{key}_correct_rate"] = (np_value > 0.0).mean()

if self.non_stop_penalty:
assert len(finish_reasons) == len(scores)
for i in range(len(finish_reasons)):
if finish_reasons[i] != "stop":
scores[i] = self.non_stop_penalty_value

return scores, metrics

return reward_fn
Loading