2121from pydantic import BaseModel , ConfigDict , field_serializer , field_validator , model_validator
2222from torch .distributed import init_process_group
2323from torch .distributed .device_mesh import init_device_mesh
24- from torch .optim .lr_scheduler import CosineAnnealingLR , LambdaLR , LinearLR , SequentialLR
24+ from torch .optim .lr_scheduler import CosineAnnealingLR , LambdaLR , LinearLR , LRScheduler , SequentialLR
2525from typing_extensions import NotRequired , Self , TypedDict
2626
2727from transformers import AutoTokenizer , PreTrainedTokenizer , PreTrainedTokenizerFast
@@ -764,7 +764,7 @@ def build_engine(
764764 engine .model .set_hf (model_path )
765765 return engine
766766
767- def build_lr_scheduler (self , lr_cfg : LRConfig , scheduler_step : int ) -> torch . optim . lr_scheduler . LRScheduler :
767+ def build_lr_scheduler (self , lr_cfg : LRConfig , scheduler_step : int ) -> LRScheduler :
768768 """Build the learning rate scheduler.
769769
770770 Args:
@@ -774,36 +774,49 @@ def build_lr_scheduler(self, lr_cfg: LRConfig, scheduler_step: int) -> torch.opt
774774 torch.optim.lr_scheduler.LRScheduler: Configured learning rate scheduler.
775775 """
776776 if lr_cfg .warmup_ratio < 1 :
777- warmup_steps = int (lr_cfg .warmup_ratio * scheduler_step )
777+ warmup_step = int (lr_cfg .warmup_ratio * scheduler_step )
778778 else :
779- warmup_steps = int (lr_cfg .warmup_ratio )
779+ warmup_step = int (lr_cfg .warmup_ratio )
780780
781781 def warmup_fn (x ):
782- return x / warmup_steps if x < warmup_steps else 1
782+ return x / warmup_step if x < warmup_step else 1
783783
784784 warmup_scheduler = LambdaLR (self ._engine .optimizer , warmup_fn )
785785
786- scheduler : torch .optim .lr_scheduler .LRScheduler
787- if lr_cfg .lr_type == "linear" :
788- scheduler = LinearLR (
789- self ._engine .optimizer ,
790- start_factor = 1.0 ,
791- end_factor = lr_cfg .lr_min / self ._engine .optimizer .defaults ["lr" ],
792- total_iters = scheduler_step - warmup_steps ,
786+ scheduler_after_warmup : LRScheduler
787+ lr_scheduler : LRScheduler
788+
789+ if warmup_step < scheduler_step :
790+ if lr_cfg .lr_type == "linear" :
791+ scheduler_after_warmup = LinearLR (
792+ self ._engine .optimizer ,
793+ start_factor = 1.0 ,
794+ end_factor = lr_cfg .lr_min / self ._engine .optimizer .defaults ["lr" ],
795+ total_iters = scheduler_step - warmup_step ,
796+ )
797+ elif lr_cfg .lr_type == "cosine" :
798+ scheduler_after_warmup = CosineAnnealingLR (
799+ self ._engine .optimizer , T_max = scheduler_step - warmup_step , eta_min = lr_cfg .lr_min
800+ )
801+ elif lr_cfg .lr_type == "constant" :
802+ scheduler_after_warmup = LambdaLR (self ._engine .optimizer , lambda x : 1.0 )
803+ else :
804+ raise ValueError (f"Unsupported lr type: { lr_cfg .lr_type } " )
805+ lr_scheduler = SequentialLR (
806+ optimizer = self ._engine .optimizer ,
807+ schedulers = [warmup_scheduler , scheduler_after_warmup ],
808+ milestones = [warmup_step ],
793809 )
794- elif lr_cfg .lr_type == "cosine" :
795- scheduler = CosineAnnealingLR (
796- self ._engine .optimizer , T_max = scheduler_step - warmup_steps , eta_min = lr_cfg .lr_min
810+ elif warmup_step == scheduler_step :
811+ self .logger .warning (
812+ f"You're setting warmup_step ({ warmup_step } to be equal to scheduler_step ({ scheduler_step } ), "
813+ "which is generally not recommended."
797814 )
798- elif lr_cfg .lr_type == "constant" :
799- scheduler = LambdaLR (self ._engine .optimizer , lambda x : 1.0 )
815+ lr_scheduler = warmup_scheduler
800816 else :
801- raise ValueError (f"Unsupported lr type: { lr_cfg .lr_type } " )
802- lr_scheduler = SequentialLR (
803- optimizer = self ._engine .optimizer ,
804- schedulers = [warmup_scheduler , scheduler ],
805- milestones = [warmup_steps ],
806- )
817+ raise ValueError (
818+ f"Expected warmup_step ({ warmup_step } ) to be no more than scheduler_step ({ scheduler_step } )"
819+ )
807820 return lr_scheduler
808821
809822 def _maybe_save (self , is_snapshot : bool = False ) -> bool :
0 commit comments