|
15 | 15 | import string |
16 | 16 | import weakref |
17 | 17 | from abc import ABC, abstractmethod |
18 | | -from collections import Counter |
19 | | -from dataclasses import dataclass |
| 18 | +from collections import Counter, defaultdict |
| 19 | +from collections.abc import Callable |
| 20 | +from dataclasses import dataclass, field |
20 | 21 | from typing import Any |
21 | 22 |
|
| 23 | +import numpy as np |
22 | 24 | import requests |
23 | 25 | from litellm import acompletion |
24 | 26 |
|
@@ -987,3 +989,160 @@ async def cleanup_all_llm_judge_clients(): |
987 | 989 | Cleanup function to properly close all LLM judge clients before shutdown. |
988 | 990 | """ |
989 | 991 | await LMJudgeVerifier.cleanup_all_clients() |
| 992 | + |
| 993 | + |
| 994 | +async def apply_verifiable_reward( |
| 995 | + reward_fn_mapping: dict[str, VerifierFunction], |
| 996 | + responses: list, |
| 997 | + decoded_responses: list[str], |
| 998 | + ground_truths: list, |
| 999 | + datasets: list[str], |
| 1000 | + reward_mult: int = 10, |
| 1001 | + queries: list[str] | None = None, |
| 1002 | +): |
| 1003 | + if queries is None: |
| 1004 | + queries = [None] * len(responses) |
| 1005 | + |
| 1006 | + async_tasks = [] |
| 1007 | + task_metadata = [] |
| 1008 | + |
| 1009 | + for i, (tok_prediction, prediction, ground_truth, dataset, query) in enumerate( |
| 1010 | + zip(responses, decoded_responses, ground_truths, datasets, queries) |
| 1011 | + ): |
| 1012 | + ground_truth_list = [ground_truth] if isinstance(ground_truth, str) else ground_truth |
| 1013 | + dataset_list = [dataset] if isinstance(dataset, str) else dataset |
| 1014 | + assert len(ground_truth_list) == len(dataset_list), "Ground truth and dataset list lengths do not match." |
| 1015 | + |
| 1016 | + for gt, ds in zip(ground_truth_list, dataset_list): |
| 1017 | + reward_func = reward_fn_mapping.get(ds.lower()) |
| 1018 | + if reward_func is None: |
| 1019 | + logger.warning("No reward function found for dataset %s. Skipping reward.", ds) |
| 1020 | + continue |
| 1021 | + |
| 1022 | + task = reward_func.async_call( |
| 1023 | + tokenized_prediction=tok_prediction, prediction=prediction, label=gt, query=query |
| 1024 | + ) |
| 1025 | + async_tasks.append(task) |
| 1026 | + task_metadata.append( |
| 1027 | + { |
| 1028 | + "response_idx": i, |
| 1029 | + "dataset": reward_func.name, |
| 1030 | + "reward_weight": reward_func.weight, |
| 1031 | + "reward_mult": reward_mult, |
| 1032 | + } |
| 1033 | + ) |
| 1034 | + |
| 1035 | + if async_tasks: |
| 1036 | + reward_results = await asyncio.gather(*async_tasks) |
| 1037 | + logger.debug(f"Applied {len(reward_results)} ground truth rewards in parallel") |
| 1038 | + else: |
| 1039 | + reward_results = [] |
| 1040 | + |
| 1041 | + response_rewards = [0] * len(responses) |
| 1042 | + response_per_func_rewards = [{} for _ in range(len(responses))] |
| 1043 | + |
| 1044 | + for result, metadata in zip(reward_results, task_metadata): |
| 1045 | + response_idx = metadata["response_idx"] |
| 1046 | + dataset = metadata["dataset"] |
| 1047 | + reward_weight = metadata["reward_weight"] |
| 1048 | + reward_mult = metadata["reward_mult"] |
| 1049 | + |
| 1050 | + score = result.score if hasattr(result, "score") else result |
| 1051 | + weighted_reward = reward_mult * score * reward_weight |
| 1052 | + |
| 1053 | + response_rewards[response_idx] += weighted_reward |
| 1054 | + response_per_func_rewards[response_idx][dataset] = ( |
| 1055 | + response_per_func_rewards[response_idx].get(dataset, 0) + weighted_reward |
| 1056 | + ) |
| 1057 | + |
| 1058 | + return response_rewards, response_per_func_rewards |
| 1059 | + |
| 1060 | + |
| 1061 | +@dataclass |
| 1062 | +class RewardConfig: |
| 1063 | + """Configuration for reward function computation.""" |
| 1064 | + |
| 1065 | + apply_r1_style_format_reward: bool = False |
| 1066 | + r1_style_format_reward: float = 1.0 |
| 1067 | + apply_verifiable_reward: bool = True |
| 1068 | + verification_reward: int = 10 |
| 1069 | + non_stop_penalty: bool = False |
| 1070 | + non_stop_penalty_value: float = -10.0 |
| 1071 | + only_reward_good_outputs: bool = False |
| 1072 | + additive_format_reward: bool = False |
| 1073 | + verifier_functions: dict[str, VerifierFunction] = field(default_factory=dict) |
| 1074 | + |
| 1075 | + def build(self) -> Callable: |
| 1076 | + """Build and return the reward function.""" |
| 1077 | + reward_fn_mapping = self.verifier_functions |
| 1078 | + |
| 1079 | + async def reward_fn( |
| 1080 | + responses: list, |
| 1081 | + decoded_responses: list[str], |
| 1082 | + ground_truths: list[Any], |
| 1083 | + datasets: list[str], |
| 1084 | + finish_reasons: list[str], |
| 1085 | + infos, |
| 1086 | + queries: list[str] | None = None, |
| 1087 | + ) -> tuple[list[float], dict[str, Any]]: |
| 1088 | + timeouts = infos.timeouts |
| 1089 | + tool_errors = infos.tool_errors |
| 1090 | + tool_outputs = infos.tool_outputs |
| 1091 | + tool_calleds = infos.tool_calleds |
| 1092 | + good_outputs = [ |
| 1093 | + len(tool_outputs[i]) > 0 and tool_calleds[i] and not timeouts[i] and not tool_errors[i] |
| 1094 | + for i in range(len(tool_outputs)) |
| 1095 | + ] |
| 1096 | + scores = [0.0] * len(decoded_responses) |
| 1097 | + metrics: dict[str, Any] = {} |
| 1098 | + format_scores: list[float] = [] |
| 1099 | + |
| 1100 | + if self.apply_r1_style_format_reward: |
| 1101 | + format_scores = soft_format_reward_func(decoded_responses, self.r1_style_format_reward) |
| 1102 | + if len(format_scores) != len(scores): |
| 1103 | + raise ValueError(f"{len(format_scores)=} != {len(scores)=}") |
| 1104 | + for i in range(len(format_scores)): |
| 1105 | + scores[i] = format_scores[i] + scores[i] |
| 1106 | + metrics["val/format_scores"] = np.array(format_scores).mean() |
| 1107 | + |
| 1108 | + if self.apply_verifiable_reward: |
| 1109 | + verifiable_rewards, per_func_rewards = await apply_verifiable_reward( |
| 1110 | + reward_fn_mapping, |
| 1111 | + responses, |
| 1112 | + decoded_responses, |
| 1113 | + ground_truths, |
| 1114 | + datasets, |
| 1115 | + reward_mult=self.verification_reward, |
| 1116 | + queries=queries, |
| 1117 | + ) |
| 1118 | + if len(verifiable_rewards) != len(scores): |
| 1119 | + raise ValueError(f"{len(verifiable_rewards)=} != {len(scores)=}") |
| 1120 | + for i in range(len(verifiable_rewards)): |
| 1121 | + if not self.only_reward_good_outputs or (good_outputs[i] and self.only_reward_good_outputs): |
| 1122 | + if self.apply_r1_style_format_reward and self.additive_format_reward: |
| 1123 | + scores[i] = verifiable_rewards[i] + scores[i] |
| 1124 | + elif self.apply_r1_style_format_reward and not self.additive_format_reward: |
| 1125 | + scores[i] = verifiable_rewards[i] if format_scores[i] == 1 else 0 |
| 1126 | + else: |
| 1127 | + scores[i] = verifiable_rewards[i] |
| 1128 | + np_verifiable_rewards = np.array(verifiable_rewards) |
| 1129 | + metrics["objective/verifiable_reward"] = np_verifiable_rewards.mean() |
| 1130 | + metrics["objective/verifiable_correct_rate"] = (np_verifiable_rewards > 0.0).mean() |
| 1131 | + per_func_lists: dict[str, list] = defaultdict(list) |
| 1132 | + for reward_dict in per_func_rewards: |
| 1133 | + for key, value in reward_dict.items(): |
| 1134 | + per_func_lists[key].append(value) |
| 1135 | + for key, value in per_func_lists.items(): |
| 1136 | + np_value = np.array(value) |
| 1137 | + metrics[f"objective/{key}_reward"] = np_value.mean() |
| 1138 | + metrics[f"objective/{key}_correct_rate"] = (np_value > 0.0).mean() |
| 1139 | + |
| 1140 | + if self.non_stop_penalty: |
| 1141 | + assert len(finish_reasons) == len(scores) |
| 1142 | + for i in range(len(finish_reasons)): |
| 1143 | + if finish_reasons[i] != "stop": |
| 1144 | + scores[i] = self.non_stop_penalty_value |
| 1145 | + |
| 1146 | + return scores, metrics |
| 1147 | + |
| 1148 | + return reward_fn |
0 commit comments