From f0f596ae6e34f1a4e9c6b58f7931e30f5f375b7b Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 1 Dec 2025 16:24:26 +0800 Subject: [PATCH 1/4] fix logprobs diff with all padding tokens --- xtuner/v1/rl/base/worker.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 85d1b37b5..f59ae2143 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -412,9 +412,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): rollout_entropy if sum_rollout_entropy is None else sum_rollout_entropy + rollout_entropy ) - if not mask.any(): # all padding tokens, skip - continue - if len(rollout_logprobs_list) > 0: # calculate logprob diff rollout_logprobs = rollout_logprobs_list[i][mask] # type: ignore[index] @@ -441,6 +438,11 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): std_diff = torch.std(rollout_logprobs - old_logprobs) all_diffs.append((min_diff, max_diff, mean_diff, std_diff)) + if not mask.any(): # all padding tokens, skip + self.logger.warning(f"Skip batch {i} as all tokens are padding.") + continue + + if len(rollout_logprobs_list) > 0: # calculate importance sampling weights cu_seq_lens = seq_ctx_list[i].cu_seq_lens_q num_tokens = cu_seq_lens[1:] - cu_seq_lens[:-1] @@ -458,7 +460,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): all_rollout_is_metrics.append(rollout_is_metrics) logger_msg = f"Rollout {rollout_idx}: " - tis_logger_msg = "" if len(all_rollout_is_metrics) > 0: rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE) From 6a5d87f11b25051f25089774acbc3f29494454d6 Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 1 Dec 2025 16:51:24 +0800 Subject: [PATCH 2/4] rm padding token logprobs --- xtuner/v1/rl/base/worker.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index f59ae2143..43bc29c46 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -424,18 +424,15 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): f"rollout_logprobs {rollout_logprobs.shape} vs old_logprobs {old_logprobs.shape}" ) if rollout_logprobs.numel() == 0: # pad 情况下是空的 - min_diff = torch.tensor(0.0) - max_diff = min_diff - std_diff = min_diff - mean_diff = min_diff + continue + + min_diff = torch.min(rollout_logprobs - old_logprobs) + max_diff = torch.max(rollout_logprobs - old_logprobs) + mean_diff = torch.mean(rollout_logprobs - old_logprobs) + if rollout_logprobs.numel() == 1: + std_diff = torch.tensor(0.0) else: - min_diff = torch.min(rollout_logprobs - old_logprobs) - max_diff = torch.max(rollout_logprobs - old_logprobs) - mean_diff = torch.mean(rollout_logprobs - old_logprobs) - if rollout_logprobs.numel() == 1: - std_diff = torch.tensor(0.0) - else: - std_diff = torch.std(rollout_logprobs - old_logprobs) + std_diff = torch.std(rollout_logprobs - old_logprobs) all_diffs.append((min_diff, max_diff, mean_diff, std_diff)) if not mask.any(): # all padding tokens, skip From 3ea1b19802ca68d7ea313295ec2a99f5d8dffe1c Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Mon, 1 Dec 2025 19:50:56 +0800 Subject: [PATCH 3/4] log abs mean and abs std --- xtuner/v1/rl/base/worker.py | 26 +++++++++++++++++--------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 43bc29c46..e2d4db10a 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -426,14 +426,19 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): if rollout_logprobs.numel() == 0: # pad 情况下是空的 continue - min_diff = torch.min(rollout_logprobs - old_logprobs) - max_diff = torch.max(rollout_logprobs - old_logprobs) - mean_diff = torch.mean(rollout_logprobs - old_logprobs) + diff = rollout_logprobs - old_logprobs + abs_diff = torch.abs(rollout_logprobs - old_logprobs) + min_diff = torch.min(diff) + max_diff = torch.max(diff) + mean_abs_diff = torch.mean(abs_diff) + mean_diff = torch.mean(diff) if rollout_logprobs.numel() == 1: std_diff = torch.tensor(0.0) + std_abs_diff = torch.tensor(0.0) else: - std_diff = torch.std(rollout_logprobs - old_logprobs) - all_diffs.append((min_diff, max_diff, mean_diff, std_diff)) + std_diff = torch.std(diff) + std_abs_diff = torch.std(abs_diff) + all_diffs.append((min_diff, max_diff, mean_diff, mean_abs_diff, std_diff, std_abs_diff)) if not mask.any(): # all padding tokens, skip self.logger.warning(f"Skip batch {i} as all tokens are padding.") @@ -469,15 +474,18 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): if len(rollout_logprobs_list) > 0: all_diffs_tensor = torch.stack([torch.tensor(d).to(DEVICE) for d in all_diffs]).to( dtype=torch.float32 - ) # n, 4 + ) # n, 6 min_diff_val = torch.min(all_diffs_tensor[:, 0]).item() max_diff_val = torch.max(all_diffs_tensor[:, 1]).item() mean_diff_val = torch.mean(all_diffs_tensor[:, 2]).item() - if all_diffs_tensor[:, 3].numel() <= 1: + mean_abs_diff_val = torch.mean(all_diffs_tensor[:, 3]).item() + if all_diffs_tensor[:, 4].numel() <= 1: std_diff_val = 0.0 + std_abs_diff_val = 0.0 else: - std_diff_val = torch.std(all_diffs_tensor[:, 3]).item() - logprob_logger_msg = f"\nlogprobs diff min {float(min_diff_val):.4f}, max {float(max_diff_val):.4f}, mean {float(mean_diff_val):.4f}, std {float(std_diff_val):.4f}, " + std_diff_val = torch.std(all_diffs_tensor[:, 4]).item() + std_abs_diff_val = torch.std(all_diffs_tensor[:, 5]).item() + logprob_logger_msg = f"\nlogprobs diff min {float(min_diff_val):.4f}, max {float(max_diff_val):.4f}, mean {float(mean_diff_val):.4f}, std {float(std_diff_val):.4f}, abs_mean {float(mean_abs_diff_val):.4f}, abs_std {float(std_abs_diff_val):.4f}" entropy_logger_msg = "" sum_entropy = cast(torch.Tensor, sum_entropy) From e5d39e21abaf7715d0bc99cc99c6007923011d0e Mon Sep 17 00:00:00 2001 From: YanhuiDua Date: Tue, 2 Dec 2025 16:06:56 +0800 Subject: [PATCH 4/4] unify mismatch log info --- xtuner/v1/rl/base/rollout_is.py | 27 ++++++----- xtuner/v1/rl/base/worker.py | 83 +++++---------------------------- 2 files changed, 27 insertions(+), 83 deletions(-) diff --git a/xtuner/v1/rl/base/rollout_is.py b/xtuner/v1/rl/base/rollout_is.py index 6a93ecdf3..81c6722a7 100644 --- a/xtuner/v1/rl/base/rollout_is.py +++ b/xtuner/v1/rl/base/rollout_is.py @@ -53,14 +53,24 @@ class RolloutImportanceSampling(BaseModel): rollout_is_mask_threshold: Optional[Tuple[float, float]] = None rollout_is_veto_threshold: Optional[Tuple[float, float]] = None - def compute_rollout_importance_weights( + def compute_rollout_importance_weights_and_metrics( self, old_log_prob: torch.Tensor, rollout_log_prob: torch.Tensor, num_tokens: torch.Tensor, response_mask: torch.Tensor, - ) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any]]: - return compute_rollout_importance_weights( + ) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any], dict[str, Any]]: + mismatch_metrics = compute_mismatch_metrics( + old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask + ) + mismatch_metrics_scalar = {} + for key, value in mismatch_metrics.items(): + if isinstance(value, torch.Tensor): + mismatch_metrics_scalar[f"mismatch/{key}"] = value.item() + else: + mismatch_metrics_scalar[f"mismatch/{key}"] = value + + rollout_is_weights, modified_response_mask, metrics_scalar = compute_rollout_importance_weights( old_log_prob, rollout_log_prob, num_tokens, @@ -71,6 +81,7 @@ def compute_rollout_importance_weights( rollout_is_mask_threshold=self.rollout_is_mask_threshold, rollout_is_veto_threshold=self.rollout_is_veto_threshold, ) + return rollout_is_weights, modified_response_mask, mismatch_metrics_scalar, metrics_scalar def compute_rollout_importance_weights( @@ -291,12 +302,6 @@ def compute_rollout_importance_weights( # This is different from rejection - padding must be zeroed regardless of mode rollout_is_weights = rollout_is_weights * response_mask - # Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics - mismatch_metrics = compute_mismatch_metrics( - old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask - ) - metrics.update(mismatch_metrics) - # Convert all tensor metrics to scalars for logging # Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad() metrics_scalar = {} @@ -477,7 +482,7 @@ def compute_mismatch_metrics( metrics["mismatch_training_ppl"] = training_ppl.detach().item() # Also log log-ppl for easier analysis (avoids exponential scale) - metrics["mismatch_training_log_ppl"] = (-mean_log_prob_training).mean().detach().item() + metrics["mismatch_training_entropy"] = (-mean_log_prob_training).mean().detach().item() # 2. Compute rollout mismatch metrics (only if rollout_log_probs available) if rollout_log_prob is not None: @@ -497,7 +502,7 @@ def compute_mismatch_metrics( mean_log_prob_rollout = masked_mean(rollout_log_prob, response_mask, axis=-1) # (batch_size,) rollout_ppl = torch.exp(-mean_log_prob_rollout).mean() # Batch mean of per-sequence PPL metrics["mismatch_rollout_ppl"] = rollout_ppl.detach().item() - metrics["mismatch_rollout_log_ppl"] = (-mean_log_prob_rollout).mean().detach().item() + metrics["mismatch_rollout_entropy"] = (-mean_log_prob_rollout).mean().detach().item() # 2d. Log PPL difference (sequence-level perplexity difference) # log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index e2d4db10a..3d004ce49 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -400,8 +400,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): f"rollout_logprobs_list {len(rollout_logprobs_list)} vs loss_ctx_input_list {len(loss_ctx_input_list)}" ) - all_diffs = [] all_rollout_is_metrics = [] + all_mismatch_metrics = [] for i, loss_ctx_input in enumerate(loss_ctx_input_list): mask = loss_ctx_input.shifted_labels != -100 entropy = -(cast(torch.Tensor, loss_ctx_input.old_logprobs) * mask).sum() @@ -412,34 +412,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): rollout_entropy if sum_rollout_entropy is None else sum_rollout_entropy + rollout_entropy ) - if len(rollout_logprobs_list) > 0: - # calculate logprob diff - rollout_logprobs = rollout_logprobs_list[i][mask] # type: ignore[index] - old_logprobs = loss_ctx_input.old_logprobs[mask] # type: ignore[index] - - assert len(rollout_logprobs.size()) == 1, ( - f"len(rollout_logprobs.size()): {len(rollout_logprobs.size())}" - ) - assert rollout_logprobs.shape == old_logprobs.shape, ( - f"rollout_logprobs {rollout_logprobs.shape} vs old_logprobs {old_logprobs.shape}" - ) - if rollout_logprobs.numel() == 0: # pad 情况下是空的 - continue - - diff = rollout_logprobs - old_logprobs - abs_diff = torch.abs(rollout_logprobs - old_logprobs) - min_diff = torch.min(diff) - max_diff = torch.max(diff) - mean_abs_diff = torch.mean(abs_diff) - mean_diff = torch.mean(diff) - if rollout_logprobs.numel() == 1: - std_diff = torch.tensor(0.0) - std_abs_diff = torch.tensor(0.0) - else: - std_diff = torch.std(diff) - std_abs_diff = torch.std(abs_diff) - all_diffs.append((min_diff, max_diff, mean_diff, mean_abs_diff, std_diff, std_abs_diff)) - if not mask.any(): # all padding tokens, skip self.logger.warning(f"Skip batch {i} as all tokens are padding.") continue @@ -449,8 +421,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): cu_seq_lens = seq_ctx_list[i].cu_seq_lens_q num_tokens = cu_seq_lens[1:] - cu_seq_lens[:-1] - rollout_is_weights, rollout_is_mask, rollout_is_metrics = ( - loss_cfg.rollout_is.compute_rollout_importance_weights( + rollout_is_weights, rollout_is_mask, mismatch_metrics, rollout_is_metrics = ( + loss_cfg.rollout_is.compute_rollout_importance_weights_and_metrics( old_log_prob=loss_ctx_input.old_logprobs, rollout_log_prob=rollout_logprobs_list[i], num_tokens=num_tokens, @@ -460,52 +432,19 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): loss_ctx_input.shifted_labels[~rollout_is_mask.bool()] = -100 # update loss mask loss_ctx_input.is_weights = rollout_is_weights all_rollout_is_metrics.append(rollout_is_metrics) + all_mismatch_metrics.append(mismatch_metrics) logger_msg = f"Rollout {rollout_idx}: " - tis_logger_msg = "" + + if len(all_mismatch_metrics) > 0: + mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE) + if len(mismatch_metrics) > 0: + logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}" + if len(all_rollout_is_metrics) > 0: rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE) if len(rollout_is_metrics) > 0: - tis_logger_msg = ( - f"\n\nrollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}" - ) - - logprob_logger_msg = "" - if len(rollout_logprobs_list) > 0: - all_diffs_tensor = torch.stack([torch.tensor(d).to(DEVICE) for d in all_diffs]).to( - dtype=torch.float32 - ) # n, 6 - min_diff_val = torch.min(all_diffs_tensor[:, 0]).item() - max_diff_val = torch.max(all_diffs_tensor[:, 1]).item() - mean_diff_val = torch.mean(all_diffs_tensor[:, 2]).item() - mean_abs_diff_val = torch.mean(all_diffs_tensor[:, 3]).item() - if all_diffs_tensor[:, 4].numel() <= 1: - std_diff_val = 0.0 - std_abs_diff_val = 0.0 - else: - std_diff_val = torch.std(all_diffs_tensor[:, 4]).item() - std_abs_diff_val = torch.std(all_diffs_tensor[:, 5]).item() - logprob_logger_msg = f"\nlogprobs diff min {float(min_diff_val):.4f}, max {float(max_diff_val):.4f}, mean {float(mean_diff_val):.4f}, std {float(std_diff_val):.4f}, abs_mean {float(mean_abs_diff_val):.4f}, abs_std {float(std_abs_diff_val):.4f}" - - entropy_logger_msg = "" - sum_entropy = cast(torch.Tensor, sum_entropy) - dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM) - avg_gen_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else 0 - entropy_logger_msg = f"avg generation entropy: {avg_gen_entropy:.4f}" - - rollout_entropy_logger_msg = "" - if sum_rollout_entropy is not None: - sum_rollout_entropy = cast(torch.Tensor, sum_rollout_entropy) - dist.all_reduce(sum_rollout_entropy, op=dist.ReduceOp.SUM) - avg_gen_entropy = sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else 0 - rollout_entropy_logger_msg = f"avg rollout generation entropy: {avg_gen_entropy:.4f}" - - if tis_logger_msg: - logger_msg += entropy_logger_msg - logger_msg += tis_logger_msg - else: - logger_msg += f"{entropy_logger_msg}, {rollout_entropy_logger_msg}" - logger_msg += logprob_logger_msg + logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}" self.logger.info(logger_msg) if self._has_ref: