Skip to content
5 changes: 5 additions & 0 deletions miles/backends/fsdp_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ class FSDPTrainRayActor(TrainRayActor):
def init(self, args: Namespace, role: str, with_ref: bool = False) -> int: # type: ignore[override]
super().init(args, role, with_ref)

if args.dumper_enable:
from sglang.srt.debug_utils.dumper import dumper

dumper.apply_source_patches()

# Setup ParallelState for both CP and non-CP cases
self.parallel_state = create_fsdp_parallel_state(args)

Expand Down
5 changes: 5 additions & 0 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ def init(

init(args)

if args.dumper_enable:
from sglang.srt.debug_utils.dumper import dumper

dumper.apply_source_patches()

self._is_main_rank = is_megatron_main_rank()

if self._is_main_rank:
Expand Down
3 changes: 2 additions & 1 deletion miles/backends/megatron_utils/arguments.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import logging
import os

from megatron.core.tokenizers.utils.build_tokenizer import vocab_size_with_padding as _vocab_size_with_padding
from megatron.training.arguments import parse_args, validate_args
from megatron.training.tokenizer.tokenizer import _vocab_size_with_padding

__all__ = ["validate_args", "parse_args", "set_default_megatron_args"]

Expand Down Expand Up @@ -32,4 +32,5 @@ def set_default_megatron_args(args):
logger.info("--tokenizer-model not set, use --hf-checkpoint as tokenizer model.")
args.tokenizer_model = args.hf_checkpoint
args.tokenizer_type = "HuggingFaceTokenizer"

return args
11 changes: 11 additions & 0 deletions miles/backends/megatron_utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from megatron.training.global_vars import get_args
from megatron.training.training import get_model

from miles.utils.dumper_utils import DumperMegatronUtil, DumperPhase
from miles.utils.memory_utils import clear_memory

from ..training_utils.ci_utils import check_grad_norm, check_kl
Expand Down Expand Up @@ -204,12 +205,16 @@ def forward_only(
Returns:
Aggregated outputs keyed by ``store_prefix + key``.
"""

dumper_phase_util = DumperMegatronUtil(args, model, DumperPhase.FWD_ONLY)

# reset data iterator
for iterator in data_iterator:
iterator.reset()

config = get_model_config(model[0])

@dumper_phase_util.wrap_forward_step
def forward_step(
data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False
) -> tuple[torch.Tensor, Callable[[torch.Tensor], dict[str, list[torch.Tensor]]]]:
Expand Down Expand Up @@ -301,6 +306,8 @@ def forward_step(
for model_module in model:
model_module.train()

dumper_phase_util.finalize(model)

rollout_data = {}
# Store the results on the last stage
if mpu.is_pipeline_last_stage():
Expand Down Expand Up @@ -340,6 +347,7 @@ def train_one_step(
Reduced loss dictionary (last stage only) and gradient norm for logging.
"""
args = get_args()
dumper_phase_util = DumperMegatronUtil(args, model, DumperPhase.FWD_BWD)

# Set grad to zero.
for model_chunk in model:
Expand All @@ -352,6 +360,7 @@ def train_one_step(
custom_before_train_step_hook = load_function(args.custom_megatron_before_train_step_hook_path)
custom_before_train_step_hook(args, rollout_id, step_id, model, optimizer, opt_param_scheduler)

@dumper_phase_util.wrap_forward_step
def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_plan: bool = False) -> tuple[
torch.Tensor,
Callable[[torch.Tensor], tuple[torch.Tensor, int, dict[str, torch.Tensor | list[str]]]],
Expand Down Expand Up @@ -479,6 +488,8 @@ def forward_step(data_iterator: DataIterator, model: GPTModel, return_schedule_p
model_chunk.zero_grad_buffer()
optimizer.zero_grad()

dumper_phase_util.finalize(model)

if mpu.is_pipeline_last_stage(ignore_virtual=True):
loss_reduced = aggregate_train_losses(losses_reduced, parallel_state)
return loss_reduced, grad_norm
Expand Down
26 changes: 26 additions & 0 deletions miles/backends/training_utils/cp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,32 @@ def pad_tokens(tokens, pad):
return torch.cat([tokens[start_1:end_1], tokens[start_2:end_2]])


def natural_to_zigzag_slice(
tensor: torch.Tensor, dim: int, cp_size: int, cp_rank: int
) -> torch.Tensor:
"""Slice a full-length tensor into the zigzag ring-attention CP layout.

Rank ``cp_rank`` owns chunks ``[cp_rank, 2*cp_size - 1 - cp_rank]`` from the
``2*cp_size`` equal-sized partitions along ``dim``. This is the inverse of
an all-gather over the zigzag CP layout (hence "natural → zigzag").

Unlike :func:`slice_with_cp`, this helper does not pad — it expects the
input to already be divisible by ``2 * cp_size`` along ``dim``. If not, it
prints a warning and returns the tensor unchanged.
"""
total = tensor.shape[dim]
num_chunks = 2 * cp_size
if total % num_chunks != 0:
print(f"Warning: dim {dim} size {total} not divisible by 2*cp_size={num_chunks}")
return tensor

chunk_size = total // num_chunks
chunk_indices = [cp_rank, 2 * cp_size - 1 - cp_rank]

slices = [tensor.narrow(dim, idx * chunk_size, chunk_size) for idx in chunk_indices]
return torch.cat(slices, dim=dim)


def _allgather_cp_redistribute(
res: dict[str, list[torch.Tensor]],
*,
Expand Down
3 changes: 3 additions & 0 deletions miles/ray/actor_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def _allocate_gpus_for_actor(self, pg, num_gpus_per_actor):
**self.args.train_env_vars,
}

if source_patcher_config := self.args.dumper_source_patcher_config_train:
env_vars["DUMPER_SOURCE_PATCHER_CONFIG"] = source_patcher_config

if self.args.offload_train and self.args.train_backend == "megatron":
import torch_memory_saver

Expand Down
3 changes: 2 additions & 1 deletion miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
call_rollout_fn,
)
from miles.rollout.inference_rollout.compatibility import call_rollout_function, load_rollout_function
from miles.utils import tracking_utils
from miles.utils import dumper_utils, tracking_utils
from miles.utils.environ import enable_experimental_rollout_refactor
from miles.utils.health_monitor import RolloutHealthMonitor
from miles.utils.http_utils import _wrap_ipv6, find_available_port, get_host_info, init_http_client
Expand Down Expand Up @@ -131,6 +131,7 @@ def start_engines(self, port_cursors: dict[int, int] | None = None) -> tuple[lis
"SGLANG_ENABLE_STRICT_MEM_CHECK_DURING_IDLE": "false",
}.items()
}
env_vars.update(dumper_utils.get_sglang_env(self.args))

rollout_engine = RolloutRayActor.options(
num_cpus=num_cpus,
Expand Down
3 changes: 3 additions & 0 deletions miles/rollout/inference_rollout/inference_rollout_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from miles.rollout.base_types import RolloutFnTrainOutput
from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter
from miles.rollout.inference_rollout.inference_rollout_common import GenerateState, generate_and_rm_group
from miles.utils import dumper_utils
from miles.utils.http_utils import get, post
from miles.utils.misc import as_completed_async, load_function
from miles.utils.types import Sample
Expand Down Expand Up @@ -75,6 +76,8 @@ async def generate_rollout_async(
args = state.args
assert args.rollout_global_dataset

await dumper_utils.configure_sglang(args)

# instantiate data filters
dynamic_filter = load_function(args.dynamic_sampling_filter_path)

Expand Down
3 changes: 3 additions & 0 deletions miles/rollout/sglang_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from miles.backends.megatron_utils.lora_utils import LORA_ADAPTER_NAME, is_lora_enabled
from miles.rollout.base_types import RolloutFnEvalOutput, RolloutFnTrainOutput
from miles.rollout.filter_hub.base_types import MetricGatherer, call_dynamic_filter
from miles.utils import dumper_utils
from miles.utils.async_utils import run
from miles.utils.data import Dataset
from miles.utils.eval_config import EvalDatasetConfig
Expand Down Expand Up @@ -374,6 +375,8 @@ async def generate_rollout_async(
"""
assert args.rollout_global_dataset

await dumper_utils.configure_sglang(args)

state = GenerateState(args)

# instantiate data filters
Expand Down
131 changes: 131 additions & 0 deletions miles/utils/argparse_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Dataclass ↔ argparse bridge.

Supported field types: str, int, float, Path, bool, and their ``X | None``
variants (except ``bool | None`` which is not supported).
"""

from __future__ import annotations

import argparse
import dataclasses
import types
from pathlib import Path
from typing import Generic, TypeVar, get_type_hints

T = TypeVar("T")

_SCALAR_TYPES: dict[type, type] = {str: str, int: int, float: float, Path: Path}


def _is_bool(tp: type) -> bool:
return tp is bool


def _is_optional(tp: type) -> tuple[bool, type | None]:
if not isinstance(tp, types.UnionType):
return False, None

args: tuple[type, ...] = tp.__args__
non_none: list[type] = [a for a in args if a is not type(None)]
if type(None) not in args or len(non_none) != 1:
return False, None

return True, non_none[0]


def _resolve_default(field: dataclasses.Field[object]) -> object:
"""Return the effective default for a dataclass field, or ``MISSING``."""
if field.default is not dataclasses.MISSING:
return field.default
if field.default_factory is not dataclasses.MISSING: # type: ignore[misc]
return field.default_factory() # type: ignore[misc]
return dataclasses.MISSING


class DataclassArgparseBridge(Generic[T]):
"""Bi-directional converter: dataclass ↔ argparse.

*prefix* controls the CLI flag prefix: ``"script"`` → ``--script-field-name``,
``""`` → ``--field-name``.
"""

def __init__(
self,
dataclass_type: type[T],
*,
prefix: str,
group_title: str | None = None,
) -> None:
if not dataclasses.is_dataclass(dataclass_type):
raise TypeError(f"{dataclass_type!r} is not a dataclass")

self._cls: type[T] = dataclass_type
self._prefix: str = prefix
self._group_title: str = group_title or f"{dataclass_type.__name__} args"
self._hints: dict[str, type] = get_type_hints(dataclass_type)

def _flag(self, field_name: str) -> str:
stem: str = field_name.replace("_", "-")
if self._prefix:
return f"--{self._prefix}-{stem}"
return f"--{stem}"

def _dest(self, field_name: str) -> str:
if self._prefix:
return f"{self._prefix}_{field_name}"
return field_name

def register_on_parser(self, parser: argparse.ArgumentParser) -> argparse.ArgumentParser:
group: argparse._ArgumentGroup = parser.add_argument_group(self._group_title)

for field in dataclasses.fields(self._cls):
flag: str = self._flag(field.name)
dest: str = self._dest(field.name)
tp: type = self._hints[field.name]

if _is_bool(tp):
group.add_argument(flag, dest=dest, action="store_true", default=False)
continue

is_opt, inner = _is_optional(tp)
if is_opt:
if inner not in _SCALAR_TYPES:
raise TypeError(f"Unsupported optional inner type {inner!r} for field {field.name}")
group.add_argument(flag, dest=dest, type=_SCALAR_TYPES[inner], default=None)
continue

if tp in _SCALAR_TYPES:
default: object = _resolve_default(field)
kwargs: dict[str, object] = {"dest": dest, "type": _SCALAR_TYPES[tp]}
if default is not dataclasses.MISSING:
kwargs["default"] = default
else:
kwargs["required"] = True
group.add_argument(flag, **kwargs) # type: ignore[arg-type]
continue

raise TypeError(f"Unsupported field type {tp!r} for field {field.name}")

return parser

def from_namespace(self, namespace: argparse.Namespace) -> T:
kwargs: dict[str, object] = {}
for field in dataclasses.fields(self._cls):
kwargs[field.name] = getattr(namespace, self._dest(field.name))
return self._cls(**kwargs) # type: ignore[call-arg]

def to_cli_args(self, instance: T) -> str:
parts: list[str] = []

for field in dataclasses.fields(self._cls): # type: ignore[arg-type]
value: object = getattr(instance, field.name)
flag: str = self._flag(field.name)
tp: type = self._hints[field.name]

if _is_bool(tp):
if value:
parts.append(flag)
elif value is not None:
parts.append(f"{flag} {value}")

return " ".join(parts)
Loading
Loading