Skip to content

Commit eb40338

Browse files
Merge branch 'main' into clairlee/dev/hybrid
2 parents cf7eb3f + cfe8cc0 commit eb40338

3 files changed

Lines changed: 172 additions & 2 deletions

File tree

primus/backends/megatron/patches/checkpoint_patches.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,79 @@ def patch_filesystem_writer_async(ctx: PatchContext):
4646
log_rank_0(
4747
"[Patch:megatron.checkpoint.filesystem_writer_async] Patch FileSystemWriterAsync successfully."
4848
)
49+
50+
51+
@register_patch(
52+
"megatron.checkpoint.save_checkpoint",
53+
backend="megatron",
54+
phase="before_train",
55+
description="Wrap save_checkpoint to skip saving at the last iteration",
56+
)
57+
def patch_save_checkpoint(ctx: PatchContext):
58+
"""
59+
Wrap Megatron's save_checkpoint to skip saving at the last iteration
60+
61+
This patch monkey-patches the save_checkpoint function in
62+
megatron.training.training module to check if:
63+
1. disable_last_saving is True
64+
2. Current iteration equals train_iters (final iteration)
65+
66+
If both conditions are met, the checkpoint save is skipped.
67+
"""
68+
try:
69+
import megatron.training.training as training_module
70+
except ImportError as e:
71+
log_rank_0(f"[Patch:megatron.checkpoint.save_checkpoint] Skip patch (Megatron not available): {e}")
72+
return
73+
74+
# Save original function
75+
original_save_checkpoint = training_module.save_checkpoint
76+
77+
# The following signature is used to match the original Megatron save_checkpoint interface,
78+
# but the wrapper will only use a subset of the arguments as handled below.
79+
def wrapped_save_checkpoint(
80+
iteration,
81+
model,
82+
optimizer,
83+
opt_param_scheduler,
84+
num_floating_point_operations_so_far,
85+
checkpointing_context=None,
86+
pipeline_rank=None,
87+
expert_rank=None,
88+
tensor_rank=None,
89+
pipeline_parallel=None,
90+
expert_parallel=None,
91+
non_persistent_ckpt=False,
92+
train_data_iterator=None,
93+
preprocess_common_state_dict_fn=None,
94+
release=False,
95+
):
96+
args = ctx.extra.get("backend_args", {})
97+
98+
if args.disable_last_saving and iteration == args.train_iters:
99+
log_rank_0(
100+
f"[Patch:megatron.checkpoint.save_checkpoint] Skip saving at the last iteration: {iteration}"
101+
)
102+
return
103+
104+
# Call the original save_checkpoint function with explicit keyword arguments for clarity.
105+
return original_save_checkpoint(
106+
iteration,
107+
model,
108+
optimizer,
109+
opt_param_scheduler,
110+
num_floating_point_operations_so_far,
111+
checkpointing_context=checkpointing_context,
112+
pipeline_rank=pipeline_rank,
113+
expert_rank=expert_rank,
114+
tensor_rank=tensor_rank,
115+
pipeline_parallel=pipeline_parallel,
116+
expert_parallel=expert_parallel,
117+
non_persistent_ckpt=non_persistent_ckpt,
118+
train_data_iterator=train_data_iterator,
119+
preprocess_common_state_dict_fn=preprocess_common_state_dict_fn,
120+
release=release,
121+
)
122+
123+
training_module.save_checkpoint = wrapped_save_checkpoint
124+
log_rank_0("[Patch:megatron.checkpoint.save_checkpoint] Patch save_checkpoint successfully.")
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
###############################################################################
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
3+
#
4+
# See LICENSE for license information.
5+
###############################################################################
6+
7+
"""
8+
Megatron Tokenizer Builder Patches
9+
10+
Override Megatron's build_tokenizer to use Primus version which properly
11+
handles custom tokenizer types (Llama2Tokenizer, Llama3Tokenizer, etc.)
12+
with HuggingFace Hub ID support.
13+
14+
Background:
15+
-----------
16+
Megatron's official _Llama2Tokenizer only supports local SentencePiece files,
17+
while Primus extends it to support HuggingFace Hub IDs (e.g., meta-llama/Llama-2-7b-hf).
18+
19+
Without this patch, the new architecture (PrimusRuntime) would call Megatron's
20+
official build_tokenizer, causing failures when using custom tokenizer types
21+
with Hub IDs.
22+
23+
This patch ensures both legacy and new architectures use the same tokenizer
24+
building logic.
25+
"""
26+
27+
from primus.core.patches import PatchContext, register_patch
28+
from primus.modules.module_utils import log_rank_0
29+
30+
31+
@register_patch(
32+
"megatron.tokenizer.build_tokenizer_override",
33+
backend="megatron",
34+
phase="setup",
35+
description="Override Megatron's build_tokenizer to support Primus custom tokenizer types with HuggingFace Hub IDs",
36+
)
37+
def patch_build_tokenizer_override(ctx: PatchContext):
38+
"""
39+
Monkey-patch Megatron's build_tokenizer with Primus version.
40+
41+
This ensures that custom tokenizer types (Llama2Tokenizer, Llama3Tokenizer,
42+
DeepSeekV2Tokenizer, etc.) are properly handled:
43+
44+
- All custom types use _HuggingFaceTokenizer internally
45+
- Support for HuggingFace Hub IDs (e.g., meta-llama/Llama-2-7b-hf)
46+
- Consistent behavior between legacy and new architectures
47+
48+
Without this patch:
49+
-------------------
50+
- tokenizer_type: Llama2Tokenizer
51+
tokenizer_model: meta-llama/Llama-2-7b-hf
52+
→ Calls Megatron's _Llama2Tokenizer
53+
→ Expects local file path
54+
→ ❌ FileNotFoundError
55+
56+
With this patch:
57+
----------------
58+
- tokenizer_type: Llama2Tokenizer
59+
tokenizer_model: meta-llama/Llama-2-7b-hf
60+
→ Calls Primus build_tokenizer
61+
→ Maps to _HuggingFaceTokenizer
62+
→ Supports Hub ID
63+
→ ✅ Success
64+
"""
65+
try:
66+
import megatron.training.global_vars as megatron_global_vars
67+
import pretrain_gpt
68+
except ImportError as e:
69+
log_rank_0(
70+
f"[Patch:megatron.tokenizer.build_tokenizer_override] "
71+
f"Skip patch (Megatron not available): {e}"
72+
)
73+
return
74+
75+
# Import Primus build_tokenizer
76+
from primus.backends.megatron.training.tokenizer.tokenizer import (
77+
build_tokenizer as primus_build_tokenizer,
78+
)
79+
80+
# Save original for reference (optional)
81+
if not hasattr(megatron_global_vars, "_original_build_tokenizer"):
82+
megatron_global_vars._original_build_tokenizer = megatron_global_vars.build_tokenizer
83+
if not hasattr(pretrain_gpt, "_original_build_tokenizer"):
84+
pretrain_gpt._original_build_tokenizer = pretrain_gpt.build_tokenizer
85+
86+
# Replace Megatron's build_tokenizer with Primus version
87+
megatron_global_vars.build_tokenizer = primus_build_tokenizer
88+
pretrain_gpt.build_tokenizer = primus_build_tokenizer
89+
90+
log_rank_0(
91+
"[Patch:megatron.tokenizer.build_tokenizer_override] "
92+
"✓ Replaced Megatron build_tokenizer with Primus version"
93+
)

primus/cli/subcommands/train.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def _resolve_pretrain_runtime(args) -> str:
1717
1818
Priority:
1919
1) Explicit env override via PRIMUS_TRAIN_RUNTIME
20-
2) Auto-detect by backend framework (TorchTitan -> core, others -> legacy)
20+
2) Auto-detect by backend framework (TorchTitan Megatron -> core, others -> legacy)
2121
"""
2222
runtime_entry = getenv("PRIMUS_TRAIN_RUNTIME", "").strip().lower()
2323
if runtime_entry in ("legacy", "core"):
@@ -38,7 +38,8 @@ def _resolve_pretrain_runtime(args) -> str:
3838
except Exception:
3939
framework = None
4040

41-
return "core" if framework == "torchtitan" else "legacy"
41+
supported_frameworks = ["torchtitan", "megatron"]
42+
return "core" if framework in supported_frameworks else "legacy"
4243

4344

4445
def run(args, overrides: List[str]):

0 commit comments

Comments
 (0)