Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions recipes/Llama-3.2-1B-Instruct/best_of_n_trl_prm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# refer to src/sal/config.py for more options

approach: best_of_n
n: 32
search_batch_size: 25
sort_completed: true
filter_duplicates: true
# num_samples: 5 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET
seed: 0
system_prompt: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n[Concise description]\n[Brief explanation and calculations]\n\n[Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem."
prm_path: "HuggingFaceH4/Qwen2.5-Math-1.5B-Instruct-PRM-0.2"
overwrite_hub_revision: true # While testing the new PRM model
prm_batch_size: 32
13 changes: 13 additions & 0 deletions recipes/Llama-3.2-1B-Instruct/dvts_trl.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# refer to src/sal/config.py for more options

approach: dvts
n: 32
search_batch_size: 25
sort_completed: true
filter_duplicates: true
# num_samples: 10 # REMOVE THIS LINE TO RUN ON THE WHOLE DATASET
seed: 0
system_prompt: "Solve the following math problem efficiently and clearly:\n\n- For simple problems (2 steps or fewer):\nProvide a concise solution with minimal explanation.\n\n- For complex problems (3 steps or more):\nUse this step-by-step format:\n\n[Concise description]\n[Brief explanation and calculations]\n\n[Concise description]\n[Brief explanation and calculations]\n\n...\n\nRegardless of the approach, always conclude with:\n\nTherefore, the final answer is: $\\boxed{answer}$. I hope it is correct.\n\nWhere [answer] is just the final number or expression that solves the problem."
prm_path: "HuggingFaceH4/Qwen2.5-Math-1.5B-Instruct-PRM-0.2"
overwrite_hub_revision: true # While testing the new PRM model
prm_batch_size: 32
3 changes: 3 additions & 0 deletions src/sal/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class Config:
filter_duplicates: bool = False
sort_completed: bool = False

# PRM related options
separator: str = "\n\n"

def __post_init__(self):
if self.approach == "dvts":
if self.n % self.beam_width != 0:
Expand Down
84 changes: 81 additions & 3 deletions src/sal/models/reward_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
from itertools import accumulate

import torch
from tqdm import tqdm
from transformers import (
AutoModelForCausalLM,
AutoModelForTokenClassification,
AutoTokenizer,
PreTrainedModel,
PreTrainedTokenizer,
Expand All @@ -30,9 +32,12 @@
prepare_input,
)
from sal.models.skywork_o1_prm.prm_model import SkyworkPRMModel
from sal.models.utils import BatchProcessor, Example, process_results

CANDIDATE_TOKENS = [648, 387]
STEP_TAG_ID = 12902
LABEL_MAP = {"LABEL_0": False, "LABEL_1": True}
LABEL_FOR_TRUE = "LABEL_1"


def batched_math_shepherd_inference(
Expand Down Expand Up @@ -129,9 +134,9 @@ def score(

# stripped_output_scores = [] TODO: strip out the reward for previous steps
for output_score, output in zip(output_scores, outputs):
assert len(output_score) == len(
output
), f"{len(output_score)} != {len(output)}"
assert len(output_score) == len(output), (
f"{len(output_score)} != {len(output)}"
)

return output_scores

Expand Down Expand Up @@ -340,6 +345,76 @@ def load_model_and_tokenizer(
return SkyworkO1._load_model_and_tokenizer(prm_model_path, **model_kwargs)


class TRLPRM(PRM):
def load_model_and_tokenizer(
self, **model_kwargs
) -> tuple[PreTrainedModel, PreTrainedTokenizer]:
tokenizer = AutoTokenizer.from_pretrained(self.search_config.prm_path)
tokenizer.padding_side = (
"left" # To extract the predicted token as the last token of the right
)
model = AutoModelForTokenClassification.from_pretrained(
self.search_config.prm_path,
device_map="auto",
torch_dtype=torch.float16,
**model_kwargs,
).eval()

return model, tokenizer

def score(
self, questions: list[str], outputs: list[list[str]]
) -> list[list[float]]:
inputs_for_prm = [
Example(problem=question, steps=answers, sep=self.search_config.separator)
for question, answers in zip(questions, outputs)
]
batch_processor = BatchProcessor(
inputs_for_prm, self.search_config.prm_batch_size
)
processed_data = {}

for batch_steps, batch_indices in tqdm(
batch_processor,
total=batch_processor.get_total_batches(),
desc="PRM Inference...",
):
with torch.no_grad():
# batch_steps = ['Let $a,$ $b,$ and $c$ be positive real numbers. Find the set of all possible values of\n\\[\\frac{c}{a} + \\frac{a}{b + c} + \\frac{b}{c}.\\]\n\nThis problem involves finding the range of an expression involving three variables.', 'Let $a,$ $b,$ and $c$ be positive real numbers. Find the set of all possible values of\n\\[\\frac{c}{a} + \\frac{a}{b + c} + \\frac{b}{c}.\\]\n\nThis problem involves finding the range of an expression involving three variables.\n\nOne possible strategy is to try to eliminate some variables and write the expression in terms of one variable only.']
tokenized_batch = self.tokenizer(
batch_steps, padding=True, return_tensors="pt"
).to(self.model.device)
# Get model outputs
batched_outputs = self.model(**tokenized_batch)
# Transform to probabilities, and extract the ones corresponding
# to the TRUE class (LABEL_1, which is the first class)
scores = batched_outputs.logits.softmax(dim=-1)[:, :, 0]
# The probabilities for the batch can be extracted by finding the prob
# of the last token, which should correspond to the
probs = scores[:, -1].tolist() # To extract them from cuda
# batched_outputs = self.pipeline(batch_steps)

# Assign results back to original structure
process_results(probs, batch_indices, processed_data)
# process_results(batched_outputs, batch_indices, processed_data)
# Clear GPU memory
del batch_steps, batched_outputs, scores, probs
torch.cuda.empty_cache()

# The "processed_data" comes sorted as a dict with the index, and the different
# scores. Now we group each group of answers to its N.
reshaped_output_scores = []
counter = 0
for _, answers in zip(questions, outputs):
scores = []
for _ in answers:
scores.append(processed_data[counter])
counter += 1
reshaped_output_scores.append(scores)

return reshaped_output_scores


def load_prm(config: Config) -> PRM:
if config.prm_path == "peiyi9979/math-shepherd-mistral-7b-prm":
return MathShepherd(config)
Expand All @@ -353,4 +428,7 @@ def load_prm(config: Config) -> PRM:
if config.prm_path == "Skywork/Skywork-o1-Open-PRM-Qwen-2.5-7B":
return SkyworkO1_7B(config)

if config.prm_path.startswith("HuggingFaceH4"):
return TRLPRM(config)

raise NotImplementedError(f"PRM {config.prm_path} not implemented")
106 changes: 106 additions & 0 deletions src/sal/models/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from dataclasses import dataclass
from functools import cached_property


@dataclass
class Example:
problem: str
steps: list[str]
sep: str = "\n"

@cached_property
def get_texts(self):
"""Returns the lists with each problem and solution steps concatenated
with the separator.
"""
return [
self.sep.join((self.problem, *self.steps[:i])) + self.sep
for i, step in enumerate(self.steps, start=1)
]


class BatchProcessor:
"""Helper class to allow passing batches to the model pipeline including different
problem and solutions steps. It allows assigning back the steps of the errors at the
end by finding the corresponding index of the problems in the batches.
"""

def __init__(self, data: list[Example], batch_size: int = 32):
self.data = data
self.batch_size = batch_size
self.current_idx = 0

# Create index mapping for steps
self.step_mapping = [] # [(dataset_idx, step_idx), ...]
for idx, item in enumerate(data):
for step_idx in range(len(item.steps)):
self.step_mapping.append((idx, step_idx))

self.total_steps = len(self.step_mapping)

def __iter__(self):
self.current_idx = 0
return self

def __next__(self):
if self.current_idx >= self.total_steps:
raise StopIteration

batch_indices = []
batch_steps = []
step_count = 0

while self.current_idx < self.total_steps and step_count < self.batch_size:
dataset_idx, step_idx = self.step_mapping[self.current_idx]
batch_indices.append((dataset_idx, step_idx))

# Here the steps have to be already generated
steps = self.data[dataset_idx].get_texts
batch_steps.append(steps[step_idx])

step_count += 1
self.current_idx += 1

return batch_steps, batch_indices

def get_total_batches(self):
"""Return the total number of batches."""
return (self.total_steps + self.batch_size - 1) // self.batch_size


def process_results(
results: list[dict[str, bool | str | int]],
batch_indices: list[tuple[int, int]],
processed_data: dict[int, list[dict[str, str | float | int]]],
) -> None:
"""
Assign results back to the original dataset structure.

Args:
results: List of results from processing the batch,
the outputs from transformers.pipeline(X).
batch_indices: List of (dataset_idx, step_idx) tuples.
processed_data: Dictionary to store results, keyed by dataset index.
"""
for result, (dataset_idx, step_idx) in zip(results, batch_indices):
if dataset_idx not in processed_data:
processed_data[dataset_idx] = []
# Ensure the list is long enough to insert at step_idx
while len(processed_data[dataset_idx]) <= step_idx:
processed_data[dataset_idx].append(None)
processed_data[dataset_idx][step_idx] = result
14 changes: 11 additions & 3 deletions src/sal/utils/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,17 @@ def parse_yaml_and_args(

outputs = []
# strip other args list into dict of key-value pairs
other_args = {
arg.split("=")[0].strip("-"): arg.split("=")[1] for arg in other_args
}
other_args_parsed = {}
for arg in other_args:
values = arg.split("=")
if len(values) == 1:
arg = values[0] # a boolean value, like --push-to-hub
value = "True"
else:
arg, value = values
other_args_parsed[arg.strip("-")] = value
other_args = other_args_parsed
del other_args_parsed
used_args = {}

# overwrite the default/loaded value with the value provided to the command line
Expand Down