diff --git a/recipes/Llama-3.2-1B-Instruct/best_of_n_trl_prm.yaml b/recipes/Llama-3.2-1B-Instruct/best_of_n_trl_prm.yaml new file mode 100644 index 00000000..ac25c798 --- /dev/null +++ b/recipes/Llama-3.2-1B-Instruct/best_of_n_trl_prm.yaml @@ -0,0 +1,13 @@ +# refer to src/sal/config.py for more options + +approach: best_of_n +n: 32 +search_batch_size: 25 +sort_completed: true +filter_duplicates: true +# num_samples: 5 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET +seed: 0 +system_prompt: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n[Concise description]\n[Brief explanation and calculations]\n\n[Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem." +prm_path: "HuggingFaceH4/Qwen2.5-Math-1.5B-Instruct-PRM-0.2" +overwrite_hub_revision: true # While testing the new PRM model +prm_batch_size: 32 \ No newline at end of file diff --git a/recipes/Llama-3.2-1B-Instruct/dvts_trl.yaml b/recipes/Llama-3.2-1B-Instruct/dvts_trl.yaml new file mode 100644 index 00000000..b17312c2 --- /dev/null +++ b/recipes/Llama-3.2-1B-Instruct/dvts_trl.yaml @@ -0,0 +1,13 @@ +# refer to src/sal/config.py for more options + +approach: dvts +n: 32 +search_batch_size: 25 +sort_completed: true +filter_duplicates: true +# num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET +seed: 0 +system_prompt: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n[Concise description]\n[Brief explanation and calculations]\n\n[Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem." +prm_path: "HuggingFaceH4/Qwen2.5-Math-1.5B-Instruct-PRM-0.2" +overwrite_hub_revision: true # While testing the new PRM model +prm_batch_size: 32 \ No newline at end of file diff --git a/src/sal/config.py b/src/sal/config.py index e09cc584..bcd6b082 100644 --- a/src/sal/config.py +++ b/src/sal/config.py @@ -68,6 +68,9 @@ class Config: filter_duplicates: bool = False sort_completed: bool = False + # PRM related options + separator: str = "\n\n" + def __post_init__(self): if self.approach == "dvts": if self.n % self.beam_width != 0: diff --git a/src/sal/models/reward_models.py b/src/sal/models/reward_models.py index 7b1e1e55..fbf5a9d3 100644 --- a/src/sal/models/reward_models.py +++ b/src/sal/models/reward_models.py @@ -16,8 +16,10 @@ from itertools import accumulate import torch +from tqdm import tqdm from transformers import ( AutoModelForCausalLM, + AutoModelForTokenClassification, AutoTokenizer, PreTrainedModel, PreTrainedTokenizer, @@ -30,9 +32,12 @@ prepare_input, ) from sal.models.skywork_o1_prm.prm_model import SkyworkPRMModel +from sal.models.utils import BatchProcessor, Example, process_results CANDIDATE_TOKENS = [648, 387] STEP_TAG_ID = 12902 +LABEL_MAP = {"LABEL_0": False, "LABEL_1": True} +LABEL_FOR_TRUE = "LABEL_1" def batched_math_shepherd_inference( @@ -129,9 +134,9 @@ def score( # stripped_output_scores = [] TODO: strip out the reward for previous steps for output_score, output in zip(output_scores, outputs): - assert len(output_score) == len( - output - ), f"{len(output_score)} != {len(output)}" + assert len(output_score) == len(output), ( + f"{len(output_score)} != {len(output)}" + ) return output_scores @@ -340,6 +345,76 @@ def load_model_and_tokenizer( return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs) +class TRLPRM(PRM): + def load_model_and_tokenizer( + self, **model_kwargs + ) -> tuple[PreTrainedModel, PreTrainedTokenizer]: + tokenizer = AutoTokenizer.from_pretrained(self.search_config.prm_path) + tokenizer.padding_side = ( + "left" # To extract the predicted token as the last token of the right + ) + model = AutoModelForTokenClassification.from_pretrained( + self.search_config.prm_path, + device_map="auto", + torch_dtype=torch.float16, + **model_kwargs, + ).eval() + + return model, tokenizer + + def score( + self, questions: list[str], outputs: list[list[str]] + ) -> list[list[float]]: + inputs_for_prm = [ + Example(problem=question, steps=answers, sep=self.search_config.separator) + for question, answers in zip(questions, outputs) + ] + batch_processor = BatchProcessor( + inputs_for_prm, self.search_config.prm_batch_size + ) + processed_data = {} + + for batch_steps, batch_indices in tqdm( + batch_processor, + total=batch_processor.get_total_batches(), + desc="PRM Inference...", + ): + with torch.no_grad(): + # batch_steps = ['Let $a,$ $b,$ and $c$ be positive real numbers. Find the set of all possible values of\n\\[\\frac{c}{a} + \\frac{a}{b + c} + \\frac{b}{c}.\\]\n\nThis problem involves finding the range of an expression involving three variables.', 'Let $a,$ $b,$ and $c$ be positive real numbers. Find the set of all possible values of\n\\[\\frac{c}{a} + \\frac{a}{b + c} + \\frac{b}{c}.\\]\n\nThis problem involves finding the range of an expression involving three variables.\n\nOne possible strategy is to try to eliminate some variables and write the expression in terms of one variable only.'] + tokenized_batch = self.tokenizer( + batch_steps, padding=True, return_tensors="pt" + ).to(self.model.device) + # Get model outputs + batched_outputs = self.model(**tokenized_batch) + # Transform to probabilities, and extract the ones corresponding + # to the TRUE class (LABEL_1, which is the first class) + scores = batched_outputs.logits.softmax(dim=-1)[:, :, 0] + # The probabilities for the batch can be extracted by finding the prob + # of the last token, which should correspond to the + probs = scores[:, -1].tolist() # To extract them from cuda + # batched_outputs = self.pipeline(batch_steps) + + # Assign results back to original structure + process_results(probs, batch_indices, processed_data) + # process_results(batched_outputs, batch_indices, processed_data) + # Clear GPU memory + del batch_steps, batched_outputs, scores, probs + torch.cuda.empty_cache() + + # The "processed_data" comes sorted as a dict with the index, and the different + # scores. Now we group each group of answers to its N. + reshaped_output_scores = [] + counter = 0 + for _, answers in zip(questions, outputs): + scores = [] + for _ in answers: + scores.append(processed_data[counter]) + counter += 1 + reshaped_output_scores.append(scores) + + return reshaped_output_scores + + def load_prm(config: Config) -> PRM: if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm": return MathShepherd(config) @@ -353,4 +428,7 @@ def load_prm(config: Config) -> PRM: if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B": return SkyworkO1_7B(config) + if config.prm_path.startswith("HuggingFaceH4"): + return TRLPRM(config) + raise NotImplementedError(f"PRM {config.prm_path} not implemented") diff --git a/src/sal/models/utils.py b/src/sal/models/utils.py new file mode 100644 index 00000000..36f55cbe --- /dev/null +++ b/src/sal/models/utils.py @@ -0,0 +1,106 @@ +#!/usr/bin/env python +# Copyright 2024 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from functools import cached_property + + +@dataclass +class Example: + problem: str + steps: list[str] + sep: str = "\n" + + @cached_property + def get_texts(self): + """Returns the lists with each problem and solution steps concatenated + with the separator. + """ + return [ + self.sep.join((self.problem, *self.steps[:i])) + self.sep + for i, step in enumerate(self.steps, start=1) + ] + + +class BatchProcessor: + """Helper class to allow passing batches to the model pipeline including different + problem and solutions steps. It allows assigning back the steps of the errors at the + end by finding the corresponding index of the problems in the batches. + """ + + def __init__(self, data: list[Example], batch_size: int = 32): + self.data = data + self.batch_size = batch_size + self.current_idx = 0 + + # Create index mapping for steps + self.step_mapping = [] # [(dataset_idx, step_idx), ...] + for idx, item in enumerate(data): + for step_idx in range(len(item.steps)): + self.step_mapping.append((idx, step_idx)) + + self.total_steps = len(self.step_mapping) + + def __iter__(self): + self.current_idx = 0 + return self + + def __next__(self): + if self.current_idx >= self.total_steps: + raise StopIteration + + batch_indices = [] + batch_steps = [] + step_count = 0 + + while self.current_idx < self.total_steps and step_count < self.batch_size: + dataset_idx, step_idx = self.step_mapping[self.current_idx] + batch_indices.append((dataset_idx, step_idx)) + + # Here the steps have to be already generated + steps = self.data[dataset_idx].get_texts + batch_steps.append(steps[step_idx]) + + step_count += 1 + self.current_idx += 1 + + return batch_steps, batch_indices + + def get_total_batches(self): + """Return the total number of batches.""" + return (self.total_steps + self.batch_size - 1) // self.batch_size + + +def process_results( + results: list[dict[str, bool | str | int]], + batch_indices: list[tuple[int, int]], + processed_data: dict[int, list[dict[str, str | float | int]]], +) -> None: + """ + Assign results back to the original dataset structure. + + Args: + results: List of results from processing the batch, + the outputs from transformers.pipeline(X). + batch_indices: List of (dataset_idx, step_idx) tuples. + processed_data: Dictionary to store results, keyed by dataset index. + """ + for result, (dataset_idx, step_idx) in zip(results, batch_indices): + if dataset_idx not in processed_data: + processed_data[dataset_idx] = [] + # Ensure the list is long enough to insert at step_idx + while len(processed_data[dataset_idx]) <= step_idx: + processed_data[dataset_idx].append(None) + processed_data[dataset_idx][step_idx] = result diff --git a/src/sal/utils/parser.py b/src/sal/utils/parser.py index fa19fa02..60df1b2b 100644 --- a/src/sal/utils/parser.py +++ b/src/sal/utils/parser.py @@ -43,9 +43,17 @@ def parse_yaml_and_args( outputs = [] # strip other args list into dict of key-value pairs - other_args = { - arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args - } + other_args_parsed = {} + for arg in other_args: + values = arg.split("=") + if len(values) == 1: + arg = values[0] # a boolean value, like --push-to-hub + value = "True" + else: + arg, value = values + other_args_parsed[arg.strip("-")] = value + other_args = other_args_parsed + del other_args_parsed used_args = {} # overwrite the default/loaded value with the value provided to the command line