From 1b6c5e1b22f22fa0564628b12f66f3d09f471d20 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 3 Nov 2025 15:31:47 -0800 Subject: [PATCH 01/34] trying out accum fix --- open_instruct/grpo_fast.py | 43 ++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 84abd902ba..4d3b893196 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -252,7 +252,9 @@ class Args: masked_mean_axis: int | None = None """the axis to compute the mean of the masked values""" masked_mean_denominator: float | None = None - """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum""" + """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. + Special value -1 means use total_batch_tokens (computed across all ranks in distributed training). + When using -1, total_batch_tokens is gathered via allreduce across all ranks.""" alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly @@ -1005,6 +1007,28 @@ def train( local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): + # Calculate total tokens across all minibatches for proper loss normalization + # This ensures loss is normalized by total tokens in the entire batch, not just per-minibatch + # First, calculate tokens for this rank's data + local_total_batch_tokens = 0.0 + for i in range(len(collated_query_responses)): + mb_response_masks = collated_response_masks[i] + mb_response_masks_bool = mb_response_masks[:, 1:].bool() + # Apply same masking logic as in loss computation + if args.mask_tool_use and args.tool_use: + mb_tool_mask = collated_tool_masks[i] + mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() + local_total_batch_tokens += mb_response_masks_bool.sum().item() + + # Gather total tokens across all ranks if using distributed training and denominator is -1 + # This ensures normalization is consistent across all ranks + if dist.is_available() and dist.is_initialized(): + local_total_batch_tokens_tensor = torch.tensor( + local_total_batch_tokens, dtype=torch.float32, device=self.device + ) + dist.all_reduce(local_total_batch_tokens_tensor, op=dist.ReduceOp.SUM) + total_batch_tokens = local_total_batch_tokens_tensor.item() + kl1_stats = torch.zeros(len(collated_query_responses)) kl2_stats = torch.zeros(len(collated_query_responses)) kl3_stats = torch.zeros(len(collated_query_responses)) @@ -1149,12 +1173,19 @@ def train( kl = kl4 # grpo change: directly subtract KL in loss (add) + loss_values = pg_loss_max + (args.beta * kl) + + # Three loss cases: + # masked_mean_denominator is set: we use sum and divide loss by this constant. + # masked_mean_denominator is set to -1: we use sum and divide loss by total number of tokens in batch. + # masked_mean_denominator is None, masked_mean_axis is None: we take mean across tokens in minibatch (old behaviour) + # masked_mean_denominator is None, masked_mean_axis is 1: we use sample-wise averaging across the sequence axis. loss = masked_mean( - pg_loss_max + (args.beta * kl), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, - ) + loss_values, + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator if args.masked_mean_denominator != -1 else total_batch_tokens, + ) loss = loss / accumulation_steps self.model.backward(loss) if (local_step + 1) % accumulation_steps == 0: From 92a4bc2a3747b6e786aa51370754fef07db107d5 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 9 Nov 2025 21:02:19 -0800 Subject: [PATCH 02/34] fix --- open_instruct/grpo_fast.py | 46 +++++++++++++++++++++++++------------- 1 file changed, 30 insertions(+), 16 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1571ee2ab6..249f21da81 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -252,10 +252,10 @@ class Args: """the length of the pack (you should prob set to the max length of the model)""" masked_mean_axis: int | None = None """the axis to compute the mean of the masked values""" - masked_mean_denominator: float | None = None + masked_mean_denominator: float | str | None = None """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. - Special value -1 means use total_batch_tokens (computed across all ranks in distributed training). - When using -1, total_batch_tokens is gathered via allreduce across all ranks.""" + Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). + When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly @@ -444,9 +444,14 @@ def __post_init__(self): "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) if self.masked_mean_denominator is not None: - assert self.masked_mean_denominator > 0, ( - f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" - ) + if isinstance(self.masked_mean_denominator, str): + assert self.masked_mean_denominator == "token", ( + f"masked_mean_denominator string value must be 'token' or number, got {self.masked_mean_denominator}" + ) + else: + assert self.masked_mean_denominator > 0, ( + f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" + ) assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") @@ -1121,14 +1126,18 @@ def train( mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() local_total_batch_tokens += mb_response_masks_bool.sum().item() - # Gather total tokens across all ranks if using distributed training and denominator is -1 + # Gather total tokens across all ranks if using distributed training and denominator is "token" # This ensures normalization is consistent across all ranks - if dist.is_available() and dist.is_initialized(): - local_total_batch_tokens_tensor = torch.tensor( - local_total_batch_tokens, dtype=torch.float32, device=self.device - ) - dist.all_reduce(local_total_batch_tokens_tensor, op=dist.ReduceOp.SUM) - total_batch_tokens = local_total_batch_tokens_tensor.item() + if args.masked_mean_denominator == "token": + if dist.is_available() and dist.is_initialized(): + local_total_batch_tokens_tensor = torch.tensor( + local_total_batch_tokens, dtype=torch.float32, device=self.device + ) + dist.all_reduce(local_total_batch_tokens_tensor, op=dist.ReduceOp.SUM) + total_batch_tokens = local_total_batch_tokens_tensor.item() + else: + # Non-distributed case: total_batch_tokens is just local tokens + total_batch_tokens = local_total_batch_tokens kl1_stats = torch.zeros(len(collated_query_responses)) kl2_stats = torch.zeros(len(collated_query_responses)) @@ -1278,16 +1287,21 @@ def train( # Three loss cases: # masked_mean_denominator is set: we use sum and divide loss by this constant. - # masked_mean_denominator is set to -1: we use sum and divide loss by total number of tokens in batch. + # masked_mean_denominator is set to "token": we use sum and divide loss by total number of tokens in batch. # masked_mean_denominator is None, masked_mean_axis is None: we take mean across tokens in minibatch (old behaviour) # masked_mean_denominator is None, masked_mean_axis is 1: we use sample-wise averaging across the sequence axis. loss = masked_mean( loss_values, mb_response_masks_bool, args.masked_mean_axis, - args.masked_mean_denominator if args.masked_mean_denominator != -1 else total_batch_tokens, + args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens, ) - loss = loss / accumulation_steps + # When using global normalization (masked_mean_denominator == "token"), total_batch_tokens already + # includes tokens from all ranks and all minibatches, so we should NOT divide by accumulation_steps. + # For other normalization modes, we divide by accumulation_steps to properly scale gradients + # for gradient accumulation. + if args.masked_mean_denominator != "token": + loss = loss / accumulation_steps self.model.backward(loss) if (local_step + 1) % accumulation_steps == 0: self.model.step() From a9ae9291892c9d3572cdaf5273a2954c0886334c Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 9 Nov 2025 21:25:42 -0800 Subject: [PATCH 03/34] fix --- open_instruct/grpo_fast.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 249f21da81..2e9458717e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1307,18 +1307,22 @@ def train( self.model.step() local_step += 1 with torch.no_grad(): + # Convert "token" to total_batch_tokens for statistics computation + stats_denominator = ( + args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens + ) # NOTE: in packed implementation, kl calculation are averages over response tokens kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + kl1, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ).float() kl2_stats[i] = masked_mean( - kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + kl2, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ).float() kl3_stats[i] = masked_mean( - kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + kl3, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ).float() kl4_stats[i] = masked_mean( - kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + kl4, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ).float() if args.kl_estimator == "kl1": kl_loss_stats[i] = kl1_stats[i] * args.beta @@ -1332,19 +1336,19 @@ def train( (pg_losses2 > pg_losses).float(), mb_response_masks_bool, args.masked_mean_axis, - args.masked_mean_denominator, + stats_denominator, ) pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ) loss_stats[i] = loss ratio_stats[i] = masked_mean( - ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + ratio, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ) if args.record_entropy: # Calculate entropy statistics entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator + mb_entropy, mb_response_masks_bool, args.masked_mean_axis, stats_denominator ).float() with torch.no_grad(): From c7ccc15b68521bce2c6e4719c22ab25c1973de07 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 9 Nov 2025 22:34:40 -0800 Subject: [PATCH 04/34] Fix up --- open_instruct/grpo_fast.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 2e9458717e..246f2631ad 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1130,11 +1130,16 @@ def train( # This ensures normalization is consistent across all ranks if args.masked_mean_denominator == "token": if dist.is_available() and dist.is_initialized(): + # Ensure all ranks have computed local_total_batch_tokens before all_reduce + dist.barrier() local_total_batch_tokens_tensor = torch.tensor( local_total_batch_tokens, dtype=torch.float32, device=self.device ) - dist.all_reduce(local_total_batch_tokens_tensor, op=dist.ReduceOp.SUM) + dist.all_reduce(local_total_batch_tokens_tensor, op=dist.ReduceOp.SUM, group=None) total_batch_tokens = local_total_batch_tokens_tensor.item() + if self.rank == 0: + logger.debug(f"[Rank {self.rank}]: total_batch_tokens across all ranks={total_batch_tokens:.0f}") + else: # Non-distributed case: total_batch_tokens is just local tokens total_batch_tokens = local_total_batch_tokens From 0aad4ac84624ed2fe0364781222ae8fa2f6264af Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 9 Nov 2025 22:48:02 -0800 Subject: [PATCH 05/34] fix --- open_instruct/grpo_fast.py | 73 ++++++++++++++++++++------------------ 1 file changed, 38 insertions(+), 35 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 246f2631ad..776fb878f4 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1113,37 +1113,6 @@ def train( local_step = 0 # Do multiple epochs of training on on-policy data (PPO-style), with a fresh random shuffle in each epoch with Timer("[Training Processes] Loss calculation", noop=self.rank != 0): - # Calculate total tokens across all minibatches for proper loss normalization - # This ensures loss is normalized by total tokens in the entire batch, not just per-minibatch - # First, calculate tokens for this rank's data - local_total_batch_tokens = 0.0 - for i in range(len(collated_query_responses)): - mb_response_masks = collated_response_masks[i] - mb_response_masks_bool = mb_response_masks[:, 1:].bool() - # Apply same masking logic as in loss computation - if args.mask_tool_use and args.tool_use: - mb_tool_mask = collated_tool_masks[i] - mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() - local_total_batch_tokens += mb_response_masks_bool.sum().item() - - # Gather total tokens across all ranks if using distributed training and denominator is "token" - # This ensures normalization is consistent across all ranks - if args.masked_mean_denominator == "token": - if dist.is_available() and dist.is_initialized(): - # Ensure all ranks have computed local_total_batch_tokens before all_reduce - dist.barrier() - local_total_batch_tokens_tensor = torch.tensor( - local_total_batch_tokens, dtype=torch.float32, device=self.device - ) - dist.all_reduce(local_total_batch_tokens_tensor, op=dist.ReduceOp.SUM, group=None) - total_batch_tokens = local_total_batch_tokens_tensor.item() - if self.rank == 0: - logger.debug(f"[Rank {self.rank}]: total_batch_tokens across all ranks={total_batch_tokens:.0f}") - - else: - # Non-distributed case: total_batch_tokens is just local tokens - total_batch_tokens = local_total_batch_tokens - kl1_stats = torch.zeros(len(collated_query_responses)) kl2_stats = torch.zeros(len(collated_query_responses)) kl3_stats = torch.zeros(len(collated_query_responses)) @@ -1155,6 +1124,33 @@ def train( ratio_stats = torch.zeros(len(collated_query_responses)) entropy_stats = torch.zeros(len(collated_query_responses)) for epoch_idx in range(args.num_epochs): + # Pre-compute total tokens for each accumulation group if using "token" normalization + # This ensures all minibatches in an accumulation group are normalized by the same total + accumulation_group_tokens = {} + if args.masked_mean_denominator == "token": + for group_start in range(0, len(collated_query_responses), accumulation_steps): + group_end = min(group_start + accumulation_steps, len(collated_query_responses)) + # Calculate local tokens for all minibatches in this accumulation group + local_group_tokens = 0.0 + for i in range(group_start, group_end): + mb_response_masks = collated_response_masks[i] + mb_response_masks_bool = mb_response_masks[:, 1:].bool() + if args.mask_tool_use and args.tool_use: + mb_tool_mask = collated_tool_masks[i] + mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() + local_group_tokens += mb_response_masks_bool.sum().item() + + # Gather total tokens across all ranks for this accumulation group + if dist.is_available() and dist.is_initialized(): + dist.barrier() + local_group_tokens_tensor = torch.tensor( + local_group_tokens, dtype=torch.float32, device=self.device + ) + dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None) + accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() + else: + accumulation_group_tokens[group_start] = local_group_tokens + for i in range(len(collated_query_responses)): mb_ref_logprob = collated_ref_logprobs[i] mb_query_responses = collated_query_responses[i] @@ -1165,6 +1161,13 @@ def train( # if masking snippets, do it here. if args.mask_tool_use and args.tool_use: mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() + + # Get total tokens for this accumulation group if using "token" normalization + # This ensures all minibatches in the accumulation group are normalized by the same total + if args.masked_mean_denominator == "token": + group_start = (i // accumulation_steps) * accumulation_steps + total_batch_tokens = accumulation_group_tokens[group_start] + mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] mb_local_logprobs, mb_entropy = self.forward( @@ -1301,10 +1304,10 @@ def train( args.masked_mean_axis, args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens, ) - # When using global normalization (masked_mean_denominator == "token"), total_batch_tokens already - # includes tokens from all ranks and all minibatches, so we should NOT divide by accumulation_steps. - # For other normalization modes, we divide by accumulation_steps to properly scale gradients - # for gradient accumulation. + # When using global normalization (masked_mean_denominator == "token"), total_batch_tokens is the sum + # of tokens across all minibatches in the accumulation group. Since we normalize by this total, + # we should NOT divide by accumulation_steps (the normalization already accounts for all minibatches). + # For other normalization modes, we divide by accumulation_steps to properly scale gradients. if args.masked_mean_denominator != "token": loss = loss / accumulation_steps self.model.backward(loss) From 89fb420edffa9a63b89934b1075f166ab72f3fcb Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 9 Nov 2025 23:02:36 -0800 Subject: [PATCH 06/34] fix --- open_instruct/grpo_fast.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 776fb878f4..55a081ad17 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -448,6 +448,7 @@ def __post_init__(self): assert self.masked_mean_denominator == "token", ( f"masked_mean_denominator string value must be 'token' or number, got {self.masked_mean_denominator}" ) + assert self.masked_mean_axis is None, "masked_mean_axis must not be provided when using 'token' normalization" else: assert self.masked_mean_denominator > 0, ( f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" @@ -1139,7 +1140,7 @@ def train( mb_tool_mask = collated_tool_masks[i] mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() local_group_tokens += mb_response_masks_bool.sum().item() - + # Gather total tokens across all ranks for this accumulation group if dist.is_available() and dist.is_initialized(): dist.barrier() @@ -1150,7 +1151,7 @@ def train( accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() else: accumulation_group_tokens[group_start] = local_group_tokens - + for i in range(len(collated_query_responses)): mb_ref_logprob = collated_ref_logprobs[i] mb_query_responses = collated_query_responses[i] @@ -1161,13 +1162,13 @@ def train( # if masking snippets, do it here. if args.mask_tool_use and args.tool_use: mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() - + # Get total tokens for this accumulation group if using "token" normalization # This ensures all minibatches in the accumulation group are normalized by the same total if args.masked_mean_denominator == "token": group_start = (i // accumulation_steps) * accumulation_steps total_batch_tokens = accumulation_group_tokens[group_start] - + mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] mb_local_logprobs, mb_entropy = self.forward( @@ -1299,11 +1300,13 @@ def train( # masked_mean_denominator is None, masked_mean_axis is None: we take mean across tokens in minibatch (old behaviour) # masked_mean_denominator is None, masked_mean_axis is 1: we use sample-wise averaging across the sequence axis. loss = masked_mean( - loss_values, - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens, - ) + loss_values, + mb_response_masks_bool, + args.masked_mean_axis, + args.masked_mean_denominator + if args.masked_mean_denominator != "token" + else total_batch_tokens, + ) # When using global normalization (masked_mean_denominator == "token"), total_batch_tokens is the sum # of tokens across all minibatches in the accumulation group. Since we normalize by this total, # we should NOT divide by accumulation_steps (the normalization already accounts for all minibatches). @@ -1315,9 +1318,12 @@ def train( self.model.step() local_step += 1 with torch.no_grad(): - # Convert "token" to total_batch_tokens for statistics computation + # for stats computation, for now no denominator is used + # unless masked_mean_denominator is a numeric value. stats_denominator = ( - args.masked_mean_denominator if args.masked_mean_denominator != "token" else total_batch_tokens + args.masked_mean_denominator + if args.masked_mean_denominator != "token" + else None ) # NOTE: in packed implementation, kl calculation are averages over response tokens kl1_stats[i] = masked_mean( From b37079c552c8e44ecf7ffbc2de6ed1c50b828e4f Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 9 Nov 2025 23:03:54 -0800 Subject: [PATCH 07/34] fix --- open_instruct/grpo_fast.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 55a081ad17..a442a35bd5 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -448,7 +448,9 @@ def __post_init__(self): assert self.masked_mean_denominator == "token", ( f"masked_mean_denominator string value must be 'token' or number, got {self.masked_mean_denominator}" ) - assert self.masked_mean_axis is None, "masked_mean_axis must not be provided when using 'token' normalization" + assert self.masked_mean_axis is None, ( + "masked_mean_axis must not be provided when using 'token' normalization" + ) else: assert self.masked_mean_denominator > 0, ( f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" @@ -1321,9 +1323,7 @@ def train( # for stats computation, for now no denominator is used # unless masked_mean_denominator is a numeric value. stats_denominator = ( - args.masked_mean_denominator - if args.masked_mean_denominator != "token" - else None + args.masked_mean_denominator if args.masked_mean_denominator != "token" else None ) # NOTE: in packed implementation, kl calculation are averages over response tokens kl1_stats[i] = masked_mean( From 1390a6eb9c7763477b5f32991aef2d19a138d2ae Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 23 Nov 2025 23:15:32 -0800 Subject: [PATCH 08/34] loss fixes --- open_instruct/grpo_fast.py | 107 +++++++++++++++--------------------- open_instruct/test_utils.py | 13 +++++ open_instruct/utils.py | 15 +++++ 3 files changed, 71 insertions(+), 64 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 1cdbc32aff..6f0373466e 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -118,6 +118,7 @@ combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, + get_denominator, get_eval_ds_config, get_optimizer_grouped_parameters, get_train_ds_config, @@ -249,12 +250,11 @@ class Args: """the KL estimator to use""" pack_length: int = 512 """the length of the pack (you should prob set to the max length of the model)""" - masked_mean_axis: int | None = None - """the axis to compute the mean of the masked values""" - masked_mean_denominator: float | str | None = None + masked_mean_denominator: float | str | None = "token" """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). - When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" + When using "token", total_batch_tokens is gathered via allreduce across all ranks. + Special value "num_prompts" means use the number of prompts in the batch.""" alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly @@ -458,18 +458,7 @@ def __post_init__(self): "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. " "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) - if self.masked_mean_denominator is not None: - if isinstance(self.masked_mean_denominator, str): - assert self.masked_mean_denominator == "token", ( - f"masked_mean_denominator string value must be 'token' or number, got {self.masked_mean_denominator}" - ) - assert self.masked_mean_axis is None, ( - "masked_mean_axis must not be provided when using 'token' normalization" - ) - else: - assert self.masked_mean_denominator > 0, ( - f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" - ) + self.masked_mean_denominator = get_denominator(self.masked_mean_denominator) assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") @@ -1205,11 +1194,16 @@ def train( if args.mask_tool_use and args.tool_use: mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() - # Get total tokens for this accumulation group if using "token" normalization - # This ensures all minibatches in the accumulation group are normalized by the same total + # Determine the denominator for masked_mean normalization + loss_denominator = args.masked_mean_denominator + loss_axis = None if args.masked_mean_denominator == "token": group_start = (i // accumulation_steps) * accumulation_steps - total_batch_tokens = accumulation_group_tokens[group_start] + loss_denominator = accumulation_group_tokens[group_start] + elif args.masked_mean_denominator == "num_prompts": + # For prompt-level loss, we average across tokens (axis=1) then across batch + loss_denominator = None + loss_axis = 1 mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] @@ -1319,7 +1313,9 @@ def train( # for stats computation, for now no denominator is used # unless masked_mean_denominator is a numeric value. stats_denominator = ( - args.masked_mean_denominator if args.masked_mean_denominator != "token" else None + args.masked_mean_denominator + if args.masked_mean_denominator not in ["token", "num_prompts"] + else None ) if args.load_ref_policy: @@ -1342,16 +1338,12 @@ def train( kl = kl4 # grpo change: directly subtract KL in loss (add) loss = masked_mean( - pg_loss_max + (args.beta * kl), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, + pg_loss_max + (args.beta * kl), mb_response_masks_bool, loss_axis, loss_denominator ) else: - loss = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - loss = loss / accumulation_steps + loss = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) + if args.masked_mean_denominator != "token": + loss = loss / accumulation_steps # Clear CUDA cache before backward pass to free memory for reduce_scatter operations torch.cuda.empty_cache() self.model.backward(loss) @@ -1362,43 +1354,30 @@ def train( if args.load_ref_policy: # NOTE: in packed implementation, kl calculation are averages over response tokens kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl2_stats[i] = masked_mean( - kl2, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl3_stats[i] = masked_mean( - kl3, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - kl4_stats[i] = masked_mean( - kl4, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() - if args.kl_estimator == "kl1": - kl_loss_stats[i] = kl1_stats[i] * args.beta - elif args.kl_estimator == "kl2": - kl_loss_stats[i] = kl2_stats[i] * args.beta - elif args.kl_estimator == "kl3": - kl_loss_stats[i] = kl3_stats[i] * args.beta - elif args.kl_estimator == "kl4": - kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), - mb_response_masks_bool, - args.masked_mean_axis, - stats_denominator, - ) - pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, stats_denominator - ) - loss_stats[i] = loss - ratio_stats[i] = masked_mean( - ratio, mb_response_masks_bool, args.masked_mean_axis, stats_denominator - ) - if args.record_entropy: - # Calculate entropy statistics - entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, args.masked_mean_axis, stats_denominator + kl1, mb_response_masks_bool, loss_axis, loss_denominator ).float() + kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, loss_denominator).float() + kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, loss_denominator).float() + kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, loss_denominator).float() + if args.kl_estimator == "kl1": + kl_loss_stats[i] = kl1_stats[i] * args.beta + elif args.kl_estimator == "kl2": + kl_loss_stats[i] = kl2_stats[i] * args.beta + elif args.kl_estimator == "kl3": + kl_loss_stats[i] = kl3_stats[i] * args.beta + elif args.kl_estimator == "kl4": + kl_loss_stats[i] = kl4_stats[i] * args.beta + pg_clipfrac_stats[i] = masked_mean( + (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator + ) + pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) + loss_stats[i] = loss + ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, loss_axis, stats_denominator) + if args.record_entropy: + # Calculate entropy statistics + entropy_stats[i] = masked_mean( + mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator + ).float() with torch.no_grad(): if args.load_ref_policy: diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index c46742e42c..43ccc0775f 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -525,3 +525,16 @@ def mock_inspect_return(*args, **kwargs): # "natolambert/tulu-v2-sft-mixture-science": 7468, # original data slightly different # } # _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"]) + + +@pytest.mark.parametrize( + "denominator,axis,expected", [(None, None, None), ("token", None, "token"), (10.0, None, 10.0), (5, 1, 5)] +) +def test_get_denominator_valid_inputs(denominator, axis, expected): + assert utils.get_denominator(denominator, axis) == expected + + +@pytest.mark.parametrize("denominator,axis", [("token", 0), ("not-token", None), (0, None), (-1.0, None)]) +def test_get_denominator_invalid_inputs(denominator, axis): + with pytest.raises(AssertionError): + utils.get_denominator(denominator, axis) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 1c76d22b4a..2a7950918d 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2491,3 +2491,18 @@ def get_beaker_experiment_url() -> str | None: return url except Exception: return None + + +def get_denominator(masked_mean_denominator: float | str | None) -> float | str | None: + """Validate and return the masked mean denominator value.""" + if masked_mean_denominator is None: + return None + + if isinstance(masked_mean_denominator, str): + assert masked_mean_denominator in ["token", "num_prompts"], ( + f"masked_mean_denominator string value must be 'token', 'num_prompts' or number, got {masked_mean_denominator}" + ) + return masked_mean_denominator + + assert masked_mean_denominator > 0, f"masked_mean_denominator (={masked_mean_denominator}) must be greater than 0!" + return masked_mean_denominator From f94b4b2528342361e12ae026bae6957c167e04a4 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 09:21:22 -0800 Subject: [PATCH 09/34] fix --- open_instruct/grpo_fast.py | 68 +++++++++++++++++++++++-------------- open_instruct/test_utils.py | 13 +++---- 2 files changed, 49 insertions(+), 32 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 6f0373466e..3b23b5a39a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1052,6 +1052,38 @@ def compute_logprobs( return collated_logprobs, collated_entropies + def calculate_group_tokens( + self, + accumulation_steps: int, + collated_query_responses: list[torch.Tensor], + collated_response_masks: list[torch.Tensor], + collated_tool_masks: list[torch.Tensor], + ) -> dict[int, float]: + accumulation_group_tokens = {} + for group_start in range(0, len(collated_query_responses), accumulation_steps): + group_end = min(group_start + accumulation_steps, len(collated_query_responses)) + # Calculate local tokens for all minibatches in this accumulation group + local_group_tokens = 0.0 + for i in range(group_start, group_end): + mb_response_masks = collated_response_masks[i] + mb_response_masks_bool = mb_response_masks[:, 1:].bool() + if self.args.mask_tool_use and self.args.tool_use: + mb_tool_mask = collated_tool_masks[i] + mb_response_masks_bool = mb_response_masks_bool & mb_tool_mask[:, 1:].bool() + local_group_tokens += mb_response_masks_bool.sum().item() + + # Gather total tokens across all ranks for this accumulation group + if dist.is_available() and dist.is_initialized(): + dist.barrier() + local_group_tokens_tensor = torch.tensor( + local_group_tokens, dtype=torch.float32, device=self.device + ) + dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None) + accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() + else: + accumulation_group_tokens[group_start] = local_group_tokens + return accumulation_group_tokens + def train( self, collated_query_responses, @@ -1161,28 +1193,12 @@ def train( # This ensures all minibatches in an accumulation group are normalized by the same total accumulation_group_tokens = {} if args.masked_mean_denominator == "token": - for group_start in range(0, len(collated_query_responses), accumulation_steps): - group_end = min(group_start + accumulation_steps, len(collated_query_responses)) - # Calculate local tokens for all minibatches in this accumulation group - local_group_tokens = 0.0 - for i in range(group_start, group_end): - mb_response_masks = collated_response_masks[i] - mb_response_masks_bool = mb_response_masks[:, 1:].bool() - if args.mask_tool_use and args.tool_use: - mb_tool_mask = collated_tool_masks[i] - mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() - local_group_tokens += mb_response_masks_bool.sum().item() - - # Gather total tokens across all ranks for this accumulation group - if dist.is_available() and dist.is_initialized(): - dist.barrier() - local_group_tokens_tensor = torch.tensor( - local_group_tokens, dtype=torch.float32, device=self.device - ) - dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None) - accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() - else: - accumulation_group_tokens[group_start] = local_group_tokens + accumulation_group_tokens = self.calculate_group_tokens( + accumulation_steps, + collated_query_responses, + collated_response_masks, + collated_tool_masks, + ) for i in range(len(collated_query_responses)): mb_query_responses = collated_query_responses[i] @@ -1354,11 +1370,11 @@ def train( if args.load_ref_policy: # NOTE: in packed implementation, kl calculation are averages over response tokens kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, loss_axis, loss_denominator + kl1, mb_response_masks_bool, loss_axis, stats_denominator ).float() - kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, loss_denominator).float() - kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, loss_denominator).float() - kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, loss_denominator).float() + kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() if args.kl_estimator == "kl1": kl_loss_stats[i] = kl1_stats[i] * args.beta elif args.kl_estimator == "kl2": diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 43ccc0775f..aa3f308eb0 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -528,13 +528,14 @@ def mock_inspect_return(*args, **kwargs): @pytest.mark.parametrize( - "denominator,axis,expected", [(None, None, None), ("token", None, "token"), (10.0, None, 10.0), (5, 1, 5)] + "denominator,expected", + [(None, None), ("token", "token"), ("num_prompts", "num_prompts"), (10.0, 10.0), (5, 5)], ) -def test_get_denominator_valid_inputs(denominator, axis, expected): - assert utils.get_denominator(denominator, axis) == expected +def test_get_denominator_valid_inputs(denominator, expected): + assert utils.get_denominator(denominator) == expected -@pytest.mark.parametrize("denominator,axis", [("token", 0), ("not-token", None), (0, None), (-1.0, None)]) -def test_get_denominator_invalid_inputs(denominator, axis): +@pytest.mark.parametrize("denominator", ["not-token", 0, -1.0]) +def test_get_denominator_invalid_inputs(denominator): with pytest.raises(AssertionError): - utils.get_denominator(denominator, axis) + utils.get_denominator(denominator) From 1931b995c2e580a5aaa2116c4c48dc285d5510b1 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 09:23:08 -0800 Subject: [PATCH 10/34] lint --- open_instruct/grpo_fast.py | 9 ++------- open_instruct/test_utils.py | 3 +-- 2 files changed, 3 insertions(+), 9 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3b23b5a39a..54f6eb7834 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1075,9 +1075,7 @@ def calculate_group_tokens( # Gather total tokens across all ranks for this accumulation group if dist.is_available() and dist.is_initialized(): dist.barrier() - local_group_tokens_tensor = torch.tensor( - local_group_tokens, dtype=torch.float32, device=self.device - ) + local_group_tokens_tensor = torch.tensor(local_group_tokens, dtype=torch.float32, device=self.device) dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None) accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() else: @@ -1194,10 +1192,7 @@ def train( accumulation_group_tokens = {} if args.masked_mean_denominator == "token": accumulation_group_tokens = self.calculate_group_tokens( - accumulation_steps, - collated_query_responses, - collated_response_masks, - collated_tool_masks, + accumulation_steps, collated_query_responses, collated_response_masks, collated_tool_masks ) for i in range(len(collated_query_responses)): diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index aa3f308eb0..801e356209 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -528,8 +528,7 @@ def mock_inspect_return(*args, **kwargs): @pytest.mark.parametrize( - "denominator,expected", - [(None, None), ("token", "token"), ("num_prompts", "num_prompts"), (10.0, 10.0), (5, 5)], + "denominator,expected", [(None, None), ("token", "token"), ("num_prompts", "num_prompts"), (10.0, 10.0), (5, 5)] ) def test_get_denominator_valid_inputs(denominator, expected): assert utils.get_denominator(denominator) == expected From 90eb98e266c0cce5b8ca35f25329919fdf8fd47e Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 09:37:53 -0800 Subject: [PATCH 11/34] fix --- open_instruct/grpo_fast.py | 11 ++--------- open_instruct/test_utils.py | 4 +--- open_instruct/utils.py | 4 ++-- 3 files changed, 5 insertions(+), 14 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 54f6eb7834..9c88b0b575 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -253,8 +253,7 @@ class Args: masked_mean_denominator: float | str | None = "token" """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). - When using "token", total_batch_tokens is gathered via allreduce across all ranks. - Special value "num_prompts" means use the number of prompts in the batch.""" + When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly @@ -1211,10 +1210,6 @@ def train( if args.masked_mean_denominator == "token": group_start = (i // accumulation_steps) * accumulation_steps loss_denominator = accumulation_group_tokens[group_start] - elif args.masked_mean_denominator == "num_prompts": - # For prompt-level loss, we average across tokens (axis=1) then across batch - loss_denominator = None - loss_axis = 1 mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] @@ -1324,9 +1319,7 @@ def train( # for stats computation, for now no denominator is used # unless masked_mean_denominator is a numeric value. stats_denominator = ( - args.masked_mean_denominator - if args.masked_mean_denominator not in ["token", "num_prompts"] - else None + args.masked_mean_denominator if args.masked_mean_denominator != "token" else None ) if args.load_ref_policy: diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 801e356209..551cc457ad 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -527,9 +527,7 @@ def mock_inspect_return(*args, **kwargs): # _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"]) -@pytest.mark.parametrize( - "denominator,expected", [(None, None), ("token", "token"), ("num_prompts", "num_prompts"), (10.0, 10.0), (5, 5)] -) +@pytest.mark.parametrize("denominator,expected", [(None, None), ("token", "token"), (10.0, 10.0), (5, 5)]) def test_get_denominator_valid_inputs(denominator, expected): assert utils.get_denominator(denominator) == expected diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 2a7950918d..9054bcefba 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2499,8 +2499,8 @@ def get_denominator(masked_mean_denominator: float | str | None) -> float | str return None if isinstance(masked_mean_denominator, str): - assert masked_mean_denominator in ["token", "num_prompts"], ( - f"masked_mean_denominator string value must be 'token', 'num_prompts' or number, got {masked_mean_denominator}" + assert masked_mean_denominator == "token", ( + f"masked_mean_denominator string value must be 'token' or number, got {masked_mean_denominator}" ) return masked_mean_denominator From 6e698e99640b4016ba77c8cd92232e1c7cae1cb2 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 10:37:06 -0800 Subject: [PATCH 12/34] fix --- open_instruct/utils.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 9054bcefba..090bd69618 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2499,10 +2499,15 @@ def get_denominator(masked_mean_denominator: float | str | None) -> float | str return None if isinstance(masked_mean_denominator, str): - assert masked_mean_denominator == "token", ( - f"masked_mean_denominator string value must be 'token' or number, got {masked_mean_denominator}" - ) - return masked_mean_denominator + if masked_mean_denominator == "token": + return masked_mean_denominator + # Try to convert numeric strings to float + try: + masked_mean_denominator = float(masked_mean_denominator) + except ValueError: + raise AssertionError( + f"masked_mean_denominator string value must be 'token' or number, got {masked_mean_denominator}" + ) from None assert masked_mean_denominator > 0, f"masked_mean_denominator (={masked_mean_denominator}) must be greater than 0!" return masked_mean_denominator From 802a7e7d321d8f7a828fd8fadf783eef09b896e3 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 13:16:20 -0800 Subject: [PATCH 13/34] quick and dirty group-level --- open_instruct/grpo_fast.py | 44 ++++++++++++++++++-------------- open_instruct/rl_utils.py | 52 ++++++++++++++++++++++++++++++++++++++ open_instruct/utils.py | 4 +-- 3 files changed, 79 insertions(+), 21 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 9c88b0b575..5c7386b321 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -107,7 +107,7 @@ push_folder_to_hub, ) from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.rl_utils import PackedSequences, Timer, masked_group_mean, masked_mean, pack_sequences from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -253,6 +253,7 @@ class Args: masked_mean_denominator: float | str | None = "token" """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). + Special value "group" means use group-level averaging (average across tokens in a group, then average across groups). When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) @@ -542,13 +543,6 @@ def __post_init__(self): raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") -def masked_mean( - values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None -) -> torch.Tensor: - """Compute mean of tensor with a masked values.""" - numerator = (values * mask).sum(axis=axis) - denom = mask.sum(axis=axis) if denominator is None else denominator - return (numerator / denom).mean() def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: @@ -1319,9 +1313,20 @@ def train( # for stats computation, for now no denominator is used # unless masked_mean_denominator is a numeric value. stats_denominator = ( - args.masked_mean_denominator if args.masked_mean_denominator != "token" else None + args.masked_mean_denominator + if args.masked_mean_denominator != "token" and args.masked_mean_denominator != "group" + else None ) + # Define reduction function based on configuration + if args.masked_mean_denominator == "group": + group_ids = (mb_response_masks[:, 1:] - 1) // args.num_samples_per_prompt_rollout + + def reduce_fn(v, m, a=None, d=None): + return masked_group_mean(v, m, group_ids) + else: + reduce_fn = masked_mean + if args.load_ref_policy: mb_ref_logprob = collated_ref_logprobs[i] # Here we recalculate kl: we want the KL loss to backpropagate through the model @@ -1341,13 +1346,14 @@ def train( elif args.kl_estimator == "kl4": kl = kl4 # grpo change: directly subtract KL in loss (add) - loss = masked_mean( + loss = reduce_fn( pg_loss_max + (args.beta * kl), mb_response_masks_bool, loss_axis, loss_denominator ) else: - loss = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) + loss = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) if args.masked_mean_denominator != "token": loss = loss / accumulation_steps + # Clear CUDA cache before backward pass to free memory for reduce_scatter operations torch.cuda.empty_cache() self.model.backward(loss) @@ -1357,12 +1363,12 @@ def train( with torch.no_grad(): if args.load_ref_policy: # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = masked_mean( + kl1_stats[i] = reduce_fn( kl1, mb_response_masks_bool, loss_axis, stats_denominator ).float() - kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl2_stats[i] = reduce_fn(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl3_stats[i] = reduce_fn(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl4_stats[i] = reduce_fn(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() if args.kl_estimator == "kl1": kl_loss_stats[i] = kl1_stats[i] * args.beta elif args.kl_estimator == "kl2": @@ -1371,15 +1377,15 @@ def train( kl_loss_stats[i] = kl3_stats[i] * args.beta elif args.kl_estimator == "kl4": kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = masked_mean( + pg_clipfrac_stats[i] = reduce_fn( (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator ) - pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) + pg_loss_stats[i] = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) loss_stats[i] = loss - ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, loss_axis, stats_denominator) + ratio_stats[i] = reduce_fn(ratio, mb_response_masks_bool, loss_axis, stats_denominator) if args.record_entropy: # Calculate entropy statistics - entropy_stats[i] = masked_mean( + entropy_stats[i] = reduce_fn( mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator ).float() diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 7dad7a6008..926c903b14 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -249,3 +249,55 @@ def calculate_advantages_packed( advantages = np.stack(advantages_reversed[::-1], axis=1) returns = advantages + values return advantages, returns + + +def masked_mean( + values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None +) -> torch.Tensor: + """Compute mean of tensor with a masked values.""" + numerator = (values * mask).sum(axis=axis) + denom = mask.sum(axis=axis) if denominator is None else denominator + return (numerator / denom).mean() + + +def masked_group_mean( + values: torch.Tensor, + mask: torch.Tensor, + group_ids: torch.Tensor, +) -> torch.Tensor: + """ + Compute mean of tensor values masked by mask, but averaged per group first. + 1. Filter values and group_ids by mask. + 2. Sum values per group. + 3. Count items per group. + 4. Compute mean per group. + 5. Compute mean of group means. + """ + # Flatten everything + flat_values = values.flatten() + flat_mask = mask.flatten().bool() + flat_group_ids = group_ids.flatten() + + # Filter invalid items + valid_values = flat_values[flat_mask] + valid_group_ids = flat_group_ids[flat_mask] + + if valid_values.numel() == 0: + return torch.tensor(0.0, device=values.device) + + # Map group_ids to contiguous range 0..num_groups-1 + _, inverse_indices = torch.unique(valid_group_ids, return_inverse=True) + num_groups = inverse_indices.max().item() + 1 + + # Sum values and counts per group + group_sums = torch.zeros(num_groups, device=values.device, dtype=values.dtype) + group_counts = torch.zeros(num_groups, device=values.device, dtype=values.dtype) + + group_sums.scatter_add_(0, inverse_indices, valid_values) + group_counts.scatter_add_(0, inverse_indices, torch.ones_like(valid_values)) + + # Average per group + group_means = group_sums / group_counts + + # Average across groups + return group_means.mean() diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 090bd69618..7917a3d79a 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2499,14 +2499,14 @@ def get_denominator(masked_mean_denominator: float | str | None) -> float | str return None if isinstance(masked_mean_denominator, str): - if masked_mean_denominator == "token": + if masked_mean_denominator in ["token", "group"]: return masked_mean_denominator # Try to convert numeric strings to float try: masked_mean_denominator = float(masked_mean_denominator) except ValueError: raise AssertionError( - f"masked_mean_denominator string value must be 'token' or number, got {masked_mean_denominator}" + f"masked_mean_denominator string value must be 'token', 'group' or number, got {masked_mean_denominator}" ) from None assert masked_mean_denominator > 0, f"masked_mean_denominator (={masked_mean_denominator}) must be greater than 0!" From 849ebb4a70f9206ca087aa639b463b25a3e8b5f0 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 13:22:33 -0800 Subject: [PATCH 14/34] fix quality --- open_instruct/grpo_fast.py | 8 ++------ open_instruct/rl_utils.py | 6 +----- 2 files changed, 3 insertions(+), 11 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 5c7386b321..9bc106101c 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -543,8 +543,6 @@ def __post_init__(self): raise ValueError("`async_steps` must be greater than 0. Fully synchronous training is not supported.") - - def collate_fn(tensors_list: list[torch.Tensor], pad_token_id: int, pin_memory: bool = True) -> torch.Tensor: padded_tensor = torch.nn.utils.rnn.pad_sequence(tensors_list, batch_first=True, padding_value=pad_token_id) if pin_memory: @@ -1322,7 +1320,7 @@ def train( if args.masked_mean_denominator == "group": group_ids = (mb_response_masks[:, 1:] - 1) // args.num_samples_per_prompt_rollout - def reduce_fn(v, m, a=None, d=None): + def reduce_fn(v, m, a=None, d=None, group_ids=group_ids): return masked_group_mean(v, m, group_ids) else: reduce_fn = masked_mean @@ -1363,9 +1361,7 @@ def reduce_fn(v, m, a=None, d=None): with torch.no_grad(): if args.load_ref_policy: # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = reduce_fn( - kl1, mb_response_masks_bool, loss_axis, stats_denominator - ).float() + kl1_stats[i] = reduce_fn(kl1, mb_response_masks_bool, loss_axis, stats_denominator).float() kl2_stats[i] = reduce_fn(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() kl3_stats[i] = reduce_fn(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() kl4_stats[i] = reduce_fn(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 926c903b14..1d5d6860af 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -260,11 +260,7 @@ def masked_mean( return (numerator / denom).mean() -def masked_group_mean( - values: torch.Tensor, - mask: torch.Tensor, - group_ids: torch.Tensor, -) -> torch.Tensor: +def masked_group_mean(values: torch.Tensor, mask: torch.Tensor, group_ids: torch.Tensor) -> torch.Tensor: """ Compute mean of tensor values masked by mask, but averaged per group first. 1. Filter values and group_ids by mask. From 8c7559d785ceac987633d5f36909970e18689795 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 24 Nov 2025 14:25:43 -0800 Subject: [PATCH 15/34] correct hacky group level --- open_instruct/grpo_fast.py | 126 ++++++++++++++++++++++++++++++------- 1 file changed, 102 insertions(+), 24 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 9bc106101c..8cd3ae4bc3 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -107,7 +107,7 @@ push_folder_to_hub, ) from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, masked_group_mean, masked_mean, pack_sequences +from open_instruct.rl_utils import PackedSequences, Timer, masked_mean, pack_sequences from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -1043,35 +1043,74 @@ def compute_logprobs( return collated_logprobs, collated_entropies - def calculate_group_tokens( + def calculate_token_counts( self, accumulation_steps: int, - collated_query_responses: list[torch.Tensor], collated_response_masks: list[torch.Tensor], collated_tool_masks: list[torch.Tensor], - ) -> dict[int, float]: - accumulation_group_tokens = {} - for group_start in range(0, len(collated_query_responses), accumulation_steps): - group_end = min(group_start + accumulation_steps, len(collated_query_responses)) - # Calculate local tokens for all minibatches in this accumulation group - local_group_tokens = 0.0 + mode: str = "token", + ) -> dict[int, float | torch.Tensor]: + accumulation_counts = {} + + max_group_id = 0 + if mode == "group": + # First pass to determine max group index + for i in range(len(collated_response_masks)): + mb_response_masks = collated_response_masks[i] + # group_id = (sample_index) // n + # sample_index starts at 0. response_mask contains sample_index + 1. + # So (response_mask - 1) // n + max_group_id = max( + max_group_id, ((mb_response_masks.max().item() - 1) // self.args.num_samples_per_prompt_rollout) + ) + + # All reduce max_group_id to ensure all ranks have the same tensor size + if dist.is_available() and dist.is_initialized(): + dist.barrier() + max_group_id_tensor = torch.tensor(max_group_id, dtype=torch.long, device=self.device) + dist.all_reduce(max_group_id_tensor, op=dist.ReduceOp.MAX, group=None) + max_group_id = max_group_id_tensor.item() + + for group_start in range(0, len(collated_response_masks), accumulation_steps): + group_end = min(group_start + accumulation_steps, len(collated_response_masks)) + + if mode == "group": + counts = torch.zeros(max_group_id + 1, device=self.device, dtype=torch.float32) + else: + counts = torch.tensor(0.0, device=self.device, dtype=torch.float32) + for i in range(group_start, group_end): mb_response_masks = collated_response_masks[i] mb_response_masks_bool = mb_response_masks[:, 1:].bool() if self.args.mask_tool_use and self.args.tool_use: mb_tool_mask = collated_tool_masks[i] mb_response_masks_bool = mb_response_masks_bool & mb_tool_mask[:, 1:].bool() - local_group_tokens += mb_response_masks_bool.sum().item() - # Gather total tokens across all ranks for this accumulation group + if mode == "group": + # Filter valid tokens + valid_mask = mb_response_masks_bool + flat_mask = valid_mask.flatten() + if flat_mask.any(): + # Get group IDs for valid tokens + flat_response_masks = mb_response_masks[:, 1:].flatten() + valid_response_masks = flat_response_masks[flat_mask] + group_ids = (valid_response_masks - 1) // self.args.num_samples_per_prompt_rollout + # Accumulate counts + counts.scatter_add_(0, group_ids, torch.ones_like(group_ids, dtype=torch.float32)) + else: + counts += mb_response_masks_bool.sum().float() + + # All reduce counts if dist.is_available() and dist.is_initialized(): dist.barrier() - local_group_tokens_tensor = torch.tensor(local_group_tokens, dtype=torch.float32, device=self.device) - dist.all_reduce(local_group_tokens_tensor, op=dist.ReduceOp.SUM, group=None) - accumulation_group_tokens[group_start] = local_group_tokens_tensor.item() + dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=None) + + if mode == "token": + accumulation_counts[group_start] = counts.item() else: - accumulation_group_tokens[group_start] = local_group_tokens - return accumulation_group_tokens + accumulation_counts[group_start] = counts + + return accumulation_counts def train( self, @@ -1180,10 +1219,13 @@ def train( for epoch_idx in range(args.num_epochs): # Pre-compute total tokens for each accumulation group if using "token" normalization # This ensures all minibatches in an accumulation group are normalized by the same total - accumulation_group_tokens = {} - if args.masked_mean_denominator == "token": - accumulation_group_tokens = self.calculate_group_tokens( - accumulation_steps, collated_query_responses, collated_response_masks, collated_tool_masks + accumulation_token_counts = {} + if args.masked_mean_denominator in ["token", "group"]: + accumulation_token_counts = self.calculate_token_counts( + accumulation_steps, + collated_response_masks, + collated_tool_masks, + mode=args.masked_mean_denominator, ) for i in range(len(collated_query_responses)): @@ -1201,7 +1243,7 @@ def train( loss_axis = None if args.masked_mean_denominator == "token": group_start = (i // accumulation_steps) * accumulation_steps - loss_denominator = accumulation_group_tokens[group_start] + loss_denominator = accumulation_token_counts[group_start] mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] @@ -1318,10 +1360,37 @@ def train( # Define reduction function based on configuration if args.masked_mean_denominator == "group": + group_start = (i // accumulation_steps) * accumulation_steps + group_counts = accumulation_token_counts[group_start] + total_active_groups = (group_counts > 0).sum().item() group_ids = (mb_response_masks[:, 1:] - 1) // args.num_samples_per_prompt_rollout - def reduce_fn(v, m, a=None, d=None, group_ids=group_ids): - return masked_group_mean(v, m, group_ids) + def reduce_fn(v, m, a=None, d=None): + flat_v = v.flatten() + flat_m = m.flatten().bool() + flat_g = group_ids.flatten() + + # if no valid tokens in batch. + if not flat_m.any(): + return torch.tensor(0.0, device=v.device) + + valid_v = flat_v[flat_m] + valid_g = flat_g[flat_m] + + valid_counts = group_counts[valid_g] + # Avoid division by zero if count is 0 (should not happen for valid tokens) + valid_counts = torch.max( + valid_counts, torch.tensor(1.0, device=valid_counts.device, dtype=valid_counts.dtype) + ) + + weights = 1.0 / (valid_counts * total_active_groups) + + # Sum weighted values + loss = (valid_v * weights).sum() + scale = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 + scale *= accumulation_steps + + return loss * scale else: reduce_fn = masked_mean @@ -1349,7 +1418,16 @@ def reduce_fn(v, m, a=None, d=None, group_ids=group_ids): ) else: loss = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) - if args.masked_mean_denominator != "token": + if args.masked_mean_denominator == "token": + # In token mode, we divide by the GLOBAL total number of tokens. + # DDP averages gradients across ranks (dividing by world_size). + # To get the true global mean gradient (sum_all_gradients / global_tokens), + # we must multiply by world_size to cancel out DDP's division. + if dist.is_available() and dist.is_initialized(): + loss *= dist.get_world_size() + elif args.masked_mean_denominator != "group": + # For "group" mode, the scaling is handled inside reduce_fn. + # For default (None) or numeric modes, we divide by accumulation_steps here. loss = loss / accumulation_steps # Clear CUDA cache before backward pass to free memory for reduce_scatter operations From febf44a5a48faa57c6514d1178285983664a14dd Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Wed, 26 Nov 2025 09:54:42 -0800 Subject: [PATCH 16/34] fix --- open_instruct/grpo_fast.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 8cd3ae4bc3..68e2307f70 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -497,6 +497,8 @@ def __post_init__(self): self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{beaker_users}/{checkpoint_dir_name}" else: self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}" + if not checkpoint_dir_name.startswith("/filestore"): + self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}" if self.checkpoint_state_dir is not None: if self.gs_checkpoint_state_dir is not None: From 1f855783a7e6dc183097b717b4d93d031ed60d06 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 30 Nov 2025 16:36:34 -0800 Subject: [PATCH 17/34] remove group loss --- open_instruct/grpo_fast.py | 130 ++++++------------------------------- open_instruct/rl_utils.py | 39 ----------- open_instruct/utils.py | 2 +- 3 files changed, 21 insertions(+), 150 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 68e2307f70..61052aeeba 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -253,7 +253,6 @@ class Args: masked_mean_denominator: float | str | None = "token" """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). - Special value "group" means use group-level averaging (average across tokens in a group, then average across groups). When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) @@ -1050,68 +1049,23 @@ def calculate_token_counts( accumulation_steps: int, collated_response_masks: list[torch.Tensor], collated_tool_masks: list[torch.Tensor], - mode: str = "token", ) -> dict[int, float | torch.Tensor]: accumulation_counts = {} - - max_group_id = 0 - if mode == "group": - # First pass to determine max group index - for i in range(len(collated_response_masks)): - mb_response_masks = collated_response_masks[i] - # group_id = (sample_index) // n - # sample_index starts at 0. response_mask contains sample_index + 1. - # So (response_mask - 1) // n - max_group_id = max( - max_group_id, ((mb_response_masks.max().item() - 1) // self.args.num_samples_per_prompt_rollout) - ) - - # All reduce max_group_id to ensure all ranks have the same tensor size - if dist.is_available() and dist.is_initialized(): - dist.barrier() - max_group_id_tensor = torch.tensor(max_group_id, dtype=torch.long, device=self.device) - dist.all_reduce(max_group_id_tensor, op=dist.ReduceOp.MAX, group=None) - max_group_id = max_group_id_tensor.item() - for group_start in range(0, len(collated_response_masks), accumulation_steps): group_end = min(group_start + accumulation_steps, len(collated_response_masks)) - - if mode == "group": - counts = torch.zeros(max_group_id + 1, device=self.device, dtype=torch.float32) - else: - counts = torch.tensor(0.0, device=self.device, dtype=torch.float32) - + counts = torch.tensor(0.0, device=self.device, dtype=torch.float32) for i in range(group_start, group_end): mb_response_masks = collated_response_masks[i] mb_response_masks_bool = mb_response_masks[:, 1:].bool() if self.args.mask_tool_use and self.args.tool_use: mb_tool_mask = collated_tool_masks[i] mb_response_masks_bool = mb_response_masks_bool & mb_tool_mask[:, 1:].bool() - - if mode == "group": - # Filter valid tokens - valid_mask = mb_response_masks_bool - flat_mask = valid_mask.flatten() - if flat_mask.any(): - # Get group IDs for valid tokens - flat_response_masks = mb_response_masks[:, 1:].flatten() - valid_response_masks = flat_response_masks[flat_mask] - group_ids = (valid_response_masks - 1) // self.args.num_samples_per_prompt_rollout - # Accumulate counts - counts.scatter_add_(0, group_ids, torch.ones_like(group_ids, dtype=torch.float32)) - else: - counts += mb_response_masks_bool.sum().float() - + counts += mb_response_masks_bool.sum().float() # All reduce counts if dist.is_available() and dist.is_initialized(): dist.barrier() dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=None) - - if mode == "token": - accumulation_counts[group_start] = counts.item() - else: - accumulation_counts[group_start] = counts - + accumulation_counts[group_start] = counts.item() return accumulation_counts def train( @@ -1222,12 +1176,9 @@ def train( # Pre-compute total tokens for each accumulation group if using "token" normalization # This ensures all minibatches in an accumulation group are normalized by the same total accumulation_token_counts = {} - if args.masked_mean_denominator in ["token", "group"]: + if args.masked_mean_denominator == "token": accumulation_token_counts = self.calculate_token_counts( - accumulation_steps, - collated_response_masks, - collated_tool_masks, - mode=args.masked_mean_denominator, + accumulation_steps, collated_response_masks, collated_tool_masks ) for i in range(len(collated_query_responses)): @@ -1355,47 +1306,9 @@ def train( # for stats computation, for now no denominator is used # unless masked_mean_denominator is a numeric value. stats_denominator = ( - args.masked_mean_denominator - if args.masked_mean_denominator != "token" and args.masked_mean_denominator != "group" - else None + args.masked_mean_denominator if args.masked_mean_denominator != "token" else None ) - # Define reduction function based on configuration - if args.masked_mean_denominator == "group": - group_start = (i // accumulation_steps) * accumulation_steps - group_counts = accumulation_token_counts[group_start] - total_active_groups = (group_counts > 0).sum().item() - group_ids = (mb_response_masks[:, 1:] - 1) // args.num_samples_per_prompt_rollout - - def reduce_fn(v, m, a=None, d=None): - flat_v = v.flatten() - flat_m = m.flatten().bool() - flat_g = group_ids.flatten() - - # if no valid tokens in batch. - if not flat_m.any(): - return torch.tensor(0.0, device=v.device) - - valid_v = flat_v[flat_m] - valid_g = flat_g[flat_m] - - valid_counts = group_counts[valid_g] - # Avoid division by zero if count is 0 (should not happen for valid tokens) - valid_counts = torch.max( - valid_counts, torch.tensor(1.0, device=valid_counts.device, dtype=valid_counts.dtype) - ) - - weights = 1.0 / (valid_counts * total_active_groups) - - # Sum weighted values - loss = (valid_v * weights).sum() - scale = dist.get_world_size() if dist.is_available() and dist.is_initialized() else 1 - scale *= accumulation_steps - - return loss * scale - else: - reduce_fn = masked_mean - if args.load_ref_policy: mb_ref_logprob = collated_ref_logprobs[i] # Here we recalculate kl: we want the KL loss to backpropagate through the model @@ -1415,21 +1328,16 @@ def reduce_fn(v, m, a=None, d=None): elif args.kl_estimator == "kl4": kl = kl4 # grpo change: directly subtract KL in loss (add) - loss = reduce_fn( + loss = masked_mean( pg_loss_max + (args.beta * kl), mb_response_masks_bool, loss_axis, loss_denominator ) else: - loss = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) + loss = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) if args.masked_mean_denominator == "token": - # In token mode, we divide by the GLOBAL total number of tokens. - # DDP averages gradients across ranks (dividing by world_size). - # To get the true global mean gradient (sum_all_gradients / global_tokens), - # we must multiply by world_size to cancel out DDP's division. + # rescale loss by world size if dist.is_available() and dist.is_initialized(): loss *= dist.get_world_size() - elif args.masked_mean_denominator != "group": - # For "group" mode, the scaling is handled inside reduce_fn. - # For default (None) or numeric modes, we divide by accumulation_steps here. + else: loss = loss / accumulation_steps # Clear CUDA cache before backward pass to free memory for reduce_scatter operations @@ -1441,10 +1349,12 @@ def reduce_fn(v, m, a=None, d=None): with torch.no_grad(): if args.load_ref_policy: # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = reduce_fn(kl1, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl2_stats[i] = reduce_fn(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl3_stats[i] = reduce_fn(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl4_stats[i] = reduce_fn(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl1_stats[i] = masked_mean( + kl1, mb_response_masks_bool, loss_axis, stats_denominator + ).float() + kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() + kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() if args.kl_estimator == "kl1": kl_loss_stats[i] = kl1_stats[i] * args.beta elif args.kl_estimator == "kl2": @@ -1453,15 +1363,15 @@ def reduce_fn(v, m, a=None, d=None): kl_loss_stats[i] = kl3_stats[i] * args.beta elif args.kl_estimator == "kl4": kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = reduce_fn( + pg_clipfrac_stats[i] = masked_mean( (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator ) - pg_loss_stats[i] = reduce_fn(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) + pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) loss_stats[i] = loss - ratio_stats[i] = reduce_fn(ratio, mb_response_masks_bool, loss_axis, stats_denominator) + ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, loss_axis, stats_denominator) if args.record_entropy: # Calculate entropy statistics - entropy_stats[i] = reduce_fn( + entropy_stats[i] = masked_mean( mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator ).float() diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 1d5d6860af..1fadbf914a 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -258,42 +258,3 @@ def masked_mean( numerator = (values * mask).sum(axis=axis) denom = mask.sum(axis=axis) if denominator is None else denominator return (numerator / denom).mean() - - -def masked_group_mean(values: torch.Tensor, mask: torch.Tensor, group_ids: torch.Tensor) -> torch.Tensor: - """ - Compute mean of tensor values masked by mask, but averaged per group first. - 1. Filter values and group_ids by mask. - 2. Sum values per group. - 3. Count items per group. - 4. Compute mean per group. - 5. Compute mean of group means. - """ - # Flatten everything - flat_values = values.flatten() - flat_mask = mask.flatten().bool() - flat_group_ids = group_ids.flatten() - - # Filter invalid items - valid_values = flat_values[flat_mask] - valid_group_ids = flat_group_ids[flat_mask] - - if valid_values.numel() == 0: - return torch.tensor(0.0, device=values.device) - - # Map group_ids to contiguous range 0..num_groups-1 - _, inverse_indices = torch.unique(valid_group_ids, return_inverse=True) - num_groups = inverse_indices.max().item() + 1 - - # Sum values and counts per group - group_sums = torch.zeros(num_groups, device=values.device, dtype=values.dtype) - group_counts = torch.zeros(num_groups, device=values.device, dtype=values.dtype) - - group_sums.scatter_add_(0, inverse_indices, valid_values) - group_counts.scatter_add_(0, inverse_indices, torch.ones_like(valid_values)) - - # Average per group - group_means = group_sums / group_counts - - # Average across groups - return group_means.mean() diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 7917a3d79a..19eeadf5bf 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2499,7 +2499,7 @@ def get_denominator(masked_mean_denominator: float | str | None) -> float | str return None if isinstance(masked_mean_denominator, str): - if masked_mean_denominator in ["token", "group"]: + if masked_mean_denominator == "token": return masked_mean_denominator # Try to convert numeric strings to float try: From 6a78727aa753c68151cdd0922bdcf4974148efbe Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 30 Nov 2025 16:41:27 -0800 Subject: [PATCH 18/34] simplify a little --- open_instruct/grpo_fast.py | 34 +++++++++++++++++++--------------- 1 file changed, 19 insertions(+), 15 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 61052aeeba..6e1134952b 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1049,23 +1049,27 @@ def calculate_token_counts( accumulation_steps: int, collated_response_masks: list[torch.Tensor], collated_tool_masks: list[torch.Tensor], - ) -> dict[int, float | torch.Tensor]: + ) -> dict[int, float]: accumulation_counts = {} - for group_start in range(0, len(collated_response_masks), accumulation_steps): - group_end = min(group_start + accumulation_steps, len(collated_response_masks)) - counts = torch.tensor(0.0, device=self.device, dtype=torch.float32) + args = self.args + device = self.device + total_batches = len(collated_response_masks) + # sometimes we have multiple mini-batches per accumulation group + # so compute counts for each group. + for group_start in range(0, total_batches, accumulation_steps): + group_end = min(group_start + accumulation_steps, total_batches) + count = 0.0 for i in range(group_start, group_end): - mb_response_masks = collated_response_masks[i] - mb_response_masks_bool = mb_response_masks[:, 1:].bool() - if self.args.mask_tool_use and self.args.tool_use: - mb_tool_mask = collated_tool_masks[i] - mb_response_masks_bool = mb_response_masks_bool & mb_tool_mask[:, 1:].bool() - counts += mb_response_masks_bool.sum().float() - # All reduce counts + masks = collated_response_masks[i][:, 1:].bool() + if args.mask_tool_use and args.tool_use: + masks &= collated_tool_masks[i][:, 1:].bool() + count += masks.sum().item() + count_tensor = torch.tensor(count, device=device, dtype=torch.float32) if dist.is_available() and dist.is_initialized(): dist.barrier() - dist.all_reduce(counts, op=dist.ReduceOp.SUM, group=None) - accumulation_counts[group_start] = counts.item() + dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) + count = count_tensor.item() + accumulation_counts[group_start] = count return accumulation_counts def train( @@ -1195,8 +1199,8 @@ def train( loss_denominator = args.masked_mean_denominator loss_axis = None if args.masked_mean_denominator == "token": - group_start = (i // accumulation_steps) * accumulation_steps - loss_denominator = accumulation_token_counts[group_start] + batch_start = (i // accumulation_steps) * accumulation_steps + loss_denominator = accumulation_token_counts[batch_start] mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] From 0331d17c2dc185e6e4fd24d0d9413ad530be4345 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 30 Nov 2025 16:52:08 -0800 Subject: [PATCH 19/34] small fix --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 6e1134952b..a55d841a1f 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1337,7 +1337,7 @@ def train( ) else: loss = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) - if args.masked_mean_denominator == "token": + if args.masked_mean_denominator is not None: # rescale loss by world size if dist.is_available() and dist.is_initialized(): loss *= dist.get_world_size() From fe7642ce07bd2777eefe0f9ae1b8da53d9e68228 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 30 Nov 2025 16:55:12 -0800 Subject: [PATCH 20/34] whoops, fix indent --- open_instruct/grpo_fast.py | 28 +++++++++++++++++----------- 1 file changed, 17 insertions(+), 11 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index a55d841a1f..6407ba97ab 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1356,17 +1356,23 @@ def train( kl1_stats[i] = masked_mean( kl1, mb_response_masks_bool, loss_axis, stats_denominator ).float() - kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool, loss_axis, stats_denominator).float() - kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool, loss_axis, stats_denominator).float() - if args.kl_estimator == "kl1": - kl_loss_stats[i] = kl1_stats[i] * args.beta - elif args.kl_estimator == "kl2": - kl_loss_stats[i] = kl2_stats[i] * args.beta - elif args.kl_estimator == "kl3": - kl_loss_stats[i] = kl3_stats[i] * args.beta - elif args.kl_estimator == "kl4": - kl_loss_stats[i] = kl4_stats[i] * args.beta + kl2_stats[i] = masked_mean( + kl2, mb_response_masks_bool, loss_axis, stats_denominator + ).float() + kl3_stats[i] = masked_mean( + kl3, mb_response_masks_bool, loss_axis, stats_denominator + ).float() + kl4_stats[i] = masked_mean( + kl4, mb_response_masks_bool, loss_axis, stats_denominator + ).float() + if args.kl_estimator == "kl1": + kl_loss_stats[i] = kl1_stats[i] * args.beta + elif args.kl_estimator == "kl2": + kl_loss_stats[i] = kl2_stats[i] * args.beta + elif args.kl_estimator == "kl3": + kl_loss_stats[i] = kl3_stats[i] * args.beta + elif args.kl_estimator == "kl4": + kl_loss_stats[i] = kl4_stats[i] * args.beta pg_clipfrac_stats[i] = masked_mean( (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator ) From 7d526a17dc36f591a279f436a22e537cb523f577 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 30 Nov 2025 16:55:54 -0800 Subject: [PATCH 21/34] small test --- open_instruct/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 551cc457ad..5d6e1f5e4f 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -527,7 +527,7 @@ def mock_inspect_return(*args, **kwargs): # _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"]) -@pytest.mark.parametrize("denominator,expected", [(None, None), ("token", "token"), (10.0, 10.0), (5, 5)]) +@pytest.mark.parametrize("denominator,expected", [(None, None), ("token", "token"), (10.0, 10.0), (5, 5), ("5", 5)]) def test_get_denominator_valid_inputs(denominator, expected): assert utils.get_denominator(denominator) == expected From ff86d58149be2f5c310d095d9bf2425b9db9e53f Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Sun, 30 Nov 2025 17:22:48 -0800 Subject: [PATCH 22/34] fix indent --- open_instruct/grpo_fast.py | 24 +++++++++++++----------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 8263fe8926..6059078f50 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1373,17 +1373,19 @@ def train( kl_loss_stats[i] = kl3_stats[i] * args.beta elif args.kl_estimator == "kl4": kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator - ) - pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator) - loss_stats[i] = loss - ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, loss_axis, stats_denominator) - if args.record_entropy: - # Calculate entropy statistics - entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator - ).float() + pg_clipfrac_stats[i] = masked_mean( + (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator + ) + pg_loss_stats[i] = masked_mean( + pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator + ) + loss_stats[i] = loss + ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, loss_axis, stats_denominator) + if args.record_entropy: + # Calculate entropy statistics + entropy_stats[i] = masked_mean( + mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator + ).float() with torch.no_grad(): if args.load_ref_policy: From f3ee394b9248ffa697608b9514c2e8a7efc62602 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 08:49:04 -0800 Subject: [PATCH 23/34] rejig based on feedback --- open_instruct/grpo_fast.py | 127 ++++++++++++++++++------------------- 1 file changed, 60 insertions(+), 67 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 6059078f50..3389f9c6f2 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -250,10 +250,11 @@ class Args: """the KL estimator to use""" pack_length: int = 512 """the length of the pack (you should prob set to the max length of the model)""" - masked_mean_denominator: float | str | None = "token" - """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum. - Special value "token" means use total_batch_tokens (computed across all ranks in distributed training). - When using "token", total_batch_tokens is gathered via allreduce across all ranks.""" + loss_denominator: str | float = "token" + """Optional constant denominator for masked_mean; can be "token" or a float value. + when "token", the loss is divided by the total number of tokens in the batch (standard LM training). + when a float value, the loss is divided by this value (ideally, max tokens in batch, per Dr GRPO). + """ alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly @@ -457,7 +458,13 @@ def __post_init__(self): "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. " "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) - self.masked_mean_denominator = get_denominator(self.masked_mean_denominator) + if self.loss_denominator is not None and self.loss_denominator != "token": + try: + self.loss_denominator = float(self.loss_denominator) + except ValueError: + raise ValueError( + f"loss_denominator must be a float value if not 'token', got: {self.loss_denominator}" + ) assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") @@ -1050,26 +1057,31 @@ def calculate_token_counts( collated_response_masks: list[torch.Tensor], collated_tool_masks: list[torch.Tensor], ) -> dict[int, float]: - accumulation_counts = {} args = self.args device = self.device total_batches = len(collated_response_masks) - # sometimes we have multiple mini-batches per accumulation group - # so compute counts for each group. - for group_start in range(0, total_batches, accumulation_steps): - group_end = min(group_start + accumulation_steps, total_batches) - count = 0.0 - for i in range(group_start, group_end): - masks = collated_response_masks[i][:, 1:].bool() - if args.mask_tool_use and args.tool_use: - masks &= collated_tool_masks[i][:, 1:].bool() - count += masks.sum().item() - count_tensor = torch.tensor(count, device=device, dtype=torch.float32) - if dist.is_available() and dist.is_initialized(): - dist.barrier() - dist.all_reduce(count_tensor, op=dist.ReduceOp.SUM) - count = count_tensor.item() - accumulation_counts[group_start] = count + + resp = torch.stack(collated_response_masks, dim=0).to(device) + masks = resp[:, :, 1:].bool() + if args.mask_tool_use and args.tool_use: + tools = torch.stack(collated_tool_masks, dim=0).to(device) + masks &= tools[:, :, 1:].bool() + # sum over bsz and seq len. + batch_counts = masks.sum(dim=(1, 2)).to(dtype=torch.float32) + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(batch_counts, op=dist.ReduceOp.SUM) + + # group counts by the batches (called groups here) + num_groups = (total_batches + accumulation_steps - 1) // accumulation_steps + group_ids = torch.arange(total_batches, device=device) // accumulation_steps + + group_counts = torch.zeros(num_groups, device=device, dtype=torch.float32) + group_counts.scatter_add_(0, group_ids, batch_counts) + + accumulation_counts: dict[int, float] = { + int(group_idx * accumulation_steps): float(count) for group_idx, count in enumerate(group_counts) + } + return accumulation_counts def train( @@ -1179,11 +1191,16 @@ def train( for epoch_idx in range(args.num_epochs): # Pre-compute total tokens for each accumulation group if using "token" normalization # This ensures all minibatches in an accumulation group are normalized by the same total - accumulation_token_counts = {} - if args.masked_mean_denominator == "token": + accumulation_token_counts: dict[int, float] = {} + if args.loss_denominator == "token": accumulation_token_counts = self.calculate_token_counts( accumulation_steps, collated_response_masks, collated_tool_masks ) + else: + accumulation_token_counts = { + int(group_idx * accumulation_steps): args.loss_denominator + for group_idx in range((len(collated_query_responses) // accumulation_steps) + 1) + } for i in range(len(collated_query_responses)): mb_query_responses = collated_query_responses[i] @@ -1195,12 +1212,9 @@ def train( if args.mask_tool_use and args.tool_use: mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() - # Determine the denominator for masked_mean normalization - loss_denominator = args.masked_mean_denominator - loss_axis = None - if args.masked_mean_denominator == "token": - batch_start = (i // accumulation_steps) * accumulation_steps - loss_denominator = accumulation_token_counts[batch_start] + # retrieve the loss denominator for the current batch + batch_start = (i // accumulation_steps) * accumulation_steps + loss_denominator = accumulation_token_counts[batch_start] mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] @@ -1307,12 +1321,6 @@ def train( pg_loss_max = torch.max(pg_losses, pg_losses2) - # for stats computation, for now no denominator is used - # unless masked_mean_denominator is a numeric value. - stats_denominator = ( - args.masked_mean_denominator if args.masked_mean_denominator != "token" else None - ) - if args.load_ref_policy: mb_ref_logprob = collated_ref_logprobs[i] # Here we recalculate kl: we want the KL loss to backpropagate through the model @@ -1333,16 +1341,14 @@ def train( kl = kl4 # grpo change: directly subtract KL in loss (add) loss = masked_mean( - pg_loss_max + (args.beta * kl), mb_response_masks_bool, loss_axis, loss_denominator + pg_loss_max + (args.beta * kl), mb_response_masks_bool, None, loss_denominator ) else: - loss = masked_mean(pg_loss_max, mb_response_masks_bool, loss_axis, loss_denominator) - if args.masked_mean_denominator is not None: - # rescale loss by world size - if dist.is_available() and dist.is_initialized(): - loss *= dist.get_world_size() - else: - loss = loss / accumulation_steps + loss = masked_mean(pg_loss_max, mb_response_masks_bool, None, loss_denominator) + + # we already took world size into account via the tokens + if dist.is_available() and dist.is_initialized(): + loss *= dist.get_world_size() # Clear CUDA cache before backward pass to free memory for reduce_scatter operations torch.cuda.empty_cache() @@ -1352,19 +1358,12 @@ def train( local_step += 1 with torch.no_grad(): if args.load_ref_policy: - # NOTE: in packed implementation, kl calculation are averages over response tokens - kl1_stats[i] = masked_mean( - kl1, mb_response_masks_bool, loss_axis, stats_denominator - ).float() - kl2_stats[i] = masked_mean( - kl2, mb_response_masks_bool, loss_axis, stats_denominator - ).float() - kl3_stats[i] = masked_mean( - kl3, mb_response_masks_bool, loss_axis, stats_denominator - ).float() - kl4_stats[i] = masked_mean( - kl4, mb_response_masks_bool, loss_axis, stats_denominator - ).float() + # NOTE: for stats, we just average over response tokens in the minibatch + # (we don't take the total token count into account) + kl1_stats[i] = masked_mean(kl1, mb_response_masks_bool).float() + kl2_stats[i] = masked_mean(kl2, mb_response_masks_bool).float() + kl3_stats[i] = masked_mean(kl3, mb_response_masks_bool).float() + kl4_stats[i] = masked_mean(kl4, mb_response_masks_bool).float() if args.kl_estimator == "kl1": kl_loss_stats[i] = kl1_stats[i] * args.beta elif args.kl_estimator == "kl2": @@ -1373,19 +1372,13 @@ def train( kl_loss_stats[i] = kl3_stats[i] * args.beta elif args.kl_estimator == "kl4": kl_loss_stats[i] = kl4_stats[i] * args.beta - pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), mb_response_masks_bool, loss_axis, stats_denominator - ) - pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, loss_axis, stats_denominator - ) + pg_clipfrac_stats[i] = masked_mean((pg_losses2 > pg_losses).float(), mb_response_masks_bool) + pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool) loss_stats[i] = loss - ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool, loss_axis, stats_denominator) + ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool) if args.record_entropy: # Calculate entropy statistics - entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, loss_axis, stats_denominator - ).float() + entropy_stats[i] = masked_mean(mb_entropy, mb_response_masks_bool).float() with torch.no_grad(): if args.load_ref_policy: From b76400431b66648b516bb0ade63a57686c48d2b2 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 08:54:22 -0800 Subject: [PATCH 24/34] clean --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 3389f9c6f2..d287703c01 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -458,7 +458,7 @@ def __post_init__(self): "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. " "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) - if self.loss_denominator is not None and self.loss_denominator != "token": + if self.loss_denominator != "token": try: self.loss_denominator = float(self.loss_denominator) except ValueError: From 408eaddeb0abb4157198b2b493297fc3c596f43e Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 09:15:47 -0800 Subject: [PATCH 25/34] fix --- open_instruct/grpo_fast.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index d287703c01..7647adbdf2 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -250,7 +250,7 @@ class Args: """the KL estimator to use""" pack_length: int = 512 """the length of the pack (you should prob set to the max length of the model)""" - loss_denominator: str | float = "token" + loss_denominator: str = "token" """Optional constant denominator for masked_mean; can be "token" or a float value. when "token", the loss is divided by the total number of tokens in the batch (standard LM training). when a float value, the loss is divided by this value (ideally, max tokens in batch, per Dr GRPO). From f5d8afd893564c53f3d29b676a34e85918c4a10f Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 09:25:28 -0800 Subject: [PATCH 26/34] fix --- open_instruct/grpo_fast.py | 39 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 7647adbdf2..ead87c141b 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1059,28 +1059,23 @@ def calculate_token_counts( ) -> dict[int, float]: args = self.args device = self.device - total_batches = len(collated_response_masks) - - resp = torch.stack(collated_response_masks, dim=0).to(device) - masks = resp[:, :, 1:].bool() - if args.mask_tool_use and args.tool_use: - tools = torch.stack(collated_tool_masks, dim=0).to(device) - masks &= tools[:, :, 1:].bool() - # sum over bsz and seq len. - batch_counts = masks.sum(dim=(1, 2)).to(dtype=torch.float32) - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(batch_counts, op=dist.ReduceOp.SUM) - - # group counts by the batches (called groups here) - num_groups = (total_batches + accumulation_steps - 1) // accumulation_steps - group_ids = torch.arange(total_batches, device=device) // accumulation_steps - - group_counts = torch.zeros(num_groups, device=device, dtype=torch.float32) - group_counts.scatter_add_(0, group_ids, batch_counts) - - accumulation_counts: dict[int, float] = { - int(group_idx * accumulation_steps): float(count) for group_idx, count in enumerate(group_counts) - } + + accumulation_counts: dict[int, float] = {} + + for i, response_mask in enumerate(collated_response_masks): + response_mask = response_mask.to(device) + mask = response_mask[:, 1:].bool() + if args.mask_tool_use and args.tool_use: + tool_mask = collated_tool_masks[i].to(device) + mask &= tool_mask[:, 1:].bool() + + count = mask.sum().float() + if dist.is_available() and dist.is_initialized(): + dist.all_reduce(count, op=dist.ReduceOp.SUM) + + group_idx = i // accumulation_steps + key = int(group_idx * accumulation_steps) + accumulation_counts[key] = accumulation_counts.get(key, 0.0) + count.item() return accumulation_counts From 222d565c674b2e221dc2a919aee337c9b730306a Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 09:25:40 -0800 Subject: [PATCH 27/34] fix --- open_instruct/grpo_fast.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index ead87c141b..b2c0abf4ba 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1070,8 +1070,7 @@ def calculate_token_counts( mask &= tool_mask[:, 1:].bool() count = mask.sum().float() - if dist.is_available() and dist.is_initialized(): - dist.all_reduce(count, op=dist.ReduceOp.SUM) + dist.all_reduce(count, op=dist.ReduceOp.SUM) group_idx = i // accumulation_steps key = int(group_idx * accumulation_steps) From ccf556395060641831d5bb4ad7666344ffc064ed Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 09:30:59 -0800 Subject: [PATCH 28/34] cleanup --- open_instruct/grpo_fast.py | 1 - open_instruct/test_utils.py | 11 ----------- open_instruct/utils.py | 20 -------------------- 3 files changed, 32 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index b2c0abf4ba..4eb4a3d57b 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -118,7 +118,6 @@ combine_reward_metrics, download_latest_checkpoint_from_gs, get_beaker_whoami, - get_denominator, get_eval_ds_config, get_optimizer_grouped_parameters, get_train_ds_config, diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index 5d6e1f5e4f..c46742e42c 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -525,14 +525,3 @@ def mock_inspect_return(*args, **kwargs): # "natolambert/tulu-v2-sft-mixture-science": 7468, # original data slightly different # } # _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"]) - - -@pytest.mark.parametrize("denominator,expected", [(None, None), ("token", "token"), (10.0, 10.0), (5, 5), ("5", 5)]) -def test_get_denominator_valid_inputs(denominator, expected): - assert utils.get_denominator(denominator) == expected - - -@pytest.mark.parametrize("denominator", ["not-token", 0, -1.0]) -def test_get_denominator_invalid_inputs(denominator): - with pytest.raises(AssertionError): - utils.get_denominator(denominator) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 19eeadf5bf..1c76d22b4a 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2491,23 +2491,3 @@ def get_beaker_experiment_url() -> str | None: return url except Exception: return None - - -def get_denominator(masked_mean_denominator: float | str | None) -> float | str | None: - """Validate and return the masked mean denominator value.""" - if masked_mean_denominator is None: - return None - - if isinstance(masked_mean_denominator, str): - if masked_mean_denominator == "token": - return masked_mean_denominator - # Try to convert numeric strings to float - try: - masked_mean_denominator = float(masked_mean_denominator) - except ValueError: - raise AssertionError( - f"masked_mean_denominator string value must be 'token', 'group' or number, got {masked_mean_denominator}" - ) from None - - assert masked_mean_denominator > 0, f"masked_mean_denominator (={masked_mean_denominator}) must be greater than 0!" - return masked_mean_denominator From de40be80329a5cfb54368952547c7c6e3e3b1d11 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 09:34:47 -0800 Subject: [PATCH 29/34] little refactor --- open_instruct/grpo_fast.py | 20 +++++++++++--------- open_instruct/test_utils.py | 18 ++++++++++++++++++ open_instruct/utils.py | 16 ++++++++++++++++ 3 files changed, 45 insertions(+), 9 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4eb4a3d57b..d1a361f96a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -457,13 +457,7 @@ def __post_init__(self): "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. " "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) - if self.loss_denominator != "token": - try: - self.loss_denominator = float(self.loss_denominator) - except ValueError: - raise ValueError( - f"loss_denominator must be a float value if not 'token', got: {self.loss_denominator}" - ) + self.loss_denominator = utils.get_denominator(self.loss_denominator) assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") @@ -1060,6 +1054,7 @@ def calculate_token_counts( device = self.device accumulation_counts: dict[int, float] = {} + local_counts = [] for i, response_mask in enumerate(collated_response_masks): response_mask = response_mask.to(device) @@ -1068,9 +1063,16 @@ def calculate_token_counts( tool_mask = collated_tool_masks[i].to(device) mask &= tool_mask[:, 1:].bool() - count = mask.sum().float() - dist.all_reduce(count, op=dist.ReduceOp.SUM) + local_counts.append(mask.sum().float()) + + if not local_counts: + return accumulation_counts + + # do the all_reduce once to avoid calling each loop + counts_tensor = torch.stack(local_counts) + dist.all_reduce(counts_tensor, op=dist.ReduceOp.SUM) + for i, count in enumerate(counts_tensor): group_idx = i // accumulation_steps key = int(group_idx * accumulation_steps) accumulation_counts[key] = accumulation_counts.get(key, 0.0) + count.item() diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index c46742e42c..a82e4686f3 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -525,3 +525,21 @@ def mock_inspect_return(*args, **kwargs): # "natolambert/tulu-v2-sft-mixture-science": 7468, # original data slightly different # } # _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"]) + + +class TestGetDenominator(unittest.TestCase): + @parameterized.expand([("token", "token"), ("0.5", 0.5), (0.5, 0.5), (1, 1.0)]) + def test_valid_inputs(self, input_val, expected): + self.assertEqual(utils.get_denominator(input_val), expected) + + @parameterized.expand( + [ + ("invalid", "loss_denominator must be a float value if not 'token'"), + ("-1", "loss_denominator must be greater than 0"), + (0, "loss_denominator must be greater than 0"), + ("0", "loss_denominator must be greater than 0"), + ] + ) + def test_invalid_inputs(self, input_val, error_msg): + with self.assertRaisesRegex(ValueError, error_msg): + utils.get_denominator(input_val) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 1c76d22b4a..3689e79106 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2491,3 +2491,19 @@ def get_beaker_experiment_url() -> str | None: return url except Exception: return None + + +def get_denominator(loss_denominator: str | float) -> float | str: + """ + Validates and converts the loss_denominator argument. + """ + if loss_denominator == "token": + return "token" + + try: + val = float(loss_denominator) + except ValueError: + raise ValueError(f"loss_denominator must be a float value if not 'token', got: {loss_denominator}") + if val <= 0: + raise ValueError(f"loss_denominator must be greater than 0 if not 'token', got: {loss_denominator}") + return val From 540e34af663a9c5b4d60fd28caf72f2644384d40 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 09:38:40 -0800 Subject: [PATCH 30/34] ruff fix --- open_instruct/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 3689e79106..d707efc184 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2503,7 +2503,7 @@ def get_denominator(loss_denominator: str | float) -> float | str: try: val = float(loss_denominator) except ValueError: - raise ValueError(f"loss_denominator must be a float value if not 'token', got: {loss_denominator}") + raise ValueError(f"loss_denominator must be a float value if not 'token', got: {loss_denominator}") from None if val <= 0: raise ValueError(f"loss_denominator must be greater than 0 if not 'token', got: {loss_denominator}") return val From b14d1ae1511eab2507a7f2a807719c94efb3b0da Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 10:40:28 -0800 Subject: [PATCH 31/34] finbarr comments --- open_instruct/grpo_fast.py | 14 +++++++------- open_instruct/utils.py | 5 +---- 2 files changed, 8 insertions(+), 11 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index 4d39af58a2..6ecf03d0e7 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -1053,17 +1053,18 @@ def calculate_token_counts( collated_response_masks: list[torch.Tensor], collated_tool_masks: list[torch.Tensor], ) -> dict[int, float]: - args = self.args - device = self.device - + """ + Compute the number of training tokens in each batch for this set of responses. + Return a dictionary of batch indices to the number of training tokens in that batch. + """ accumulation_counts: dict[int, float] = {} local_counts = [] for i, response_mask in enumerate(collated_response_masks): - response_mask = response_mask.to(device) + response_mask = response_mask.to(self.device) mask = response_mask[:, 1:].bool() - if args.mask_tool_use and args.tool_use: - tool_mask = collated_tool_masks[i].to(device) + if self.args.mask_tool_use and self.args.tool_use: + tool_mask = collated_tool_masks[i].to(self.device) mask &= tool_mask[:, 1:].bool() local_counts.append(mask.sum().float()) @@ -1189,7 +1190,6 @@ def train( for epoch_idx in range(args.num_epochs): # Pre-compute total tokens for each accumulation group if using "token" normalization # This ensures all minibatches in an accumulation group are normalized by the same total - accumulation_token_counts: dict[int, float] = {} if args.loss_denominator == "token": accumulation_token_counts = self.calculate_token_counts( accumulation_steps, collated_response_masks, collated_tool_masks diff --git a/open_instruct/utils.py b/open_instruct/utils.py index d707efc184..eacb3497e7 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2500,10 +2500,7 @@ def get_denominator(loss_denominator: str | float) -> float | str: if loss_denominator == "token": return "token" - try: - val = float(loss_denominator) - except ValueError: - raise ValueError(f"loss_denominator must be a float value if not 'token', got: {loss_denominator}") from None + val = float(loss_denominator) if val <= 0: raise ValueError(f"loss_denominator must be greater than 0 if not 'token', got: {loss_denominator}") return val From e2fa6456744659ba812297bc92c9c5e5230247d4 Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 10:48:20 -0800 Subject: [PATCH 32/34] fix test case --- open_instruct/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index a82e4686f3..e8eae913fa 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -534,7 +534,7 @@ def test_valid_inputs(self, input_val, expected): @parameterized.expand( [ - ("invalid", "loss_denominator must be a float value if not 'token'"), + ("invalid", "could not convert string to float"), ("-1", "loss_denominator must be greater than 0"), (0, "loss_denominator must be greater than 0"), ("0", "loss_denominator must be greater than 0"), From 23d5add9e8c55074e4f18bd73d0cef8d232a6b9a Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 10:59:06 -0800 Subject: [PATCH 33/34] masked mean import --- open_instruct/grpo_fast.py | 1 - open_instruct/model_utils.py | 17 ----------------- 2 files changed, 18 deletions(-) diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index c4adebe7ee..9aff0588a7 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -103,7 +103,6 @@ get_olmo3_generation_config, load_ref_policy, log_softmax_and_gather, - masked_mean, print_rich_single_line_metrics, print_rich_table, push_folder_to_hub, diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 9f5a13e61c..43fb4ee806 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -769,23 +769,6 @@ def exact_div(a, b, custom_error_message=""): return q -def masked_mean( - values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None -) -> torch.Tensor: - """Compute mean of tensor with masked values.""" - extra_dims = values.ndim - mask.ndim - if axis is None: - sum_dims = tuple(range(extra_dims, values.ndim)) - elif axis >= 0: - sum_dims = axis + extra_dims - else: - sum_dims = axis - numerator = (values * mask).sum(dim=sum_dims) - denom = mask.sum(dim=axis) if denominator is None else denominator - result = numerator / denom - return result.flatten(extra_dims).mean(-1) if result.ndim > extra_dims else result - - def estimate_kl(ref_logprobs_diff: torch.Tensor, ratio: torch.Tensor) -> torch.Tensor: """Compute 4 different KL divergence estimators between current and reference policies. From 3c9720bca6c7f24f0d5e6c41c062c6945d00ee6b Mon Sep 17 00:00:00 2001 From: Hamish Ivison Date: Mon, 1 Dec 2025 11:43:49 -0800 Subject: [PATCH 34/34] new masked mean --- open_instruct/rl_utils.py | 16 ++++++--- open_instruct/test_model_utils.py | 59 ------------------------------- open_instruct/test_rl_utils.py | 59 +++++++++++++++++++++++++++++++ 3 files changed, 71 insertions(+), 63 deletions(-) diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 1fadbf914a..2824feb901 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -254,7 +254,15 @@ def calculate_advantages_packed( def masked_mean( values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None ) -> torch.Tensor: - """Compute mean of tensor with a masked values.""" - numerator = (values * mask).sum(axis=axis) - denom = mask.sum(axis=axis) if denominator is None else denominator - return (numerator / denom).mean() + """Compute mean of tensor with masked values.""" + extra_dims = values.ndim - mask.ndim + if axis is None: + sum_dims = tuple(range(extra_dims, values.ndim)) + elif axis >= 0: + sum_dims = axis + extra_dims + else: + sum_dims = axis + numerator = (values * mask).sum(dim=sum_dims) + denom = mask.sum(dim=axis) if denominator is None else denominator + result = numerator / denom + return result.flatten(extra_dims).mean(-1) if result.ndim > extra_dims else result diff --git a/open_instruct/test_model_utils.py b/open_instruct/test_model_utils.py index 1c1fbd4110..4d4955d0bd 100644 --- a/open_instruct/test_model_utils.py +++ b/open_instruct/test_model_utils.py @@ -42,65 +42,6 @@ def test_batch_slicing_with_none_fields(self): self.assertIsNone(sliced.scores) -class TestMaskedMean(unittest.TestCase): - def test_original_axis_int(self): - values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(values, mask, axis=1) - expected = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 - self.assertAlmostEqual(result.item(), expected) - - def test_original_axis_none(self): - values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(values, mask, axis=None) - expected = (1.0 + 2.0 + 4.0) / 3 - self.assertAlmostEqual(result.item(), expected, places=5) - - def test_vectorized_axis_int(self): - kl_4BT = torch.tensor( - [ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], - [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], - [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], - ] - ) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(kl_4BT, mask, axis=1) - self.assertEqual(result.shape, (4,)) - expected_0 = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 - expected_1 = ((10.0 + 20.0) / 2 + 40.0 / 1) / 2 - expected_2 = ((100.0 + 200.0) / 2 + 400.0 / 1) / 2 - expected_3 = ((1000.0 + 2000.0) / 2 + 4000.0 / 1) / 2 - self.assertAlmostEqual(result[0].item(), expected_0) - self.assertAlmostEqual(result[1].item(), expected_1) - self.assertAlmostEqual(result[2].item(), expected_2) - self.assertAlmostEqual(result[3].item(), expected_3) - - def test_vectorized_axis_none(self): - kl_4BT = torch.tensor( - [ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], - [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], - [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], - ] - ) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(kl_4BT, mask, axis=None) - self.assertEqual(result.shape, (4,)) - expected = torch.tensor( - [ - (1.0 + 2.0 + 4.0) / 3, - (10.0 + 20.0 + 40.0) / 3, - (100.0 + 200.0 + 400.0) / 3, - (1000.0 + 2000.0 + 4000.0) / 3, - ] - ) - self.assertTrue(torch.allclose(result, expected)) - - class TestLogSoftmaxAndGather(unittest.TestCase): def test_log_softmax_and_gather_sliced_logits(self): batch_size, seq_len, vocab_size = 2, 160, 151936 diff --git a/open_instruct/test_rl_utils.py b/open_instruct/test_rl_utils.py index dfc7403f73..a42e6afc2a 100644 --- a/open_instruct/test_rl_utils.py +++ b/open_instruct/test_rl_utils.py @@ -263,5 +263,64 @@ def test_pack_sequences_logits(self): ) +class TestMaskedMean(unittest.TestCase): + def test_original_axis_int(self): + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(values, mask, axis=1) + expected = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 + self.assertAlmostEqual(result.item(), expected) + + def test_original_axis_none(self): + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(values, mask, axis=None) + expected = (1.0 + 2.0 + 4.0) / 3 + self.assertAlmostEqual(result.item(), expected, places=5) + + def test_vectorized_axis_int(self): + kl_4BT = torch.tensor( + [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], + [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], + [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], + ] + ) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(kl_4BT, mask, axis=1) + self.assertEqual(result.shape, (4,)) + expected_0 = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 + expected_1 = ((10.0 + 20.0) / 2 + 40.0 / 1) / 2 + expected_2 = ((100.0 + 200.0) / 2 + 400.0 / 1) / 2 + expected_3 = ((1000.0 + 2000.0) / 2 + 4000.0 / 1) / 2 + self.assertAlmostEqual(result[0].item(), expected_0) + self.assertAlmostEqual(result[1].item(), expected_1) + self.assertAlmostEqual(result[2].item(), expected_2) + self.assertAlmostEqual(result[3].item(), expected_3) + + def test_vectorized_axis_none(self): + kl_4BT = torch.tensor( + [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], + [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], + [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], + ] + ) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(kl_4BT, mask, axis=None) + self.assertEqual(result.shape, (4,)) + expected = torch.tensor( + [ + (1.0 + 2.0 + 4.0) / 3, + (10.0 + 20.0 + 40.0) / 3, + (100.0 + 200.0 + 400.0) / 3, + (1000.0 + 2000.0 + 4000.0) / 3, + ] + ) + self.assertTrue(torch.allclose(result, expected)) + + if __name__ == "__main__": unittest.main()