Skip to content

Add logprob comparator for megatron runner#777

Open
fzyzcjy wants to merge 13 commits intomainfrom
ac8452/1/11
Open

Add logprob comparator for megatron runner#777
fzyzcjy wants to merge 13 commits intomainfrom
ac8452/1/11

Conversation

@fzyzcjy
Copy link
Collaborator

@fzyzcjy fzyzcjy commented Mar 20, 2026

No description provided.

fzyzcjy added 12 commits March 20, 2026 18:47
Add dumper CLI arguments (--dumper-enable, --dumper-dir, per-phase config),
dumper_utils.py for SGLang/Megatron dumper integration, model.py hooks
for forward-only and forward-backward phases, rollout env var plumbing,
source patcher wiring in training actors, and basic e2e test.
Add _maybe_apply_dumper_overrides to disable heartbeats, force single
rollout, and disable eval/save when --dumper-enable is set.
Add conftest_dumper.py with shared source patcher YAML configs and
comparator helpers. Expand test_dumper.py with full MoE parallelism
coverage, field verification, and cross-framework (SGLang vs Megatron)
activation comparison. Update dumper_utils.py to nest engine dumps under
engines/ subdirectory.
Add DataclassArgparseBridge for dataclass-to-argparse conversion,
refactor typer_utils with dataclass_cli decorator, and add
parallel_utils for running parallel CLI commands. Include tests.
Add run_megatron worker batch.py for input batch construction and
cross-entropy loss, script_args.py for CLI-to-worker argument passing
via DataclassArgparseBridge, and args.py with CommonRunArgs/RunArgs
dataclass definitions. Include tests.
Add context parallelism (CP) zigzag slicing to batch preparation with
CP-aware next-token labels using position_ids-based gathering. Add cp
field to RunArgs. Include comprehensive CP tests.
Add path_utils.py for Megatron path and model script resolution,
prompt_utils.py for token ID generation (math/file/text modes),
and parallel_utils.py for parallel config parsing. Include tests.
Add worker/main.py for standalone Megatron forward/backward via torchrun,
worker_executor.py for building torchrun commands, run.py CLI command,
and package entry points (__main__.py, __init__.py). Include tests.
Add worker/replay.py for routing replay recording and loading with CP
zigzag and SP slicing support. Wire replay into worker main.py and CLI
(run.py, worker_executor.py, args.py). Include tests.
Add worker/top_k_print.py for printing top-k predictions per position
across all ranks. Wire into worker main.py and CLI args. Include tests.
Add worker/output.py for computing and saving per-token logprobs as
JSON files per rank. Wire into worker main.py, CLI args, and run.py.
Include tests.
Add logprob_comparator.py for comparing per-token logprobs between
baseline and target runs with statistical reporting. Include tests.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a crucial utility for validating the consistency of Megatron runner outputs by enabling the comparison of per-token log probabilities. It provides a robust mechanism to detect and report discrepancies between different runs, which is essential for ensuring model stability and correctness during development and optimization. This enhancement is a significant step in a larger series of improvements to the dumper and Megatron runner chain.

Highlights

  • New Logprob Comparator Module: A new Python module, logprob_comparator.py, has been added to compare per-token log probabilities between baseline and target Megatron runner runs.
  • Configurable Tolerance and Detailed Reporting: The comparator supports a configurable threshold for determining pass/fail status and provides detailed statistics on logprob differences, including max, mean, median, and percentile absolute differences, along with information about the worst-performing position.
  • Data Loading and Merging: The module includes functionality to load and merge logprob entries from multiple rank JSON files, handling deduplication for identical logprobs from TP ranks and skipping invalid entries.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a log probability comparator and associated tests. The implementation is solid and the tests are comprehensive. I've identified a minor performance improvement opportunity in the statistics calculation logic to avoid redundant computations. Overall, great work on this utility.

Comment on lines +147 to +169
sorted_diffs = sorted(diffs)
num = len(sorted_diffs)

baseline_worst = baseline_entries[max_diff_key]
target_worst = target_entries[max_diff_key]

return _CompareResult(
passed=statistics.mean(diffs) <= threshold,
num_positions=num,
max_abs_diff=max_abs_diff,
max_diff_position=max_diff_key[1],
max_diff_baseline_logprob=baseline_worst.logprob,
max_diff_target_logprob=target_worst.logprob,
max_diff_token_id=baseline_worst.token_id,
mean_abs_diff=statistics.mean(diffs),
median_abs_diff=statistics.median(diffs),
p95_abs_diff=sorted_diffs[int(num * 0.95)] if num > 0 else 0.0,
p99_abs_diff=sorted_diffs[int(num * 0.99)] if num > 0 else 0.0,
baseline_mean_logprob=statistics.mean(baseline_logprobs),
target_mean_logprob=statistics.mean(target_logprobs),
threshold=threshold,
per_position_diffs=diffs,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

There are a couple of minor inefficiencies in the statistics calculation:

  1. statistics.mean(diffs) is called twice (once for passed and once for mean_abs_diff).
  2. statistics.median(diffs) is called, which sorts the list internally, even though sorted_diffs is already available. This results in an unnecessary re-sort.

You can calculate the mean once and reuse it, and compute the median directly from the sorted list to improve performance.

    sorted_diffs = sorted(diffs)
    num = len(sorted_diffs)

    baseline_worst = baseline_entries[max_diff_key]
    target_worst = target_entries[max_diff_key]

    mean_abs_diff = statistics.mean(diffs)

    # Calculate median from the already sorted list to avoid re-sorting.
    if num % 2 == 1:
        # `num` is guaranteed to be > 0 here because `common_keys` is not empty.
        median_abs_diff = sorted_diffs[num // 2]
    else:
        median_abs_diff = (sorted_diffs[num // 2 - 1] + sorted_diffs[num // 2]) / 2

    return _CompareResult(
        passed=mean_abs_diff <= threshold,
        num_positions=num,
        max_abs_diff=max_abs_diff,
        max_diff_position=max_diff_key[1],
        max_diff_baseline_logprob=baseline_worst.logprob,
        max_diff_target_logprob=target_worst.logprob,
        max_diff_token_id=baseline_worst.token_id,
        mean_abs_diff=mean_abs_diff,
        median_abs_diff=median_abs_diff,
        p95_abs_diff=sorted_diffs[int(num * 0.95)] if num > 0 else 0.0,
        p99_abs_diff=sorted_diffs[int(num * 0.99)] if num > 0 else 0.0,
        baseline_mean_logprob=statistics.mean(baseline_logprobs),
        target_mean_logprob=statistics.mean(target_logprobs),
        threshold=threshold,
        per_position_diffs=diffs,
    )

Co-authored-by: Yueming Yuan <112649537+yueming-yuan@users.noreply.github.com>
@fzyzcjy fzyzcjy changed the base branch from ac8452/1/10 to main March 20, 2026 13:58
@fzyzcjy fzyzcjy requested a review from yushengsu-thu as a code owner March 20, 2026 13:58
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant