From e6c82bd895ed6bcd881bf802fe6d31c053758da7 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 4 Jul 2025 00:50:25 +0200 Subject: [PATCH 1/7] refactor: app_state can now be partially loaded --- .../checkpointing/stateful/app_state.py | 62 +++++++++++++------ .../stateful/app_state_factory.py | 19 +++++- 2 files changed, 62 insertions(+), 19 deletions(-) diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 6f42074cf..1b5a583b9 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -15,6 +15,8 @@ from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler +from modalities.utils.logging import get_logger + class StatefulComponents(Enum): MODEL = "model" @@ -34,13 +36,19 @@ class AppState(Stateful): https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html """ - def __init__(self, model: nn.Module, optimizer: Optimizer, lr_scheduler: Optional[LRScheduler] = None): + def __init__( + self, + model: Optional[nn.Module] = None, + optimizer: Optional[Optimizer] = None, + lr_scheduler: Optional[LRScheduler] = None, + ): """Initializes the AppState object. Args: - model (nn.Module): The model can be either a non-sharded model, FSDP1 or FSDP2 model. - optimizer (Optimizer): The optimizer can be either a non-sharded optimizer, FSDP1 or FSDP2 optimizer. - lr_scheduler (Optional[LRScheduler], optional): The lr scheduler used during training. Defaults to None. + model (nn.Module, optional): The model can be either a non-sharded model, FSDP1 or FSDP2 model. + optimizer (Optimizer, optional): The optimizer can be either a non-sharded optimizer, + FSDP1 or FSDP2 optimizer. + lr_scheduler (LRScheduler, optional): The lr scheduler used during training. Defaults to None. """ self._model = model self._optimizer = optimizer @@ -76,12 +84,13 @@ def state_dict(self) -> dict[str, Any]: # this line automatically manages FSDP FQN's, as well as sets the default # state dict type to FSDP.SHARDED_STATE_DICT # model_state_dict, optimizer_state_dict = get_state_dict(self._model, self._optimizer) - sd = { - StatefulComponents.MODEL.value: ModelStateRetriever.get_state_dict(app_state=self), - StatefulComponents.OPTIMIZER.value: OptimizerStateRetriever.get_state_dict( - app_state=self, - ), - } + sd = {} + if self._model is not None: + sd[StatefulComponents.MODEL.value] = ModelStateRetriever.get_state_dict(app_state=self) + + if self._optimizer is not None: + sd[StatefulComponents.OPTIMIZER.value] = OptimizerStateRetriever.get_state_dict(app_state=self) + if self._lr_scheduler is not None: sd[StatefulComponents.LR_SCHEDULER.value] = LRSchedulerStateRetriever.get_state_dict(app_state=self) return sd @@ -101,15 +110,32 @@ def load_state_dict(self, state_dict: dict[str, Any]) -> None: "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded." ) - ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value]) - OptimizerStateRetriever.load_state_dict_( - app_state=self, - state_dict=state_dict[StatefulComponents.OPTIMIZER.value], - ) + if self._model is not None: + ModelStateRetriever.load_state_dict_(app_state=self, state_dict=state_dict[StatefulComponents.MODEL.value]) + + if self._optimizer is not None: + if StatefulComponents.OPTIMIZER.value in state_dict: + OptimizerStateRetriever.load_state_dict_( + app_state=self, + state_dict=state_dict[StatefulComponents.OPTIMIZER.value], + ) + else: + get_logger(name="app_state").warning( + "Did not load optimizer checkpoint! " + f"Optimizer state dict not found in state_dict: {state_dict.keys()}." + ) + if self._lr_scheduler is not None: - LRSchedulerStateRetriever.load_state_dict_( - app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value] - ) + if StatefulComponents.LR_SCHEDULER.value in state_dict: + LRSchedulerStateRetriever.load_state_dict_( + app_state=self, state_dict=state_dict[StatefulComponents.LR_SCHEDULER.value] + ) + else: + get_logger(name="app_state").warning( + "Did not load lr scheduler checkpoint! " + f"LR scheduler state dict not found in state_dict: {state_dict.keys()}." + ) + self._is_loaded = True diff --git a/src/modalities/checkpointing/stateful/app_state_factory.py b/src/modalities/checkpointing/stateful/app_state_factory.py index bad48d44c..c504855ed 100644 --- a/src/modalities/checkpointing/stateful/app_state_factory.py +++ b/src/modalities/checkpointing/stateful/app_state_factory.py @@ -35,6 +35,9 @@ def get_raw_app_state( def get_dcp_checkpointed_app_state_( raw_app_state: AppState, checkpoint_dir_path: Path, + load_model_checkpoint: bool = True, + load_optimizer_checkpoint: bool = True, + load_lr_scheduler_checkpoint: bool = True, ) -> AppState: """Loads the checkpointed state dict into the raw AppState object (i.e., non-checkpoint loaded AppState) in-place. @@ -54,5 +57,19 @@ def get_dcp_checkpointed_app_state_( "Cannot call load_state_dict twice on the same AppState object. " "State dict has already been loaded." ) cp_loading = DCPCheckpointLoading(global_rank=dist.get_rank()) - cp_loading.load_checkpoint_(app_state=raw_app_state, checkpoint_dir_path=checkpoint_dir_path) + + tmp_app_state = AppStateFactory.get_raw_app_state( + model=raw_app_state.model if load_model_checkpoint else None, + optimizer=raw_app_state.optimizer if load_optimizer_checkpoint else None, + lr_scheduler=raw_app_state.lr_scheduler if load_lr_scheduler_checkpoint else None, + ) + + cp_loading.load_checkpoint_(app_state=tmp_app_state, checkpoint_dir_path=checkpoint_dir_path) + raw_app_state.model = tmp_app_state.model if tmp_app_state.model is not None else raw_app_state.model + raw_app_state.optimizer = ( + tmp_app_state.optimizer if tmp_app_state.optimizer is not None else raw_app_state.optimizer + ) + raw_app_state.lr_scheduler = ( + tmp_app_state.lr_scheduler if tmp_app_state.lr_scheduler is not None else raw_app_state.lr_scheduler + ) return raw_app_state From b9e7baf69589ef7da79bf95f2273c6ba62b74ae3 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 4 Jul 2025 00:54:33 +0200 Subject: [PATCH 2/7] refactor: collate fn is now composable --- src/modalities/config/config.py | 17 +-- src/modalities/config/pydantic_if_types.py | 3 +- .../dataloader/collate_fns/__init__.py | 1 + .../dataloader/collate_fns/collate_if.py | 19 +++- .../dataloader/collate_fns/collator.py | 100 ++++++++++++++++++ ..._masking.py => loss_masking_collate_fn.py} | 33 +++--- .../dataloader/dataloader_factory.py | 9 +- src/modalities/models/gpt2/collator.py | 36 ------- src/modalities/registry/components.py | 24 +++-- tests/dataloader/test_packed_dataset.py | 4 +- tests/instruction_tuning/test_loss_masking.py | 6 +- 11 files changed, 169 insertions(+), 83 deletions(-) create mode 100644 src/modalities/dataloader/collate_fns/__init__.py create mode 100644 src/modalities/dataloader/collate_fns/collator.py rename src/modalities/dataloader/collate_fns/{collator_fn_wrapper_for_loss_masking.py => loss_masking_collate_fn.py} (84%) delete mode 100644 src/modalities/models/gpt2/collator.py diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index b1292af0d..443cdb136 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -16,7 +16,7 @@ PydanticAppStateType, PydanticCheckpointSavingExecutionIFType, PydanticCheckpointSavingStrategyIFType, - PydanticCollateFnIFType, + PydanticCollatorIFType, PydanticDatasetIFType, PydanticDeviceMeshIFType, PydanticFSDP1CheckpointLoadingIFType, @@ -313,6 +313,9 @@ class RawAppStateConfig(BaseModel): class DCPAppStateConfig(BaseModel): raw_app_state: PydanticAppStateType checkpoint_dir_path: Path + load_model_checkpoint: bool = True + load_optimizer_checkpoint: bool = True + load_lr_scheduler_checkpoint: bool = True class PreTrainedHFTokenizerConfig(BaseModel): @@ -366,6 +369,11 @@ class PackedMemMapDatasetContinuousConfig(BaseModel): reuse_last_target: bool = Field(default=True) +class MemMapDatasetIterativeConfig(BaseModel): + raw_data_path: Path + sample_key: str + + class PackedMemMapDatasetMegatronConfig(BaseModel): raw_data_path: Path block_size: Annotated[int, Field(strict=True, gt=1)] @@ -382,16 +390,11 @@ class BatchSamplerConfig(BaseModel): drop_last: Literal[True] = True -class GPT2LLMCollateFnConfig(BaseModel): - sample_key: str - target_key: str - - class LLMDataLoaderConfig(BaseModel): dataloader_tag: str dataset: PydanticDatasetIFType batch_sampler: PydanticSamplerIFType - collate_fn: Optional[PydanticCollateFnIFType] = None + collator: Optional[PydanticCollatorIFType] = None num_workers: Annotated[int, Field(strict=True, ge=0)] pin_memory: bool diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index d7a4162d8..d02fad09d 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -16,7 +16,7 @@ from modalities.checkpointing.checkpoint_saving import CheckpointSaving, CheckpointSavingExecutionABC from modalities.checkpointing.checkpoint_saving_strategies import CheckpointSavingStrategyIF from modalities.checkpointing.stateful.app_state import AppState -from modalities.dataloader.collate_fns.collate_if import CollateFnIF +from modalities.dataloader.collate_fns.collate_if import CollateFnIF, CollatorIF from modalities.dataloader.dataloader import LLMDataLoader from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF @@ -67,6 +67,7 @@ def __get_pydantic_core_schema__( PydanticDatasetIFType = Annotated[Dataset, PydanticThirdPartyTypeIF(Dataset)] PydanticSamplerIFType = Annotated[Sampler, PydanticThirdPartyTypeIF(Sampler)] PydanticCollateFnIFType = Annotated[CollateFnIF, PydanticThirdPartyTypeIF(CollateFnIF)] +PydanticCollatorIFType = Annotated[CollatorIF, PydanticThirdPartyTypeIF(CollatorIF)] PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)] PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)] PydanticLRSchedulerIFType = Annotated[LRScheduler, PydanticThirdPartyTypeIF(LRScheduler)] diff --git a/src/modalities/dataloader/collate_fns/__init__.py b/src/modalities/dataloader/collate_fns/__init__.py new file mode 100644 index 000000000..8b1378917 --- /dev/null +++ b/src/modalities/dataloader/collate_fns/__init__.py @@ -0,0 +1 @@ + diff --git a/src/modalities/dataloader/collate_fns/collate_if.py b/src/modalities/dataloader/collate_fns/collate_if.py index 88e851132..b954e12fb 100644 --- a/src/modalities/dataloader/collate_fns/collate_if.py +++ b/src/modalities/dataloader/collate_fns/collate_if.py @@ -5,11 +5,28 @@ from modalities.batch import DatasetBatch +class CollatorIF(ABC): + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: + """ + Process a batch of data. + + Args: + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing 1-dim tensors. + + Returns: + DatasetBatch: The processed batch of data. + + Raises: + NotImplementedError: This abstract method should be implemented in a subclass. + """ + raise NotImplementedError + + class CollateFnIF(ABC): """CollateFnIF class to define a collate function interface.""" @abstractmethod - def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ Process a batch of data. diff --git a/src/modalities/dataloader/collate_fns/collator.py b/src/modalities/dataloader/collate_fns/collator.py new file mode 100644 index 000000000..597d94c0b --- /dev/null +++ b/src/modalities/dataloader/collate_fns/collator.py @@ -0,0 +1,100 @@ +from typing import Optional + +import torch +from pydantic import BaseModel, model_validator + +from modalities.batch import DatasetBatch +from modalities.config.pydantic_if_types import PydanticCollateFnIFType +from modalities.dataloader.collate_fns.collate_if import CollateFnIF, CollatorIF + + +class DefaultWrappingCollatorConfig(BaseModel): + input_keys: list[str] + sample_keys: list[str] + target_keys: list[str] + collate_fns: list[PydanticCollateFnIFType] = None + sequence_length: int = None + padding_token_id: int = None + + @model_validator(mode="after") + def validate_sequence_length_and_padding(self): + if self.sequence_length is None != self.padding_token_id is None: + raise ValueError("If sequence_length is set, padding_token_id must also be set.") + return self + + +class DefaultWrappingCollator(CollatorIF): + """DefaultWrappingCollator class to define a collate function that pads and + truncates sequences to a fixed length and applies passed collate functions + sequentially.""" + + def __init__( + self, + input_keys: list[str], + sample_keys: list[str], + target_keys: list[str], + collate_fns: Optional[list[CollateFnIF]] = None, + sequence_length: Optional[int] = None, + padding_token_id: Optional[int] = None, + ): + """ + Initializes the Collator object. + + Args: + input_keys (list[str]): List of keys for the input data. + sample_keys (list[str]): List of keys for the resulting sample data. + target_keys (list[str]): List of keys for the resulting target data. + collate_fns (list[CollateFnIF], optional): List of wrapped collate functions to apply sequentially. + Defaults to None. + sequence_length (int, optional): Fixed sequence length for padding/truncating. Defaults to None. + padding_token_id (int, optional): Token ID used for padding. Defaults to None. + + Raises: + ValueError: If sequence_length is set but padding_token_id is not set or vice versa. + """ + self.input_keys = input_keys + self.sampple_keys = sample_keys + self.target_keys = target_keys + self.collate_fns = collate_fns if collate_fns is not None else [] + self.sequence_length = sequence_length + self.padding_token_id = padding_token_id + if sequence_length is None != padding_token_id is None: + raise ValueError("If sequence_length is set, padding_token_id must also be set.") + + def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: + """Process a batch of data by calling the wrapped collate function. + + Args: + batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing 1-dim tensors. + + Returns: + DatasetBatch: The processed batch of data. + """ + if self.sequence_length is not None and self.padding_token_id is not None: + # Pad and truncate the sequences in the batch to the fixed sequence length + self._pad_and_truncate_inplace(batch) + + sample_tensor_dict = {key: torch.stack([torch.tensor(d[key]) for d in batch]) for key in self.sampple_keys} + for wrapped_collate_fn in self.collate_fns: + sample_tensor_dict = wrapped_collate_fn(sample_tensor_dict) + + samples = {sample_key: sample_tensor_dict[sample_key] for sample_key in self.sampple_keys} + targets = {target_key: sample_tensor_dict[target_key] for target_key in self.target_keys} + return DatasetBatch(targets=targets, samples=samples) + + def _pad_and_truncate_inplace(self, batch: list[dict[str, torch.Tensor]]) -> torch.Tensor: + for sample in batch: + for key in sample.keys(): + seq = sample[key] + if seq.dim() != 1: + raise ValueError( + f"Expected tensor with at least one dimension, got {seq.dim()} dimensions for key '{key}'." + ) + + # Truncate or pad to fixed sequence length + if seq.size(0) > self.sequence_length: + seq = seq[: self.sequence_length] + elif seq.size(0) < self.sequence_length: + padding = torch.full((self.sequence_length - seq.size(0),), self.padding_token_id, dtype=seq.dtype) + seq = torch.cat([seq, padding], dim=0) + sample[key] = seq diff --git a/src/modalities/dataloader/collate_fns/collator_fn_wrapper_for_loss_masking.py b/src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py similarity index 84% rename from src/modalities/dataloader/collate_fns/collator_fn_wrapper_for_loss_masking.py rename to src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py index 5e0b75d5a..31cabbfbb 100644 --- a/src/modalities/dataloader/collate_fns/collator_fn_wrapper_for_loss_masking.py +++ b/src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py @@ -1,10 +1,9 @@ -from typing import Dict, List +from typing import List import torch from pydantic import BaseModel -from modalities.batch import DatasetBatch -from modalities.config.pydantic_if_types import PydanticCollateFnIFType, PydanticTokenizerIFType +from modalities.config.pydantic_if_types import PydanticTokenizerIFType from modalities.dataloader.collate_fns.collate_if import CollateFnIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.util import warn_rank_0 @@ -15,18 +14,16 @@ class LossMaskingTokenConfig(BaseModel): e_include_to_loss_token: str -class LossMaskingCollateFnWrapperConfig(BaseModel): - wrapped_collate_fn: PydanticCollateFnIFType +class LossMaskingCollateFnConfig(BaseModel): target_keys_to_mask: List[str] loss_ignore_index: int mask_tokens: LossMaskingTokenConfig tokenizer: PydanticTokenizerIFType -class LossMaskingCollateFnWrapper(CollateFnIF): +class LossMaskingCollateFn(CollateFnIF): def __init__( self, - wrapped_collate_fn: CollateFnIF, target_keys_to_mask: List[str], loss_ignore_index: int, mask_tokens: LossMaskingTokenConfig, @@ -34,7 +31,7 @@ def __init__( ): """ Initializes the LossMaskingCollateFnWrapper. - Wraps the given wrapped_collate_fn and masks the target keys if not within the given special mask tokens. + The colate function masks the target keys if not within the given special mask tokens. Does not include both mask tokens into the loss. If you need a token to indicate the end of the assistant, use another special token for this! Works also for the continuous dataset reading, as if the "end-include-to-loss" token is detected in the front, @@ -44,7 +41,6 @@ def __init__( Args: - wrapped_collate_fn (CollateFnIF): The wrapped collate function. target_keys_to_mask (List[str]): The list of target keys to mask. loss_ignore_index (int): The index to ignore in the loss calculation. mask_tokens (MaskingTokenConfig): Entails begin and end tokens, which mark (exclusive) inclusion to the @@ -54,7 +50,6 @@ def __init__( Raises: ValueError: If b_mask_token_id and e_mask_token_id are the same. """ - self.wrapped_collate_fn = wrapped_collate_fn self.target_keys_to_mask = target_keys_to_mask self.loss_ignore_index = loss_ignore_index self.tokenizer = tokenizer @@ -65,30 +60,30 @@ def __init__( "b_mask_token_id and e_mask_token_id of the LossMaskingCollateFnWrapper must be different!" ) - def __call__(self, batch: List[Dict[str, torch.Tensor]]) -> DatasetBatch: + def __call__(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ - Collates a batch of data by calling the wrapped collate function and applies target masking. + Collates a batch of data by applying target masking. Args: - batch (List[Dict[str, torch.Tensor]]): A list of dictionaries, where each dictionary represents a sample - in the batch. Each dictionary contains keys corresponding to different data modalities and their + batch (dict[str, torch.Tensor]): The batch contains keys corresponding to different + data modalities and their respective tensors. Returns: - DatasetBatch: A batch of collated data with masked targets. + dict[str, torch.Tensor]: A batch dict with masked targets. """ - dataset_batch = self.wrapped_collate_fn(batch) + for target_key_to_mask in self.target_keys_to_mask: - target = dataset_batch.targets[target_key_to_mask] + target = batch[target_key_to_mask] masked_target = self._mask_target( target=target, b_mask_token_id=self.b_mask_token_id, e_mask_token_id=self.e_mask_token_id, loss_ignore_index=self.loss_ignore_index, ) - dataset_batch.targets[target_key_to_mask] = masked_target - return dataset_batch + batch[target_key_to_mask] = masked_target + return batch def _mask_target( self, target: torch.Tensor, b_mask_token_id: int, e_mask_token_id: int, loss_ignore_index: int diff --git a/src/modalities/dataloader/dataloader_factory.py b/src/modalities/dataloader/dataloader_factory.py index 56d9db1ba..fbc74647b 100644 --- a/src/modalities/dataloader/dataloader_factory.py +++ b/src/modalities/dataloader/dataloader_factory.py @@ -1,8 +1,7 @@ -from typing import Callable - from torch.utils.data import BatchSampler from torch.utils.data.dataset import Dataset +from modalities.dataloader.collate_fns.collate_if import CollatorIF from modalities.dataloader.dataloader import LLMDataLoader @@ -12,7 +11,7 @@ def get_dataloader( dataloader_tag: str, dataset: Dataset, batch_sampler: BatchSampler, - collate_fn: Callable, + collator: CollatorIF, num_workers: int, pin_memory: bool, ) -> LLMDataLoader: @@ -23,7 +22,7 @@ def get_dataloader( dataloader_tag (str): Tag for the dataloader dataset (Dataset): Dataset to be used batch_sampler (BatchSampler): batch sampler for batch-wise sampling from the dataset - collate_fn (Callable): Callable for shaping the batch + collator (CollatorIF): Collator for shaping the batch num_workers (int): Number of workers for the dataloader pin_memory (bool): Flag indicating whether to pin memory Returns: @@ -33,7 +32,7 @@ def get_dataloader( dataloader_tag=dataloader_tag, batch_sampler=batch_sampler, dataset=dataset, - collate_fn=collate_fn, + collate_fn=collator, num_workers=num_workers, pin_memory=pin_memory, ) diff --git a/src/modalities/models/gpt2/collator.py b/src/modalities/models/gpt2/collator.py deleted file mode 100644 index f4cf9b531..000000000 --- a/src/modalities/models/gpt2/collator.py +++ /dev/null @@ -1,36 +0,0 @@ -import torch - -from modalities.batch import DatasetBatch -from modalities.dataloader.collate_fns.collate_if import CollateFnIF - - -class GPT2LLMCollateFn(CollateFnIF): - """GPT2LLMCollateFn class to define a collate function for GPT2 language model.""" - - def __init__(self, sample_key: str, target_key: str): - """ - Initializes the Collator object. - - Args: - sample_key (str): The key for accessing the sample data. - target_key (str): The key for accessing the target data. - """ - self.sample_key = sample_key - self.target_key = target_key - - def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: - """ - Process a batch of data. - - Args: - batch (list[dict[str, torch.Tensor]]): A list of dictionaries containing tensors. - - Returns: - DatasetBatch: A processed batch of data where sample and target sequences are created. - - """ - - sample_tensor = torch.stack([torch.tensor(d[self.sample_key]) for d in batch]) - samples = {self.sample_key: sample_tensor[:, :-1]} - targets = {self.target_key: sample_tensor[:, 1:]} - return DatasetBatch(targets=targets, samples=samples) diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 312ff5957..f1f82a303 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -39,11 +39,11 @@ FSDP1CheckpointSavingConfig, FSDP2WrappedModelConfig, FSDPWrappedModelConfig, - GPT2LLMCollateFnConfig, GPT2MFUCalculatorConfig, LinearLRSchedulerConfig, LLMDataLoaderConfig, MemMapDatasetConfig, + MemMapDatasetIterativeConfig, OneCycleLRSchedulerConfig, PackedMemMapDatasetContinuousConfig, PackedMemMapDatasetMegatronConfig, @@ -61,10 +61,12 @@ WandBEvaluationResultSubscriberConfig, WeightInitializedModelConfig, ) -from modalities.dataloader.collate_fns.collator_fn_wrapper_for_loss_masking import ( - LossMaskingCollateFnWrapper, - LossMaskingCollateFnWrapperConfig, +from modalities.dataloader.collate_fns.autoregressive_collate_fn import ( + AutoregressiveCollateFn, + AutoregressiveCollateFnConfig, ) +from modalities.dataloader.collate_fns.collator import DefaultWrappingCollator, DefaultWrappingCollatorConfig +from modalities.dataloader.collate_fns.loss_masking_collate_fn import LossMaskingCollateFn, LossMaskingCollateFnConfig from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.dataset import DummyDatasetConfig from modalities.dataloader.dataset_factory import DatasetFactory @@ -77,7 +79,6 @@ from modalities.models.coca.coca_model import CoCa, CoCaConfig from modalities.models.coca.collator import CoCaCollateFnConfig, CoCaCollatorFn from modalities.models.components.layer_norms import LayerNormConfig, RMSLayerNorm, RMSLayerNormConfig -from modalities.models.gpt2.collator import GPT2LLMCollateFn from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory @@ -202,6 +203,12 @@ class ComponentEntity: DatasetFactory.get_packed_mem_map_dataset_continuous, PackedMemMapDatasetContinuousConfig, ), + ComponentEntity( + "dataset", + "mem_map_dataset_iterative", + DatasetFactory.get_mem_map_dataset_iterative, + MemMapDatasetIterativeConfig, + ), ComponentEntity( "dataset", "packed_mem_map_dataset_megatron", @@ -219,11 +226,10 @@ class ComponentEntity: # batch samplers ComponentEntity("batch_sampler", "default", BatchSampler, BatchSamplerConfig), # collators - ComponentEntity("collate_fn", "gpt_2_llm_collator", GPT2LLMCollateFn, GPT2LLMCollateFnConfig), + ComponentEntity("collator", "default_wrapping_collator", DefaultWrappingCollator, DefaultWrappingCollatorConfig), + ComponentEntity("collate_fn", "autoregressive", AutoregressiveCollateFn, AutoregressiveCollateFnConfig), ComponentEntity("collate_fn", "coca_collator", CoCaCollatorFn, CoCaCollateFnConfig), - ComponentEntity( - "collate_fn", "mask_loss_collator_wrapper", LossMaskingCollateFnWrapper, LossMaskingCollateFnWrapperConfig - ), + ComponentEntity("collate_fn", "masked_loss", LossMaskingCollateFn, LossMaskingCollateFnConfig), # data loaders ComponentEntity("data_loader", "default", DataloaderFactory.get_dataloader, LLMDataLoaderConfig), # checkpointing diff --git a/tests/dataloader/test_packed_dataset.py b/tests/dataloader/test_packed_dataset.py index ebb4a228e..166917787 100644 --- a/tests/dataloader/test_packed_dataset.py +++ b/tests/dataloader/test_packed_dataset.py @@ -10,7 +10,7 @@ PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron, ) -from modalities.models.gpt2.collator import GPT2LLMCollateFn +from modalities.models.gpt2.collator import AutoregressiveCollateFn @pytest.mark.parametrize("block_size, expected_length", [(1, 4), (2, 3), (3, 3), (10, 2), (6, 2), (20, 1), (25, 0)]) @@ -231,7 +231,7 @@ def test_conversion_tokens_represented_as_unsigned_ints(tmpdir, token_size_in_by ) assert list(ds) - collator = GPT2LLMCollateFn(sample_key=sample_key, target_key="abc") + collator = AutoregressiveCollateFn(sample_key=sample_key, target_key="abc") for batch in zip(ds, ds): collator(list(batch)) diff --git a/tests/instruction_tuning/test_loss_masking.py b/tests/instruction_tuning/test_loss_masking.py index 346b12685..b66361fa2 100644 --- a/tests/instruction_tuning/test_loss_masking.py +++ b/tests/instruction_tuning/test_loss_masking.py @@ -4,12 +4,12 @@ import torch from modalities.batch import DatasetBatch -from modalities.dataloader.collate_fns.collator_fn_wrapper_for_loss_masking import ( +from modalities.dataloader.collate_fns.autoregressive_collate_fn import AutoregressiveCollateFn +from modalities.dataloader.collate_fns.loss_masking_collate_fn import ( LossMaskingCollateFnWrapper, LossMaskingCollateFnWrapperConfig, LossMaskingTokenConfig, ) -from modalities.models.gpt2.collator import GPT2LLMCollateFn from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper @@ -28,7 +28,7 @@ def dummy_tokenizer(): @pytest.fixture def loss_masking_config(dummy_tokenizer) -> LossMaskingCollateFnWrapperConfig: return dict( - wrapped_collate_fn=GPT2LLMCollateFn(sample_key="sample", target_key="target"), + wrapped_collate_fn=AutoregressiveCollateFn(sample_key="sample", target_key="target"), target_keys_to_mask=["target"], loss_ignore_index=-100, mask_tokens=LossMaskingTokenConfig(b_include_to_loss_token="begin", e_include_to_loss_token="end"), From 78f605849b449b0393a0a9250bc55220e55f9e8d Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 4 Jul 2025 00:55:11 +0200 Subject: [PATCH 3/7] feat: added iterative memmap dataset --- src/modalities/dataloader/dataset_factory.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/src/modalities/dataloader/dataset_factory.py b/src/modalities/dataloader/dataset_factory.py index 1eab9e328..7e28aba13 100644 --- a/src/modalities/dataloader/dataset_factory.py +++ b/src/modalities/dataloader/dataset_factory.py @@ -10,6 +10,7 @@ DummyDataset, DummySampleConfig, MemMapDataset, + PackedMemMapDatasetBase, PackedMemMapDatasetContinuous, PackedMemMapDatasetMegatron, ) @@ -71,6 +72,23 @@ def get_raw_index(raw_index_path: Path) -> list[tuple[int, int]]: index = pickle.load(f) return index + @staticmethod + def get_mem_map_dataset_iterative(raw_data_path: Path, sample_key: str) -> PackedMemMapDatasetBase: + """ + Initializes a PackedMemMapDatasetBase object for iterative memory-mapped datasets. + In contrast to the packed version, this dataset always returns the respective sample for + a given index. The packed version returns a block of samples with a fixed sequence length. + + Args: + raw_data_path (Path): The path to the raw data. + sample_key (str): The key used to retrieve the samples from the dataset. + + Returns: + PackedMemMapDatasetBase: The iterative memory-mapped dataset. + """ + dataset = PackedMemMapDatasetBase(raw_data_path=raw_data_path, sample_key=sample_key, load_index=True) + return dataset + @staticmethod def get_packed_mem_map_dataset_continuous( raw_data_path: Path, sequence_length: int, sample_key: str, reuse_last_target: bool From 1cbe27283c15ae0809aea08c66d2b985c65c76c3 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 4 Jul 2025 00:55:38 +0200 Subject: [PATCH 4/7] feat: added autoregressive collate fn --- .../collate_fns/autoregressive_collate_fn.py | 39 +++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 src/modalities/dataloader/collate_fns/autoregressive_collate_fn.py diff --git a/src/modalities/dataloader/collate_fns/autoregressive_collate_fn.py b/src/modalities/dataloader/collate_fns/autoregressive_collate_fn.py new file mode 100644 index 000000000..7e9b6085a --- /dev/null +++ b/src/modalities/dataloader/collate_fns/autoregressive_collate_fn.py @@ -0,0 +1,39 @@ +import torch +from pydantic import BaseModel + +from modalities.dataloader.collate_fns.collate_if import CollateFnIF + + +class AutoregressiveCollateFnConfig(BaseModel): + sample_key: str + target_key: str + + +class AutoregressiveCollateFn(CollateFnIF): + """AutoregressiveCollateFn class to define a collate function for language modeling.""" + + def __init__(self, sample_key: str, target_key: str): + """ + Initializes the Collator object. + + Args: + sample_key (str): The key for accessing the sample data. + target_key (str): The key for accessing the target data. + """ + self.sample_key = sample_key + self.target_key = target_key + + def __call__(self, batch: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + """ + Process a batch of data. + + Args: + batch (dict[str, torch.Tensor]): A dictionary containing tensors of the batch. + + Returns: + dict[str, torch.Tensor]: The processed batch with sample and target tensors. + """ + sample_tensor = batch[self.sample_key] + batch[self.sample_key] = sample_tensor[:, :-1] + batch[self.target_key] = sample_tensor[:, 1:] + return batch From cf870460ed99a7d718e46860f305c468b4566174 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Fri, 4 Jul 2025 00:55:59 +0200 Subject: [PATCH 5/7] refactor: added first running config with refactored collator --- .../config_lorem_ipsum_long_fsdp2.yaml | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2.yaml index 7e27b7e33..549a6ae7d 100644 --- a/config_files/training/config_lorem_ipsum_long_fsdp2.yaml +++ b/config_files/training/config_lorem_ipsum_long_fsdp2.yaml @@ -51,12 +51,22 @@ settings: num_seen_samples: 0 last_step: -1 -collate_fn: - component_key: collate_fn - variant_key: gpt_2_llm_collator +collator: + component_key: collator + variant_key: default_wrapping_collator config: - sample_key: ${settings.referencing_keys.sample_key} - target_key: ${settings.referencing_keys.target_key} + input_keys: + - ${settings.referencing_keys.sample_key} + sample_keys: + - ${settings.referencing_keys.sample_key} + target_keys: + - ${settings.referencing_keys.target_key} + collate_fns: + - component_key: collate_fn + variant_key: autoregressive + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} train_dataset: component_key: dataset @@ -95,8 +105,8 @@ train_dataloader: seed: 42 drop_last: true skip_num_global_samples: ${settings.training_progress.num_seen_samples} - collate_fn: - instance_key: collate_fn + collator: + instance_key: collator pass_type: BY_REFERENCE test_dataset: @@ -134,8 +144,8 @@ test_dataloader: dataset: instance_key: test_dataset pass_type: BY_REFERENCE - collate_fn: - instance_key: collate_fn + collator: + instance_key: collator pass_type: BY_REFERENCE eval_dataloaders: From adf8e19a62e3afc0b759cf0cb673150badadcf93 Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sun, 6 Jul 2025 20:53:55 +0200 Subject: [PATCH 6/7] refactor: app state properties can cow be set from outside --- .../checkpointing/stateful/app_state.py | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/modalities/checkpointing/stateful/app_state.py b/src/modalities/checkpointing/stateful/app_state.py index 1b5a583b9..9169cb902 100644 --- a/src/modalities/checkpointing/stateful/app_state.py +++ b/src/modalities/checkpointing/stateful/app_state.py @@ -67,14 +67,41 @@ def is_loaded(self) -> bool: def model(self) -> nn.Module: return self._model + @model.setter + def model(self, model: nn.Module) -> None: + """Sets the model in the AppState object. + + Args: + model (nn.Module): The model to set in the AppState object. + """ + self._model = model + @property def optimizer(self) -> Optimizer: return self._optimizer + @optimizer.setter + def optimizer(self, optimizer: Optimizer) -> None: + """Sets the optimizer in the AppState object. + + Args: + optimizer (Optimizer): The optimizer to set in the AppState object. + """ + self._optimizer = optimizer + @property def lr_scheduler(self) -> LRScheduler: return self._lr_scheduler + @lr_scheduler.setter + def lr_scheduler(self, lr_scheduler: LRScheduler) -> None: + """Sets the learning rate scheduler in the AppState object. + + Args: + lr_scheduler (LRScheduler): The learning rate scheduler to set in the AppState object. + """ + self._lr_scheduler = lr_scheduler + def state_dict(self) -> dict[str, Any]: """Returns the state dict of the AppState object. From 331473b52bcc1bf417348377a682fc11b27640ca Mon Sep 17 00:00:00 2001 From: Max Luebbering <2804731+le1nux@users.noreply.github.com> Date: Sun, 6 Jul 2025 20:55:03 +0200 Subject: [PATCH 7/7] feat: added multi-round inference (not multi-turn) --- src/modalities/dataloader/collate_fns/collator.py | 5 ++++- .../collate_fns/loss_masking_collate_fn.py | 3 +-- .../inference/text/inference_component.py | 15 ++++++++++----- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/modalities/dataloader/collate_fns/collator.py b/src/modalities/dataloader/collate_fns/collator.py index 597d94c0b..15592863b 100644 --- a/src/modalities/dataloader/collate_fns/collator.py +++ b/src/modalities/dataloader/collate_fns/collator.py @@ -70,11 +70,14 @@ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> DatasetBatch: Returns: DatasetBatch: The processed batch of data. """ + # convert to tensors + batch = [{k: torch.tensor(v) for k, v in tensor_dict.items()} for tensor_dict in batch] + if self.sequence_length is not None and self.padding_token_id is not None: # Pad and truncate the sequences in the batch to the fixed sequence length self._pad_and_truncate_inplace(batch) - sample_tensor_dict = {key: torch.stack([torch.tensor(d[key]) for d in batch]) for key in self.sampple_keys} + sample_tensor_dict = {key: torch.stack([d[key] for d in batch]) for key in self.sampple_keys} for wrapped_collate_fn in self.collate_fns: sample_tensor_dict = wrapped_collate_fn(sample_tensor_dict) diff --git a/src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py b/src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py index 31cabbfbb..f911eb465 100644 --- a/src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py +++ b/src/modalities/dataloader/collate_fns/loss_masking_collate_fn.py @@ -157,8 +157,7 @@ def _mask_target( if not ((0 <= include_to_loss_mask).all() and (include_to_loss_mask <= 1).all()): raise ValueError( "end mask token indicator is before begin mask token indicator in the target. " - + "This is not supported by the LossMaskingCollateFnWrapper." - + "Make sure to use padding and truncation with the tokenizer for PackedMemMapDatasetContinuous" + + "This is not supported by the LossMaskingCollateFn." ) # apply mask: if mask is 1, keep the target, otherwise replace with loss_ignore_index diff --git a/src/modalities/inference/text/inference_component.py b/src/modalities/inference/text/inference_component.py index 939ccadc0..b991b813c 100644 --- a/src/modalities/inference/text/inference_component.py +++ b/src/modalities/inference/text/inference_component.py @@ -74,11 +74,16 @@ def generate_tokens( print("\n max tokens reached", end="") def run(self): - prompt = TextInferenceComponent._get_prompt(self.prompt_template) - try: - self.generate_tokens(context=prompt) - except KeyboardInterrupt: - print("closing app...") + round = 1 + while True: + print(f"\n\n--------------------ROUND-{round}--------------------") + prompt = TextInferenceComponent._get_prompt(self.prompt_template) + try: + self.generate_tokens(context=prompt) + except KeyboardInterrupt: + print("closing app...") + break + round += 1 @staticmethod def _get_prompt(template: str) -> str: