Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions xtuner/v1/rl/base/rollout_is.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,24 @@ class RolloutImportanceSampling(BaseModel):
rollout_is_mask_threshold: Optional[Tuple[float, float]] = None
rollout_is_veto_threshold: Optional[Tuple[float, float]] = None

def compute_rollout_importance_weights(
def compute_rollout_importance_weights_and_metrics(
self,
old_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
num_tokens: torch.Tensor,
response_mask: torch.Tensor,
) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any]]:
return compute_rollout_importance_weights(
) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any], dict[str, Any]]:
mismatch_metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
)
mismatch_metrics_scalar = {}
for key, value in mismatch_metrics.items():
if isinstance(value, torch.Tensor):
mismatch_metrics_scalar[f"mismatch/{key}"] = value.item()
else:
mismatch_metrics_scalar[f"mismatch/{key}"] = value

rollout_is_weights, modified_response_mask, metrics_scalar = compute_rollout_importance_weights(
old_log_prob,
rollout_log_prob,
num_tokens,
Expand All @@ -71,6 +81,7 @@ def compute_rollout_importance_weights(
rollout_is_mask_threshold=self.rollout_is_mask_threshold,
rollout_is_veto_threshold=self.rollout_is_veto_threshold,
)
return rollout_is_weights, modified_response_mask, mismatch_metrics_scalar, metrics_scalar


def compute_rollout_importance_weights(
Expand Down Expand Up @@ -291,12 +302,6 @@ def compute_rollout_importance_weights(
# This is different from rejection - padding must be zeroed regardless of mode
rollout_is_weights = rollout_is_weights * response_mask

# Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics
mismatch_metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
)
metrics.update(mismatch_metrics)

# Convert all tensor metrics to scalars for logging
# Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()
metrics_scalar = {}
Expand Down Expand Up @@ -477,7 +482,7 @@ def compute_mismatch_metrics(
metrics["mismatch_training_ppl"] = training_ppl.detach().item()

# Also log log-ppl for easier analysis (avoids exponential scale)
metrics["mismatch_training_log_ppl"] = (-mean_log_prob_training).mean().detach().item()
metrics["mismatch_training_entropy"] = (-mean_log_prob_training).mean().detach().item()

# 2. Compute rollout mismatch metrics (only if rollout_log_probs available)
if rollout_log_prob is not None:
Expand All @@ -497,7 +502,7 @@ def compute_mismatch_metrics(
mean_log_prob_rollout = masked_mean(rollout_log_prob, response_mask, axis=-1) # (batch_size,)
rollout_ppl = torch.exp(-mean_log_prob_rollout).mean() # Batch mean of per-sequence PPL
metrics["mismatch_rollout_ppl"] = rollout_ppl.detach().item()
metrics["mismatch_rollout_log_ppl"] = (-mean_log_prob_rollout).mean().detach().item()
metrics["mismatch_rollout_entropy"] = (-mean_log_prob_rollout).mean().detach().item()

# 2d. Log PPL difference (sequence-level perplexity difference)
# log_ppl_diff = mean_log_prob_rollout - mean_log_prob_training
Expand Down
77 changes: 11 additions & 66 deletions xtuner/v1/rl/base/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -400,8 +400,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
f"rollout_logprobs_list {len(rollout_logprobs_list)} vs loss_ctx_input_list {len(loss_ctx_input_list)}"
)

all_diffs = []
all_rollout_is_metrics = []
all_mismatch_metrics = []
for i, loss_ctx_input in enumerate(loss_ctx_input_list):
mask = loss_ctx_input.shifted_labels != -100
entropy = -(cast(torch.Tensor, loss_ctx_input.old_logprobs) * mask).sum()
Expand All @@ -413,40 +413,16 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
)

if not mask.any(): # all padding tokens, skip
self.logger.warning(f"Skip batch {i} as all tokens are padding.")
continue

if len(rollout_logprobs_list) > 0:
# calculate logprob diff
rollout_logprobs = rollout_logprobs_list[i][mask] # type: ignore[index]
old_logprobs = loss_ctx_input.old_logprobs[mask] # type: ignore[index]

assert len(rollout_logprobs.size()) == 1, (
f"len(rollout_logprobs.size()): {len(rollout_logprobs.size())}"
)
assert rollout_logprobs.shape == old_logprobs.shape, (
f"rollout_logprobs {rollout_logprobs.shape} vs old_logprobs {old_logprobs.shape}"
)
if rollout_logprobs.numel() == 0: # pad 情况下是空的
min_diff = torch.tensor(0.0)
max_diff = min_diff
std_diff = min_diff
mean_diff = min_diff
else:
min_diff = torch.min(rollout_logprobs - old_logprobs)
max_diff = torch.max(rollout_logprobs - old_logprobs)
mean_diff = torch.mean(rollout_logprobs - old_logprobs)
if rollout_logprobs.numel() == 1:
std_diff = torch.tensor(0.0)
else:
std_diff = torch.std(rollout_logprobs - old_logprobs)
all_diffs.append((min_diff, max_diff, mean_diff, std_diff))

# calculate importance sampling weights
cu_seq_lens = seq_ctx_list[i].cu_seq_lens_q
num_tokens = cu_seq_lens[1:] - cu_seq_lens[:-1]

rollout_is_weights, rollout_is_mask, rollout_is_metrics = (
loss_cfg.rollout_is.compute_rollout_importance_weights(
rollout_is_weights, rollout_is_mask, mismatch_metrics, rollout_is_metrics = (
loss_cfg.rollout_is.compute_rollout_importance_weights_and_metrics(
old_log_prob=loss_ctx_input.old_logprobs,
rollout_log_prob=rollout_logprobs_list[i],
num_tokens=num_tokens,
Expand All @@ -456,50 +432,19 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
loss_ctx_input.shifted_labels[~rollout_is_mask.bool()] = -100 # update loss mask
loss_ctx_input.is_weights = rollout_is_weights
all_rollout_is_metrics.append(rollout_is_metrics)
all_mismatch_metrics.append(mismatch_metrics)

logger_msg = f"Rollout {rollout_idx}: "

tis_logger_msg = ""
if len(all_mismatch_metrics) > 0:
mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE)
if len(mismatch_metrics) > 0:
logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}"

if len(all_rollout_is_metrics) > 0:
rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE)
if len(rollout_is_metrics) > 0:
tis_logger_msg = (
f"\n\nrollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}"
)

logprob_logger_msg = ""
if len(rollout_logprobs_list) > 0:
all_diffs_tensor = torch.stack([torch.tensor(d).to(DEVICE) for d in all_diffs]).to(
dtype=torch.float32
) # n, 4
min_diff_val = torch.min(all_diffs_tensor[:, 0]).item()
max_diff_val = torch.max(all_diffs_tensor[:, 1]).item()
mean_diff_val = torch.mean(all_diffs_tensor[:, 2]).item()
if all_diffs_tensor[:, 3].numel() <= 1:
std_diff_val = 0.0
else:
std_diff_val = torch.std(all_diffs_tensor[:, 3]).item()
logprob_logger_msg = f"\nlogprobs diff min {float(min_diff_val):.4f}, max {float(max_diff_val):.4f}, mean {float(mean_diff_val):.4f}, std {float(std_diff_val):.4f}, "

entropy_logger_msg = ""
sum_entropy = cast(torch.Tensor, sum_entropy)
dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM)
avg_gen_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else 0
entropy_logger_msg = f"avg generation entropy: {avg_gen_entropy:.4f}"

rollout_entropy_logger_msg = ""
if sum_rollout_entropy is not None:
sum_rollout_entropy = cast(torch.Tensor, sum_rollout_entropy)
dist.all_reduce(sum_rollout_entropy, op=dist.ReduceOp.SUM)
avg_gen_entropy = sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else 0
rollout_entropy_logger_msg = f"avg rollout generation entropy: {avg_gen_entropy:.4f}"

if tis_logger_msg:
logger_msg += entropy_logger_msg
logger_msg += tis_logger_msg
else:
logger_msg += f"{entropy_logger_msg}, {rollout_entropy_logger_msg}"
logger_msg += logprob_logger_msg
logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}"
self.logger.info(logger_msg)

if self._has_ref:
Expand Down