diff --git a/open_instruct/grpo_fast.py b/open_instruct/grpo_fast.py index f6ae53428..9aff0588a 100644 --- a/open_instruct/grpo_fast.py +++ b/open_instruct/grpo_fast.py @@ -103,13 +103,12 @@ get_olmo3_generation_config, load_ref_policy, log_softmax_and_gather, - masked_mean, print_rich_single_line_metrics, print_rich_table, push_folder_to_hub, ) from open_instruct.queue_types import GenerationResult, PromptRequest, RequestInfo, TokenStatistics -from open_instruct.rl_utils import PackedSequences, Timer, pack_sequences +from open_instruct.rl_utils import PackedSequences, Timer, masked_mean, pack_sequences from open_instruct.utils import ( ArgumentParserPlus, BeakerRuntimeConfig, @@ -251,10 +250,11 @@ class Args: """the KL estimator to use""" pack_length: int = 512 """the length of the pack (you should prob set to the max length of the model)""" - masked_mean_axis: int | None = None - """the axis to compute the mean of the masked values""" - masked_mean_denominator: float | None = None - """Optional constant denominator for masked_mean; if set, divides by this instead of mask.sum""" + loss_denominator: str = "token" + """Optional constant denominator for masked_mean; can be "token" or a float value. + when "token", the loss is divided by the total number of tokens in the batch (standard LM training). + when a float value, the loss is divided by this value (ideally, max tokens in batch, per Dr GRPO). + """ alpha: float = 0.6 """The alpha value for doing polyak updates (ref_param = alpha * param + (1 - alpha) * ref_param) reference: [TR-DPO](https://huggingface.co/papers/2404.09656), but it's actually pretty commonly @@ -458,10 +458,7 @@ def __post_init__(self): "Cannot use both `use_vllm_logprobs` and `truncated_importance_sampling_ratio_cap`. " "use_vllm_logprobs sets old_logprobs to vLLM logprobs, making importance sampling pointless." ) - if self.masked_mean_denominator is not None: - assert self.masked_mean_denominator > 0, ( - f"masked_mean_denominator (={self.masked_mean_denominator}) must be greater than 0!" - ) + self.loss_denominator = utils.get_denominator(self.loss_denominator) assert self.num_samples_per_prompt_rollout > 0, "Number of samples per prompt must be greater than 0!" if self.num_samples_per_prompt_rollout == 1: logger.warning("num_samples_per_prompt_rollout is 1. This reduces GRPO to REINFORCE.") @@ -501,7 +498,7 @@ def __post_init__(self): else: self.gs_checkpoint_state_dir = f"{self.gs_bucket_path}/{checkpoint_dir_name}" # On GCP, all checkpointing must happen on filestore. - # TODO(finbarrtimbers): Chanage this so we can checkpoint to GCS. + # TODO(finbarrtimbers): Change this so we can checkpoint to GCS. # TODO(finbarrtimbers): Move this logic to mason.py once we refactor config. if not checkpoint_dir_name.startswith("/filestore"): self.checkpoint_state_dir = f"/filestore{self.checkpoint_state_dir}" @@ -1051,6 +1048,42 @@ def compute_logprobs( return collated_logprobs, collated_entropies + def calculate_token_counts( + self, + accumulation_steps: int, + collated_response_masks: list[torch.Tensor], + collated_tool_masks: list[torch.Tensor], + ) -> dict[int, float]: + """ + Compute the number of training tokens in each batch for this set of responses. + Return a dictionary of batch indices to the number of training tokens in that batch. + """ + accumulation_counts: dict[int, float] = {} + local_counts = [] + + for i, response_mask in enumerate(collated_response_masks): + response_mask = response_mask.to(self.device) + mask = response_mask[:, 1:].bool() + if self.args.mask_tool_use and self.args.tool_use: + tool_mask = collated_tool_masks[i].to(self.device) + mask &= tool_mask[:, 1:].bool() + + local_counts.append(mask.sum().float()) + + if not local_counts: + return accumulation_counts + + # do the all_reduce once to avoid calling each loop + counts_tensor = torch.stack(local_counts) + dist.all_reduce(counts_tensor, op=dist.ReduceOp.SUM) + + for i, count in enumerate(counts_tensor): + group_idx = i // accumulation_steps + key = int(group_idx * accumulation_steps) + accumulation_counts[key] = accumulation_counts.get(key, 0.0) + count.item() + + return accumulation_counts + def train( self, collated_query_responses, @@ -1153,6 +1186,18 @@ def train( ratio_stats = torch.zeros(len(collated_query_responses)) entropy_stats = torch.zeros(len(collated_query_responses)) for epoch_idx in range(args.num_epochs): + # Pre-compute total tokens for each accumulation group if using "token" normalization + # This ensures all minibatches in an accumulation group are normalized by the same total + if args.loss_denominator == "token": + accumulation_token_counts = self.calculate_token_counts( + accumulation_steps, collated_response_masks, collated_tool_masks + ) + else: + accumulation_token_counts = { + int(group_idx * accumulation_steps): args.loss_denominator + for group_idx in range((len(collated_query_responses) // accumulation_steps) + 1) + } + for i in range(len(collated_query_responses)): mb_query_responses = collated_query_responses[i] mb_tool_mask = collated_tool_masks[i] @@ -1162,6 +1207,11 @@ def train( # if masking snippets, do it here. if args.mask_tool_use and args.tool_use: mb_response_masks_bool = mb_response_masks[:, 1:].bool() & mb_tool_mask[:, 1:].bool() + + # retrieve the loss denominator for the current batch + batch_start = (i // accumulation_steps) * accumulation_steps + loss_denominator = accumulation_token_counts[batch_start] + mb_attention_mask = collated_attention_masks[i] mb_position_id = collated_position_ids[i] mb_local_logprobs, mb_entropy = self.forward( @@ -1277,16 +1327,15 @@ def train( kl = kl_4BT[args.kl_estimator] # grpo change: directly subtract KL in loss (add) loss = masked_mean( - pg_loss_max + (args.beta * kl), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, + pg_loss_max + (args.beta * kl), mb_response_masks_bool, None, loss_denominator ) else: - loss = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) - loss = loss / accumulation_steps + loss = masked_mean(pg_loss_max, mb_response_masks_bool, None, loss_denominator) + + # we already took world size into account via the tokens + if dist.is_available() and dist.is_initialized(): + loss *= dist.get_world_size() + # Clear CUDA cache before backward pass to free memory for reduce_scatter operations torch.cuda.empty_cache() self.model.backward(loss) @@ -1296,28 +1345,15 @@ def train( with torch.no_grad(): if args.load_ref_policy: # NOTE: in packed implementation, kl calculation are averages over response tokens - kl_stats_4M[:, i] = masked_mean( - kl_4BT, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() + kl_stats_4M[:, i] = masked_mean(kl_4BT, mb_response_masks_bool).float() kl_loss_stats[i] = kl_stats_4M[args.kl_estimator, i] * args.beta - pg_clipfrac_stats[i] = masked_mean( - (pg_losses2 > pg_losses).float(), - mb_response_masks_bool, - args.masked_mean_axis, - args.masked_mean_denominator, - ) - pg_loss_stats[i] = masked_mean( - pg_loss_max, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) + pg_clipfrac_stats[i] = masked_mean((pg_losses2 > pg_losses).float(), mb_response_masks_bool) + pg_loss_stats[i] = masked_mean(pg_loss_max, mb_response_masks_bool) loss_stats[i] = loss - ratio_stats[i] = masked_mean( - ratio, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ) + ratio_stats[i] = masked_mean(ratio, mb_response_masks_bool) if args.record_entropy: # Calculate entropy statistics - entropy_stats[i] = masked_mean( - mb_entropy, mb_response_masks_bool, args.masked_mean_axis, args.masked_mean_denominator - ).float() + entropy_stats[i] = masked_mean(mb_entropy, mb_response_masks_bool).float() with torch.no_grad(): if args.load_ref_policy: diff --git a/open_instruct/model_utils.py b/open_instruct/model_utils.py index 9f5a13e61..43fb4ee80 100644 --- a/open_instruct/model_utils.py +++ b/open_instruct/model_utils.py @@ -769,23 +769,6 @@ def exact_div(a, b, custom_error_message=""): return q -def masked_mean( - values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None -) -> torch.Tensor: - """Compute mean of tensor with masked values.""" - extra_dims = values.ndim - mask.ndim - if axis is None: - sum_dims = tuple(range(extra_dims, values.ndim)) - elif axis >= 0: - sum_dims = axis + extra_dims - else: - sum_dims = axis - numerator = (values * mask).sum(dim=sum_dims) - denom = mask.sum(dim=axis) if denominator is None else denominator - result = numerator / denom - return result.flatten(extra_dims).mean(-1) if result.ndim > extra_dims else result - - def estimate_kl(ref_logprobs_diff: torch.Tensor, ratio: torch.Tensor) -> torch.Tensor: """Compute 4 different KL divergence estimators between current and reference policies. diff --git a/open_instruct/rl_utils.py b/open_instruct/rl_utils.py index 7dad7a600..2824feb90 100644 --- a/open_instruct/rl_utils.py +++ b/open_instruct/rl_utils.py @@ -249,3 +249,20 @@ def calculate_advantages_packed( advantages = np.stack(advantages_reversed[::-1], axis=1) returns = advantages + values return advantages, returns + + +def masked_mean( + values: torch.Tensor, mask: torch.Tensor, axis: int | None = None, denominator: float | None = None +) -> torch.Tensor: + """Compute mean of tensor with masked values.""" + extra_dims = values.ndim - mask.ndim + if axis is None: + sum_dims = tuple(range(extra_dims, values.ndim)) + elif axis >= 0: + sum_dims = axis + extra_dims + else: + sum_dims = axis + numerator = (values * mask).sum(dim=sum_dims) + denom = mask.sum(dim=axis) if denominator is None else denominator + result = numerator / denom + return result.flatten(extra_dims).mean(-1) if result.ndim > extra_dims else result diff --git a/open_instruct/test_model_utils.py b/open_instruct/test_model_utils.py index 1c1fbd411..4d4955d0b 100644 --- a/open_instruct/test_model_utils.py +++ b/open_instruct/test_model_utils.py @@ -42,65 +42,6 @@ def test_batch_slicing_with_none_fields(self): self.assertIsNone(sliced.scores) -class TestMaskedMean(unittest.TestCase): - def test_original_axis_int(self): - values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(values, mask, axis=1) - expected = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 - self.assertAlmostEqual(result.item(), expected) - - def test_original_axis_none(self): - values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(values, mask, axis=None) - expected = (1.0 + 2.0 + 4.0) / 3 - self.assertAlmostEqual(result.item(), expected, places=5) - - def test_vectorized_axis_int(self): - kl_4BT = torch.tensor( - [ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], - [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], - [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], - ] - ) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(kl_4BT, mask, axis=1) - self.assertEqual(result.shape, (4,)) - expected_0 = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 - expected_1 = ((10.0 + 20.0) / 2 + 40.0 / 1) / 2 - expected_2 = ((100.0 + 200.0) / 2 + 400.0 / 1) / 2 - expected_3 = ((1000.0 + 2000.0) / 2 + 4000.0 / 1) / 2 - self.assertAlmostEqual(result[0].item(), expected_0) - self.assertAlmostEqual(result[1].item(), expected_1) - self.assertAlmostEqual(result[2].item(), expected_2) - self.assertAlmostEqual(result[3].item(), expected_3) - - def test_vectorized_axis_none(self): - kl_4BT = torch.tensor( - [ - [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], - [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], - [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], - [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], - ] - ) - mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) - result = open_instruct.model_utils.masked_mean(kl_4BT, mask, axis=None) - self.assertEqual(result.shape, (4,)) - expected = torch.tensor( - [ - (1.0 + 2.0 + 4.0) / 3, - (10.0 + 20.0 + 40.0) / 3, - (100.0 + 200.0 + 400.0) / 3, - (1000.0 + 2000.0 + 4000.0) / 3, - ] - ) - self.assertTrue(torch.allclose(result, expected)) - - class TestLogSoftmaxAndGather(unittest.TestCase): def test_log_softmax_and_gather_sliced_logits(self): batch_size, seq_len, vocab_size = 2, 160, 151936 diff --git a/open_instruct/test_rl_utils.py b/open_instruct/test_rl_utils.py index dfc7403f7..a42e6afc2 100644 --- a/open_instruct/test_rl_utils.py +++ b/open_instruct/test_rl_utils.py @@ -263,5 +263,64 @@ def test_pack_sequences_logits(self): ) +class TestMaskedMean(unittest.TestCase): + def test_original_axis_int(self): + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(values, mask, axis=1) + expected = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 + self.assertAlmostEqual(result.item(), expected) + + def test_original_axis_none(self): + values = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(values, mask, axis=None) + expected = (1.0 + 2.0 + 4.0) / 3 + self.assertAlmostEqual(result.item(), expected, places=5) + + def test_vectorized_axis_int(self): + kl_4BT = torch.tensor( + [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], + [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], + [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], + ] + ) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(kl_4BT, mask, axis=1) + self.assertEqual(result.shape, (4,)) + expected_0 = ((1.0 + 2.0) / 2 + 4.0 / 1) / 2 + expected_1 = ((10.0 + 20.0) / 2 + 40.0 / 1) / 2 + expected_2 = ((100.0 + 200.0) / 2 + 400.0 / 1) / 2 + expected_3 = ((1000.0 + 2000.0) / 2 + 4000.0 / 1) / 2 + self.assertAlmostEqual(result[0].item(), expected_0) + self.assertAlmostEqual(result[1].item(), expected_1) + self.assertAlmostEqual(result[2].item(), expected_2) + self.assertAlmostEqual(result[3].item(), expected_3) + + def test_vectorized_axis_none(self): + kl_4BT = torch.tensor( + [ + [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], + [[10.0, 20.0, 30.0], [40.0, 50.0, 60.0]], + [[100.0, 200.0, 300.0], [400.0, 500.0, 600.0]], + [[1000.0, 2000.0, 3000.0], [4000.0, 5000.0, 6000.0]], + ] + ) + mask = torch.tensor([[1.0, 1.0, 0.0], [1.0, 0.0, 0.0]]) + result = rl_utils.masked_mean(kl_4BT, mask, axis=None) + self.assertEqual(result.shape, (4,)) + expected = torch.tensor( + [ + (1.0 + 2.0 + 4.0) / 3, + (10.0 + 20.0 + 40.0) / 3, + (100.0 + 200.0 + 400.0) / 3, + (1000.0 + 2000.0 + 4000.0) / 3, + ] + ) + self.assertTrue(torch.allclose(result, expected)) + + if __name__ == "__main__": unittest.main() diff --git a/open_instruct/test_utils.py b/open_instruct/test_utils.py index c46742e42..e8eae913f 100644 --- a/open_instruct/test_utils.py +++ b/open_instruct/test_utils.py @@ -525,3 +525,21 @@ def mock_inspect_return(*args, **kwargs): # "natolambert/tulu-v2-sft-mixture-science": 7468, # original data slightly different # } # _ = get_datasets(dataset_mixer, splits=["train"], columns_to_keep=["messages"]) + + +class TestGetDenominator(unittest.TestCase): + @parameterized.expand([("token", "token"), ("0.5", 0.5), (0.5, 0.5), (1, 1.0)]) + def test_valid_inputs(self, input_val, expected): + self.assertEqual(utils.get_denominator(input_val), expected) + + @parameterized.expand( + [ + ("invalid", "could not convert string to float"), + ("-1", "loss_denominator must be greater than 0"), + (0, "loss_denominator must be greater than 0"), + ("0", "loss_denominator must be greater than 0"), + ] + ) + def test_invalid_inputs(self, input_val, error_msg): + with self.assertRaisesRegex(ValueError, error_msg): + utils.get_denominator(input_val) diff --git a/open_instruct/utils.py b/open_instruct/utils.py index 1c76d22b4..eacb3497e 100644 --- a/open_instruct/utils.py +++ b/open_instruct/utils.py @@ -2491,3 +2491,16 @@ def get_beaker_experiment_url() -> str | None: return url except Exception: return None + + +def get_denominator(loss_denominator: str | float) -> float | str: + """ + Validates and converts the loss_denominator argument. + """ + if loss_denominator == "token": + return "token" + + val = float(loss_denominator) + if val <= 0: + raise ValueError(f"loss_denominator must be greater than 0 if not 'token', got: {loss_denominator}") + return val