Conversation
Add dumper CLI arguments (--dumper-enable, --dumper-dir, per-phase config), dumper_utils.py for SGLang/Megatron dumper integration, model.py hooks for forward-only and forward-backward phases, rollout env var plumbing, source patcher wiring in training actors, and basic e2e test.
Add _maybe_apply_dumper_overrides to disable heartbeats, force single rollout, and disable eval/save when --dumper-enable is set.
Add conftest_dumper.py with shared source patcher YAML configs and comparator helpers. Expand test_dumper.py with full MoE parallelism coverage, field verification, and cross-framework (SGLang vs Megatron) activation comparison. Update dumper_utils.py to nest engine dumps under engines/ subdirectory.
Add DataclassArgparseBridge for dataclass-to-argparse conversion, refactor typer_utils with dataclass_cli decorator, and add parallel_utils for running parallel CLI commands. Include tests.
Add run_megatron worker batch.py for input batch construction and cross-entropy loss, script_args.py for CLI-to-worker argument passing via DataclassArgparseBridge, and args.py with CommonRunArgs/RunArgs dataclass definitions. Include tests.
Add context parallelism (CP) zigzag slicing to batch preparation with CP-aware next-token labels using position_ids-based gathering. Add cp field to RunArgs. Include comprehensive CP tests.
Add path_utils.py for Megatron path and model script resolution, prompt_utils.py for token ID generation (math/file/text modes), and parallel_utils.py for parallel config parsing. Include tests.
Add worker/main.py for standalone Megatron forward/backward via torchrun, worker_executor.py for building torchrun commands, run.py CLI command, and package entry points (__main__.py, __init__.py). Include tests.
Add worker/replay.py for routing replay recording and loading with CP zigzag and SP slicing support. Wire replay into worker main.py and CLI (run.py, worker_executor.py, args.py). Include tests.
Add worker/top_k_print.py for printing top-k predictions per position across all ranks. Wire into worker main.py and CLI args. Include tests.
Add worker/output.py for computing and saving per-token logprobs as JSON files per rank. Wire into worker main.py, CLI args, and run.py. Include tests.
Add logprob_comparator.py for comparing per-token logprobs between baseline and target runs with statistical reporting. Include tests.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a crucial utility for validating the consistency of Megatron runner outputs by enabling the comparison of per-token log probabilities. It provides a robust mechanism to detect and report discrepancies between different runs, which is essential for ensuring model stability and correctness during development and optimization. This enhancement is a significant step in a larger series of improvements to the dumper and Megatron runner chain. Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a log probability comparator and associated tests. The implementation is solid and the tests are comprehensive. I've identified a minor performance improvement opportunity in the statistics calculation logic to avoid redundant computations. Overall, great work on this utility.
| sorted_diffs = sorted(diffs) | ||
| num = len(sorted_diffs) | ||
|
|
||
| baseline_worst = baseline_entries[max_diff_key] | ||
| target_worst = target_entries[max_diff_key] | ||
|
|
||
| return _CompareResult( | ||
| passed=statistics.mean(diffs) <= threshold, | ||
| num_positions=num, | ||
| max_abs_diff=max_abs_diff, | ||
| max_diff_position=max_diff_key[1], | ||
| max_diff_baseline_logprob=baseline_worst.logprob, | ||
| max_diff_target_logprob=target_worst.logprob, | ||
| max_diff_token_id=baseline_worst.token_id, | ||
| mean_abs_diff=statistics.mean(diffs), | ||
| median_abs_diff=statistics.median(diffs), | ||
| p95_abs_diff=sorted_diffs[int(num * 0.95)] if num > 0 else 0.0, | ||
| p99_abs_diff=sorted_diffs[int(num * 0.99)] if num > 0 else 0.0, | ||
| baseline_mean_logprob=statistics.mean(baseline_logprobs), | ||
| target_mean_logprob=statistics.mean(target_logprobs), | ||
| threshold=threshold, | ||
| per_position_diffs=diffs, | ||
| ) |
There was a problem hiding this comment.
There are a couple of minor inefficiencies in the statistics calculation:
statistics.mean(diffs)is called twice (once forpassedand once formean_abs_diff).statistics.median(diffs)is called, which sorts the list internally, even thoughsorted_diffsis already available. This results in an unnecessary re-sort.
You can calculate the mean once and reuse it, and compute the median directly from the sorted list to improve performance.
sorted_diffs = sorted(diffs)
num = len(sorted_diffs)
baseline_worst = baseline_entries[max_diff_key]
target_worst = target_entries[max_diff_key]
mean_abs_diff = statistics.mean(diffs)
# Calculate median from the already sorted list to avoid re-sorting.
if num % 2 == 1:
# `num` is guaranteed to be > 0 here because `common_keys` is not empty.
median_abs_diff = sorted_diffs[num // 2]
else:
median_abs_diff = (sorted_diffs[num // 2 - 1] + sorted_diffs[num // 2]) / 2
return _CompareResult(
passed=mean_abs_diff <= threshold,
num_positions=num,
max_abs_diff=max_abs_diff,
max_diff_position=max_diff_key[1],
max_diff_baseline_logprob=baseline_worst.logprob,
max_diff_target_logprob=target_worst.logprob,
max_diff_token_id=baseline_worst.token_id,
mean_abs_diff=mean_abs_diff,
median_abs_diff=median_abs_diff,
p95_abs_diff=sorted_diffs[int(num * 0.95)] if num > 0 else 0.0,
p99_abs_diff=sorted_diffs[int(num * 0.99)] if num > 0 else 0.0,
baseline_mean_logprob=statistics.mean(baseline_logprobs),
target_mean_logprob=statistics.mean(target_logprobs),
threshold=threshold,
per_position_diffs=diffs,
)
No description provided.