Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
1b6c5e1
trying out accum fix
hamishivi Nov 3, 2025
23ea9ce
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi Nov 10, 2025
92a4bc2
fix
hamishivi Nov 10, 2025
a9ae929
fix
hamishivi Nov 10, 2025
c7ccc15
Fix up
hamishivi Nov 10, 2025
0aad4ac
fix
hamishivi Nov 10, 2025
89fb420
fix
hamishivi Nov 10, 2025
b37079c
fix
hamishivi Nov 10, 2025
0331016
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi Nov 24, 2025
1390a6e
loss fixes
hamishivi Nov 24, 2025
f94b4b2
fix
hamishivi Nov 24, 2025
1931b99
lint
hamishivi Nov 24, 2025
90eb98e
fix
hamishivi Nov 24, 2025
6e698e9
fix
hamishivi Nov 24, 2025
802a7e7
quick and dirty group-level
hamishivi Nov 24, 2025
849ebb4
fix quality
hamishivi Nov 24, 2025
8c7559d
correct hacky group level
hamishivi Nov 24, 2025
febf44a
fix
hamishivi Nov 26, 2025
1f85578
remove group loss
hamishivi Dec 1, 2025
6a78727
simplify a little
hamishivi Dec 1, 2025
0331d17
small fix
hamishivi Dec 1, 2025
fe7642c
whoops, fix indent
hamishivi Dec 1, 2025
7d526a1
small test
hamishivi Dec 1, 2025
338b482
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi Dec 1, 2025
ff86d58
fix indent
hamishivi Dec 1, 2025
f3ee394
rejig based on feedback
hamishivi Dec 1, 2025
b764004
clean
hamishivi Dec 1, 2025
408eadd
fix
hamishivi Dec 1, 2025
f5d8afd
fix
hamishivi Dec 1, 2025
222d565
fix
hamishivi Dec 1, 2025
ccf5563
cleanup
hamishivi Dec 1, 2025
de40be8
little refactor
hamishivi Dec 1, 2025
637f44f
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi Dec 1, 2025
540e34a
ruff fix
hamishivi Dec 1, 2025
b14d1ae
finbarr comments
hamishivi Dec 1, 2025
e2fa645
fix test case
hamishivi Dec 1, 2025
a6341ed
Merge branch 'main' into accum-loss-fix-grpo-fast
hamishivi Dec 1, 2025
23d5add
masked mean import
hamishivi Dec 1, 2025
3c9720b
new masked mean
hamishivi Dec 1, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 73 additions & 37 deletions open_instruct/grpo_fast.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,12 @@
get_olmo3_generation_config,
load_ref_policy,
log_softmax_and_gather,
masked_mean,
print_rich_single_line_metrics,
print_rich_table,
push_folder_to_hub,
)
from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics
from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences
from open_instruct.rl_utils import PackedSequences, Timer, masked_mean, pack_sequences
from open_instruct.utils import (
ArgumentParserPlus,
BeakerRuntimeConfig,
Expand Down Expand Up @@ -251,10 +250,11 @@ class Args:
"""the KL estimator to use"""
pack_length: int = 512
"""the length of the pack (you should prob set to the max length of the model)"""
masked_mean_axis: int | None = None
"""the axis to compute the mean of the masked values"""
masked_mean_denominator: float | None = None
"""Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum"""
loss_denominator: str = "token"
"""Optional constant denominator for masked_mean; can be "token" or a float value.
when "token", the loss is divided by the total number of tokens in the batch (standard LM training).
when a float value, the loss is divided by this value (ideally, max tokens in batch, per Dr GRPO).
"""
alpha: float = 0.6
"""The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param)
reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly
Expand Down Expand Up @@ -458,10 +458,7 @@ def __post_init__(self):
"Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. "
"use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless."
)
if self.masked_mean_denominator is not None:
assert self.masked_mean_denominator > 0, (
f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!"
)
self.loss_denominator = utils.get_denominator(self.loss_denominator)
assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!"
if self.num_samples_per_prompt_rollout == 1:
logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.")
Expand Down Expand Up @@ -501,7 +498,7 @@ def __post_init__(self):
else:
self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}"
# On GCP, all checkpointing must happen on filestore.
# TODO(finbarrtimbers): Chanage this so we can checkpoint to GCS.
# TODO(finbarrtimbers): Change this so we can checkpoint to GCS.
# TODO(finbarrtimbers): Move this logic to mason.py once we refactor config.
if not checkpoint_dir_name.startswith("/filestore"):
self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}"
Expand Down Expand Up @@ -1051,6 +1048,42 @@ def compute_logprobs(

return collated_logprobs, collated_entropies

def calculate_token_counts(
self,
accumulation_steps: int,
collated_response_masks: list[torch.Tensor],
collated_tool_masks: list[torch.Tensor],
) -> dict[int, float]:
"""
Compute the number of training tokens in each batch for this set of responses.
Return a dictionary of batch indices to the number of training tokens in that batch.
"""
accumulation_counts: dict[int, float] = {}
local_counts = []

for i, response_mask in enumerate(collated_response_masks):
response_mask = response_mask.to(self.device)
mask = response_mask[:, 1:].bool()
if self.args.mask_tool_use and self.args.tool_use:
tool_mask = collated_tool_masks[i].to(self.device)
mask &= tool_mask[:, 1:].bool()

local_counts.append(mask.sum().float())

if not local_counts:
return accumulation_counts

# do the all_reduce once to avoid calling each loop
counts_tensor = torch.stack(local_counts)
dist.all_reduce(counts_tensor, op=dist.ReduceOp.SUM)

for i, count in enumerate(counts_tensor):
group_idx = i // accumulation_steps
key = int(group_idx * accumulation_steps)
accumulation_counts[key] = accumulation_counts.get(key, 0.0) + count.item()

return accumulation_counts

def train(
self,
collated_query_responses,
Expand Down Expand Up @@ -1153,6 +1186,18 @@ def train(
ratio_stats = torch.zeros(len(collated_query_responses))
entropy_stats = torch.zeros(len(collated_query_responses))
for epoch_idx in range(args.num_epochs):
# Pre-compute total tokens for each accumulation group if using "token" normalization
# This ensures all minibatches in an accumulation group are normalized by the same total
if args.loss_denominator == "token":
accumulation_token_counts = self.calculate_token_counts(
accumulation_steps, collated_response_masks, collated_tool_masks
)
else:
accumulation_token_counts = {
int(group_idx * accumulation_steps): args.loss_denominator
for group_idx in range((len(collated_query_responses) // accumulation_steps) + 1)
}

for i in range(len(collated_query_responses)):
mb_query_responses = collated_query_responses[i]
mb_tool_mask = collated_tool_masks[i]
Expand All @@ -1162,6 +1207,11 @@ def train(
# if masking snippets, do it here.
if args.mask_tool_use and args.tool_use:
mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool()

# retrieve the loss denominator for the current batch
batch_start = (i // accumulation_steps) * accumulation_steps
loss_denominator = accumulation_token_counts[batch_start]

mb_attention_mask = collated_attention_masks[i]
mb_position_id = collated_position_ids[i]
mb_local_logprobs, mb_entropy = self.forward(
Expand Down Expand Up @@ -1277,16 +1327,15 @@ def train(
kl = kl_4BT[args.kl_estimator]
# grpo change: directly subtract KL in loss (add)
loss = masked_mean(
pg_loss_max + (args.beta * kl),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
pg_loss_max + (args.beta * kl), mb_response_masks_bool, None, loss_denominator
)
else:
loss = masked_mean(
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
loss = loss / accumulation_steps
loss = masked_mean(pg_loss_max, mb_response_masks_bool, None, loss_denominator)

# we already took world size into account via the tokens
if dist.is_available() and dist.is_initialized():
loss *= dist.get_world_size()

# Clear CUDA cache before backward pass to free memory for reduce_scatter operations
torch.cuda.empty_cache()
self.model.backward(loss)
Expand All @@ -1296,28 +1345,15 @@ def train(
with torch.no_grad():
if args.load_ref_policy:
# NOTE: in packed implementation, kl calculation are averages over response tokens
kl_stats_4M[:, i] = masked_mean(
kl_4BT, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
kl_stats_4M[:, i] = masked_mean(kl_4BT, mb_response_masks_bool).float()
kl_loss_stats[i] = kl_stats_4M[args.kl_estimator, i] * args.beta
pg_clipfrac_stats[i] = masked_mean(
(pg_losses2 > pg_losses).float(),
mb_response_masks_bool,
args.masked_mean_axis,
args.masked_mean_denominator,
)
pg_loss_stats[i] = masked_mean(
pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
pg_clipfrac_stats[i] = masked_mean((pg_losses2 > pg_losses).float(), mb_response_masks_bool)
pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool)
loss_stats[i] = loss
ratio_stats[i] = masked_mean(
ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
)
ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool)
if args.record_entropy:
# Calculate entropy statistics
entropy_stats[i] = masked_mean(
mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator
).float()
entropy_stats[i] = masked_mean(mb_entropy, mb_response_masks_bool).float()

with torch.no_grad():
if args.load_ref_policy:
Expand Down
17 changes: 0 additions & 17 deletions open_instruct/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,23 +769,6 @@ def exact_div(a, b, custom_error_message=""):
return q


def masked_mean(
values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None
) -> torch.Tensor:
"""Compute mean of tensor with masked values."""
extra_dims = values.ndim - mask.ndim
if axis is None:
sum_dims = tuple(range(extra_dims, values.ndim))
elif axis >= 0:
sum_dims = axis + extra_dims
else:
sum_dims = axis
numerator = (values * mask).sum(dim=sum_dims)
denom = mask.sum(dim=axis) if denominator is None else denominator
result = numerator / denom
return result.flatten(extra_dims).mean(-1) if result.ndim > extra_dims else result


def estimate_kl(ref_logprobs_diff: torch.Tensor, ratio: torch.Tensor) -> torch.Tensor:
"""Compute 4 different KL divergence estimators between current and reference policies.

Expand Down
17 changes: 17 additions & 0 deletions open_instruct/rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,3 +249,20 @@ def calculate_advantages_packed(
advantages = np.stack(advantages_reversed[::-1], axis=1)
returns = advantages + values
return advantages, returns


def masked_mean(
values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None
) -> torch.Tensor:
"""Compute mean of tensor with masked values."""
extra_dims = values.ndim - mask.ndim
if axis is None:
sum_dims = tuple(range(extra_dims, values.ndim))
elif axis >= 0:
sum_dims = axis + extra_dims
else:
sum_dims = axis
numerator = (values * mask).sum(dim=sum_dims)
denom = mask.sum(dim=axis) if denominator is None else denominator
result = numerator / denom
return result.flatten(extra_dims).mean(-1) if result.ndim > extra_dims else result
59 changes: 0 additions & 59 deletions open_instruct/test_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,65 +42,6 @@ def test_batch_slicing_with_none_fields(self):
self.assertIsNone(sliced.scores)


class TestMaskedMean(unittest.TestCase):
def test_original_axis_int(self):
values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = open_instruct.model_utils.masked_mean(values, mask, axis=1)
expected = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2
self.assertAlmostEqual(result.item(), expected)

def test_original_axis_none(self):
values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = open_instruct.model_utils.masked_mean(values, mask, axis=None)
expected = (1.0 + 2.0 + 4.0) / 3
self.assertAlmostEqual(result.item(), expected, places=5)

def test_vectorized_axis_int(self):
kl_4BT = torch.tensor(
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]],
[[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]],
[[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]],
]
)
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = open_instruct.model_utils.masked_mean(kl_4BT, mask, axis=1)
self.assertEqual(result.shape, (4,))
expected_0 = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2
expected_1 = ((10.0 + 20.0) / 2 + 40.0 / 1) / 2
expected_2 = ((100.0 + 200.0) / 2 + 400.0 / 1) / 2
expected_3 = ((1000.0 + 2000.0) / 2 + 4000.0 / 1) / 2
self.assertAlmostEqual(result[0].item(), expected_0)
self.assertAlmostEqual(result[1].item(), expected_1)
self.assertAlmostEqual(result[2].item(), expected_2)
self.assertAlmostEqual(result[3].item(), expected_3)

def test_vectorized_axis_none(self):
kl_4BT = torch.tensor(
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]],
[[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]],
[[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]],
]
)
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = open_instruct.model_utils.masked_mean(kl_4BT, mask, axis=None)
self.assertEqual(result.shape, (4,))
expected = torch.tensor(
[
(1.0 + 2.0 + 4.0) / 3,
(10.0 + 20.0 + 40.0) / 3,
(100.0 + 200.0 + 400.0) / 3,
(1000.0 + 2000.0 + 4000.0) / 3,
]
)
self.assertTrue(torch.allclose(result, expected))


class TestLogSoftmaxAndGather(unittest.TestCase):
def test_log_softmax_and_gather_sliced_logits(self):
batch_size, seq_len, vocab_size = 2, 160, 151936
Expand Down
59 changes: 59 additions & 0 deletions open_instruct/test_rl_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,5 +263,64 @@ def test_pack_sequences_logits(self):
)


class TestMaskedMean(unittest.TestCase):
def test_original_axis_int(self):
values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = rl_utils.masked_mean(values, mask, axis=1)
expected = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2
self.assertAlmostEqual(result.item(), expected)

def test_original_axis_none(self):
values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = rl_utils.masked_mean(values, mask, axis=None)
expected = (1.0 + 2.0 + 4.0) / 3
self.assertAlmostEqual(result.item(), expected, places=5)

def test_vectorized_axis_int(self):
kl_4BT = torch.tensor(
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]],
[[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]],
[[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]],
]
)
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = rl_utils.masked_mean(kl_4BT, mask, axis=1)
self.assertEqual(result.shape, (4,))
expected_0 = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2
expected_1 = ((10.0 + 20.0) / 2 + 40.0 / 1) / 2
expected_2 = ((100.0 + 200.0) / 2 + 400.0 / 1) / 2
expected_3 = ((1000.0 + 2000.0) / 2 + 4000.0 / 1) / 2
self.assertAlmostEqual(result[0].item(), expected_0)
self.assertAlmostEqual(result[1].item(), expected_1)
self.assertAlmostEqual(result[2].item(), expected_2)
self.assertAlmostEqual(result[3].item(), expected_3)

def test_vectorized_axis_none(self):
kl_4BT = torch.tensor(
[
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]],
[[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]],
[[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]],
[[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]],
]
)
mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]])
result = rl_utils.masked_mean(kl_4BT, mask, axis=None)
self.assertEqual(result.shape, (4,))
expected = torch.tensor(
[
(1.0 + 2.0 + 4.0) / 3,
(10.0 + 20.0 + 40.0) / 3,
(100.0 + 200.0 + 400.0) / 3,
(1000.0 + 2000.0 + 4000.0) / 3,
]
)
self.assertTrue(torch.allclose(result, expected))


if __name__ == "__main__":
unittest.main()
18 changes: 18 additions & 0 deletions open_instruct/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,21 @@ def mock_inspect_return(*args, **kwargs):
# "natolambert/tulu-v2-sft-mixture-science": 7468, # original data slightly different
# }
# _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"])


class TestGetDenominator(unittest.TestCase):
@parameterized.expand([("token", "token"), ("0.5", 0.5), (0.5, 0.5), (1, 1.0)])
def test_valid_inputs(self, input_val, expected):
self.assertEqual(utils.get_denominator(input_val), expected)

@parameterized.expand(
[
("invalid", "could not convert string to float"),
("-1", "loss_denominator must be greater than 0"),
(0, "loss_denominator must be greater than 0"),
("0", "loss_denominator must be greater than 0"),
]
)
def test_invalid_inputs(self, input_val, error_msg):
with self.assertRaisesRegex(ValueError, error_msg):
utils.get_denominator(input_val)
13 changes: 13 additions & 0 deletions open_instruct/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2491,3 +2491,16 @@ def get_beaker_experiment_url() -> str | None:
return url
except Exception:
return None


def get_denominator(loss_denominator: str | float) -> float | str:
"""
Validates and converts the loss_denominator argument.
"""
if loss_denominator == "token":
return "token"

val = float(loss_denominator)
if val <= 0:
raise ValueError(f"loss_denominator must be greater than 0 if not 'token', got: {loss_denominator}")
return val