Skip to content

Commit ccea424

Browse files
new async reward impl.
1 parent b8b735b commit ccea424

File tree

4 files changed

+248
-136
lines changed

4 files changed

+248
-136
lines changed

open_instruct/ground_truth_utils.py

Lines changed: 161 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,12 @@
1515
import string
1616
import weakref
1717
from abc import ABC, abstractmethod
18-
from collections import Counter
19-
from dataclasses import dataclass
18+
from collections import Counter, defaultdict
19+
from collections.abc import Callable
20+
from dataclasses import dataclass, field
2021
from typing import Any
2122

23+
import numpy as np
2224
import requests
2325
from litellm import acompletion
2426

@@ -987,3 +989,160 @@ async def cleanup_all_llm_judge_clients():
987989
Cleanup function to properly close all LLM judge clients before shutdown.
988990
"""
989991
await LMJudgeVerifier.cleanup_all_clients()
992+
993+
994+
async def apply_verifiable_reward(
995+
reward_fn_mapping: dict[str, VerifierFunction],
996+
responses: list,
997+
decoded_responses: list[str],
998+
ground_truths: list,
999+
datasets: list[str],
1000+
reward_mult: int = 10,
1001+
queries: list[str] | None = None,
1002+
):
1003+
if queries is None:
1004+
queries = [None] * len(responses)
1005+
1006+
async_tasks = []
1007+
task_metadata = []
1008+
1009+
for i, (tok_prediction, prediction, ground_truth, dataset, query) in enumerate(
1010+
zip(responses, decoded_responses, ground_truths, datasets, queries)
1011+
):
1012+
ground_truth_list = [ground_truth] if isinstance(ground_truth, str) else ground_truth
1013+
dataset_list = [dataset] if isinstance(dataset, str) else dataset
1014+
assert len(ground_truth_list) == len(dataset_list), "Ground truth and dataset list lengths do not match."
1015+
1016+
for gt, ds in zip(ground_truth_list, dataset_list):
1017+
reward_func = reward_fn_mapping.get(ds.lower())
1018+
if reward_func is None:
1019+
logger.warning("No reward function found for dataset %s. Skipping reward.", ds)
1020+
continue
1021+
1022+
task = reward_func.async_call(
1023+
tokenized_prediction=tok_prediction, prediction=prediction, label=gt, query=query
1024+
)
1025+
async_tasks.append(task)
1026+
task_metadata.append(
1027+
{
1028+
"response_idx": i,
1029+
"dataset": reward_func.name,
1030+
"reward_weight": reward_func.weight,
1031+
"reward_mult": reward_mult,
1032+
}
1033+
)
1034+
1035+
if async_tasks:
1036+
reward_results = await asyncio.gather(*async_tasks)
1037+
logger.debug(f"Applied {len(reward_results)} ground truth rewards in parallel")
1038+
else:
1039+
reward_results = []
1040+
1041+
response_rewards = [0] * len(responses)
1042+
response_per_func_rewards = [{} for _ in range(len(responses))]
1043+
1044+
for result, metadata in zip(reward_results, task_metadata):
1045+
response_idx = metadata["response_idx"]
1046+
dataset = metadata["dataset"]
1047+
reward_weight = metadata["reward_weight"]
1048+
reward_mult = metadata["reward_mult"]
1049+
1050+
score = result.score if hasattr(result, "score") else result
1051+
weighted_reward = reward_mult * score * reward_weight
1052+
1053+
response_rewards[response_idx] += weighted_reward
1054+
response_per_func_rewards[response_idx][dataset] = (
1055+
response_per_func_rewards[response_idx].get(dataset, 0) + weighted_reward
1056+
)
1057+
1058+
return response_rewards, response_per_func_rewards
1059+
1060+
1061+
@dataclass
1062+
class RewardConfig:
1063+
"""Configuration for reward function computation."""
1064+
1065+
apply_r1_style_format_reward: bool = False
1066+
r1_style_format_reward: float = 1.0
1067+
apply_verifiable_reward: bool = True
1068+
verification_reward: int = 10
1069+
non_stop_penalty: bool = False
1070+
non_stop_penalty_value: float = -10.0
1071+
only_reward_good_outputs: bool = False
1072+
additive_format_reward: bool = False
1073+
verifier_functions: dict[str, VerifierFunction] = field(default_factory=dict)
1074+
1075+
def build(self) -> Callable:
1076+
"""Build and return the reward function."""
1077+
reward_fn_mapping = self.verifier_functions
1078+
1079+
async def reward_fn(
1080+
responses: list,
1081+
decoded_responses: list[str],
1082+
ground_truths: list[Any],
1083+
datasets: list[str],
1084+
finish_reasons: list[str],
1085+
infos,
1086+
queries: list[str] | None = None,
1087+
) -> tuple[list[float], dict[str, Any]]:
1088+
timeouts = infos.timeouts
1089+
tool_errors = infos.tool_errors
1090+
tool_outputs = infos.tool_outputs
1091+
tool_calleds = infos.tool_calleds
1092+
good_outputs = [
1093+
len(tool_outputs[i]) > 0 and tool_calleds[i] and not timeouts[i] and not tool_errors[i]
1094+
for i in range(len(tool_outputs))
1095+
]
1096+
scores = [0.0] * len(decoded_responses)
1097+
metrics: dict[str, Any] = {}
1098+
format_scores: list[float] = []
1099+
1100+
if self.apply_r1_style_format_reward:
1101+
format_scores = soft_format_reward_func(decoded_responses, self.r1_style_format_reward)
1102+
if len(format_scores) != len(scores):
1103+
raise ValueError(f"{len(format_scores)=} != {len(scores)=}")
1104+
for i in range(len(format_scores)):
1105+
scores[i] = format_scores[i] + scores[i]
1106+
metrics["val/format_scores"] = np.array(format_scores).mean()
1107+
1108+
if self.apply_verifiable_reward:
1109+
verifiable_rewards, per_func_rewards = await apply_verifiable_reward(
1110+
reward_fn_mapping,
1111+
responses,
1112+
decoded_responses,
1113+
ground_truths,
1114+
datasets,
1115+
reward_mult=self.verification_reward,
1116+
queries=queries,
1117+
)
1118+
if len(verifiable_rewards) != len(scores):
1119+
raise ValueError(f"{len(verifiable_rewards)=} != {len(scores)=}")
1120+
for i in range(len(verifiable_rewards)):
1121+
if not self.only_reward_good_outputs or (good_outputs[i] and self.only_reward_good_outputs):
1122+
if self.apply_r1_style_format_reward and self.additive_format_reward:
1123+
scores[i] = verifiable_rewards[i] + scores[i]
1124+
elif self.apply_r1_style_format_reward and not self.additive_format_reward:
1125+
scores[i] = verifiable_rewards[i] if format_scores[i] == 1 else 0
1126+
else:
1127+
scores[i] = verifiable_rewards[i]
1128+
np_verifiable_rewards = np.array(verifiable_rewards)
1129+
metrics["objective/verifiable_reward"] = np_verifiable_rewards.mean()
1130+
metrics["objective/verifiable_correct_rate"] = (np_verifiable_rewards > 0.0).mean()
1131+
per_func_lists: dict[str, list] = defaultdict(list)
1132+
for reward_dict in per_func_rewards:
1133+
for key, value in reward_dict.items():
1134+
per_func_lists[key].append(value)
1135+
for key, value in per_func_lists.items():
1136+
np_value = np.array(value)
1137+
metrics[f"objective/{key}_reward"] = np_value.mean()
1138+
metrics[f"objective/{key}_correct_rate"] = (np_value > 0.0).mean()
1139+
1140+
if self.non_stop_penalty:
1141+
assert len(finish_reasons) == len(scores)
1142+
for i in range(len(finish_reasons)):
1143+
if finish_reasons[i] != "stop":
1144+
scores[i] = self.non_stop_penalty_value
1145+
1146+
return scores, metrics
1147+
1148+
return reward_fn

0 commit comments

Comments
 (0)