Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1981,6 +1981,26 @@ def miles_validate_args(args):
args.use_dynamic_batch_size is False
), "Dynamic batch size is not supported for bshd format. Please specify --micro-batch-size instead."

_maybe_apply_dumper_overrides(args)


def _maybe_apply_dumper_overrides(args) -> None:
if not args.dumper_enable:
return

if args.use_fault_tolerance:
logger.info("Dumper mode: disabling --use-fault-tolerance to suppress RolloutHealthMonitor heartbeats")
args.use_fault_tolerance = False

logger.info("Dumper mode: all heartbeat mechanisms disabled")
args.router_disable_health_check = True
args.rollout_health_check_interval = 1e18

logger.info("Dumper mode: forced num_rollout=%d, disabled eval and save", args.num_rollout)
args.num_rollout = (args.start_rollout_id or 0) + 1
Comment on lines +1999 to +2000
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The log message on line 1950 is misleading because it logs the value of args.num_rollout before it is updated on the next line. This can cause confusion during debugging, as the logged value will not match the value actually used. To ensure the log reflects the correct value, the assignment should happen before the logging statement.

Suggested change
logger.info("Dumper mode: forced num_rollout=%d, disabled eval and save", args.num_rollout)
args.num_rollout = (args.start_rollout_id or 0) + 1
args.num_rollout = (args.start_rollout_id or 0) + 1
logger.info("Dumper mode: forced num_rollout=%d, disabled eval and save", args.num_rollout)

args.eval_interval = None
args.save_interval = None


def hf_validate_args(args, hf_config):
def equal(x, y):
Expand Down
69 changes: 68 additions & 1 deletion tests/fast/utils/test_arguments.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import argparse
import sys
from types import SimpleNamespace
from unittest.mock import patch

import pytest

from miles.utils.arguments import get_miles_extra_args_provider
from miles.utils.arguments import _maybe_apply_dumper_overrides, get_miles_extra_args_provider
from miles.utils.misc import function_registry

PATH_ARGS = ["--rollout-function-path", "--custom-generate-function-path"]
Expand Down Expand Up @@ -56,3 +57,69 @@ def test_skips_function_without_add_arguments(self, path_arg):
):
parser = argparse.ArgumentParser()
get_miles_extra_args_provider()(parser)


class TestMaybeApplyDumperOverrides:
def _make_args(
self,
*,
dumper_enable: bool = False,
use_fault_tolerance: bool = False,
router_disable_health_check: bool = False,
rollout_health_check_interval: float = 30.0,
start_rollout_id: int | None = None,
num_rollout: int = 10,
eval_interval: int | None = 5,
save_interval: int | None = 5,
) -> SimpleNamespace:
return SimpleNamespace(
dumper_enable=dumper_enable,
use_fault_tolerance=use_fault_tolerance,
router_disable_health_check=router_disable_health_check,
rollout_health_check_interval=rollout_health_check_interval,
start_rollout_id=start_rollout_id,
num_rollout=num_rollout,
eval_interval=eval_interval,
save_interval=save_interval,
)

def test_noop_when_dumper_disabled(self) -> None:
args = self._make_args(
dumper_enable=False,
use_fault_tolerance=True,
rollout_health_check_interval=30.0,
)
_maybe_apply_dumper_overrides(args)

assert args.use_fault_tolerance is True
assert args.router_disable_health_check is False
assert args.rollout_health_check_interval == 30.0
assert args.num_rollout == 10
assert args.eval_interval == 5
assert args.save_interval == 5

def test_disables_all_heartbeats(self) -> None:
args = self._make_args(
dumper_enable=True,
use_fault_tolerance=True,
rollout_health_check_interval=30.0,
)
_maybe_apply_dumper_overrides(args)

assert args.use_fault_tolerance is False
assert args.router_disable_health_check is True
assert args.rollout_health_check_interval == 1e18

def test_forces_single_rollout(self) -> None:
args = self._make_args(dumper_enable=True, num_rollout=100)
_maybe_apply_dumper_overrides(args)

assert args.num_rollout == 1
assert args.eval_interval is None
assert args.save_interval is None

def test_respects_start_rollout_id(self) -> None:
args = self._make_args(dumper_enable=True, start_rollout_id=5, num_rollout=100)
_maybe_apply_dumper_overrides(args)

assert args.num_rollout == 6
Loading