-
Notifications
You must be signed in to change notification settings - Fork 12
Pipeline parallelism continued #399
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1e80bbb
ed93d28
6241ea8
0ba8fbc
4a41b6c
ee529b7
625de59
9677bd6
7ac9edf
d9f63c1
83c87b9
95f2470
521e586
002b0ae
1d4943f
5b53ff9
5147a7a
1cb9779
27ad56d
41c4f36
6f3d5da
9b85334
84e2702
32fbe94
61ab311
6952bcc
90dbe51
49df7d6
7996a29
cbddcbc
56a917a
eb47aa9
d228351
55dad72
b6a1e2d
16a51af
c49895a
f685fc5
c07fcf6
5019bbb
a08e555
45b5418
a394ab0
cae050e
6952230
ffa032c
8d418a1
fffd0a1
049472f
608c7fc
16c4bc4
f5a1020
9d1f107
dfc1bde
cd9f595
edf7a4e
abcf235
554cd39
484815e
ddb249b
51b7db4
a84db67
bdc684e
0a8ab53
71d1f81
c0cf6a2
c6fee18
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,12 +1,10 @@ | ||
| [project] | ||
| name = "modalities" | ||
| version = "0.3.2" | ||
| requires-python = ">=3.10,<3.12" | ||
| description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training." | ||
| readme = "README.md" | ||
| dependencies = [ | ||
| "numpy<2.0", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we should update the README that for now we need torch-nightly to be installed manually by the user |
||
| "torch==2.6.0", | ||
| "packaging", | ||
| "tqdm", | ||
| "pyyaml", | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate | ||
| from modalities.logging_broker.publisher import MessagePublisher | ||
| from modalities.models.model import model_predict_batch | ||
| from modalities.models.parallelism.pipeline_parallelism import Pipeline | ||
| from modalities.running_env.fsdp.reducer import Reducer | ||
| from modalities.trainer import ThroughputAggregationKeys | ||
| from modalities.util import Aggregator, TimeRecorder | ||
|
|
@@ -36,20 +37,42 @@ def evaluate_batch( | |
| batch: DatasetBatch, | ||
| model: nn.Module, | ||
| loss_fun: Callable[[InferenceResultBatch], torch.Tensor], | ||
| ) -> torch.Tensor: | ||
| scheduled_pipeline: Pipeline | None = None, | ||
| ) -> torch.Tensor | None: | ||
| """Evaluate a single batch by forwarding it through the model and calculating the loss. | ||
|
|
||
| Args: | ||
| batch (DatasetBatch): The batch to evaluate | ||
| model (nn.Module): The model to evaluate | ||
| loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss | ||
| scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to | ||
| operate the model. Defaults to None. | ||
|
|
||
| Returns: | ||
| torch.Tensor: The loss of the batch | ||
| torch.Tensor | None: The loss of the batch | ||
| None, if a non-last stage was processed in pipeline parallelism | ||
| """ | ||
| with torch.no_grad(): | ||
| result_batch = model_predict_batch(model=model, batch=batch) | ||
| loss = loss_fun(result_batch) | ||
| if scheduled_pipeline is not None: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. basically code duplication from the trainer. |
||
| pp_schedule = scheduled_pipeline.pp_schedule | ||
| targets, losses = ( | ||
| (batch.targets[loss_fun.target_key].contiguous(), []) | ||
| if scheduled_pipeline.is_last_pp_stage | ||
| else (None, None) | ||
| ) | ||
|
|
||
| if scheduled_pipeline.is_first_pp_stage: | ||
| pp_schedule.eval(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) | ||
| else: | ||
| pp_schedule.eval(target=targets, losses=losses) | ||
| loss = ( | ||
| torch.mean(torch.stack(losses)).to(losses[0].device) | ||
| if scheduled_pipeline.is_last_pp_stage | ||
| else None | ||
| ) | ||
| else: | ||
| result_batch = model_predict_batch(model=model, batch=batch) | ||
| loss = loss_fun(result_batch) | ||
| return loss | ||
|
|
||
| def evaluate( | ||
|
|
@@ -58,6 +81,7 @@ def evaluate( | |
| data_loaders: list[LLMDataLoader], | ||
| loss_fun: Callable[[InferenceResultBatch], torch.Tensor], | ||
| num_train_steps_done: int, | ||
| scheduled_pipeline: Pipeline | None = None, | ||
| ) -> dict[str, EvaluationResultBatch]: | ||
| """Evaluate the model on a set of datasets. | ||
|
|
||
|
|
@@ -66,6 +90,8 @@ def evaluate( | |
| data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on | ||
| loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss | ||
| num_train_steps_done (int): The number of training steps done so far for logging purposes | ||
| scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to | ||
| operate the model. Defaults to None. | ||
|
|
||
| Returns: | ||
| dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader | ||
|
|
@@ -90,10 +116,13 @@ def evaluate( | |
| batch=batch, | ||
| model=model, | ||
| loss_fun=loss_fun, | ||
| scheduled_pipeline=scheduled_pipeline, | ||
| ) | ||
|
|
||
| cumulated_loss[0] += batch_loss.item() # sum up batch loss | ||
| cumulated_loss[1] += 1 | ||
| # The batch_loss might be None if we use pipeline parallelism and are not the last stage. | ||
| if batch_loss is not None: | ||
| cumulated_loss[0] += batch_loss.item() # sum up batch loss | ||
| cumulated_loss[1] += 1 | ||
| batch_length_tensor = torch.tensor(len(batch)).to(device) | ||
| thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -9,6 +9,7 @@ | |
| from modalities.dataloader.dataloader import LLMDataLoader | ||
| from modalities.evaluator import Evaluator | ||
| from modalities.loss_functions import Loss | ||
| from modalities.models.parallelism.pipeline_parallelism import Pipeline | ||
| from modalities.trainer import Trainer | ||
| from modalities.training.training_progress import TrainingProgress | ||
| from modalities.util import print_rank_0 | ||
|
|
@@ -40,6 +41,7 @@ def run( | |
| train_data_loader: LLMDataLoader, | ||
| evaluation_data_loaders: list[LLMDataLoader], | ||
| checkpoint_saving: CheckpointSaving, | ||
| scheduled_pipeline: Pipeline | None = None, | ||
| ): | ||
| """Runs the model training, including evaluation and checkpointing. | ||
|
|
||
|
|
@@ -51,12 +53,15 @@ def run( | |
| train_data_loader (LLMDataLoader): Data loader with the training data. | ||
| evaluation_data_loaders (list[LLMDataLoader]): List of data loaders with the evaluation data. | ||
| checkpoint_saving (CheckpointSaving): Routine for saving checkpoints. | ||
| scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to | ||
| operate the model. Defaults to None. | ||
| """ | ||
| evaluation_callback: Callable[[int], None] = partial( | ||
| self._run_evaluation, | ||
| model=app_state.model, | ||
| evaluation_data_loaders=evaluation_data_loaders, | ||
| evaluation_interval_in_steps=evaluation_interval_in_steps, | ||
| scheduled_pipeline=scheduled_pipeline, | ||
| ) | ||
|
|
||
| checkpointing_callback: Callable[[TrainingProgress], None] = partial( | ||
|
|
@@ -74,6 +79,7 @@ def run( | |
| evaluation_callback=evaluation_callback, | ||
| checkpointing_callback=checkpointing_callback, | ||
| training_log_interval_in_steps=training_log_interval_in_steps, | ||
| scheduled_pipeline=scheduled_pipeline, | ||
| ) | ||
| print_rank_0(f"Training done at {datetime.now()}.") | ||
|
|
||
|
|
@@ -101,11 +107,13 @@ def _run_evaluation( | |
| num_train_steps_done: int, | ||
| evaluation_data_loaders: list[LLMDataLoader], | ||
| evaluation_interval_in_steps: int, | ||
| scheduled_pipeline: Pipeline | None = None, | ||
| ): | ||
| if num_train_steps_done % evaluation_interval_in_steps == 0: | ||
| if num_train_steps_done > 0 and num_train_steps_done % evaluation_interval_in_steps == 0: | ||
rrutmann marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| self.evaluator.evaluate( | ||
| model=model, | ||
| data_loaders=evaluation_data_loaders, | ||
| loss_fun=self.loss_fun, | ||
| num_train_steps_done=num_train_steps_done, | ||
| scheduled_pipeline=scheduled_pipeline, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as in previous comments. not a big fan of passing aroudn the |
||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,5 @@ | ||
| from abc import ABC, abstractmethod | ||
| from typing import overload | ||
|
|
||
| import torch | ||
| from torch.nn import CrossEntropyLoss | ||
|
|
@@ -31,9 +32,16 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt | |
| # Mean over the tokens in the local-batch (batch per rank) | ||
| self.loss_fun = CrossEntropyLoss(reduction="mean") | ||
|
|
||
| @overload | ||
| def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: | ||
| labels = forward_batch.get_targets(self.target_key) | ||
| lm_logits = forward_batch.get_predictions(self.prediction_key) | ||
| ... | ||
|
|
||
| @overload | ||
| def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: | ||
| ... | ||
|
|
||
| def __call__(self, *args, **kwargs) -> torch.Tensor: | ||
|
||
| labels, lm_logits = self._parse_arguments(args, kwargs) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could be improved from a software engineering point of view |
||
|
|
||
| # move labels to correct device to enable model parallelism | ||
| labels = labels.to(lm_logits.device) | ||
|
|
@@ -43,6 +51,41 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: | |
| loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) | ||
| return loss | ||
|
|
||
| def _parse_arguments( | ||
| self, | ||
| args: list[torch.Tensor] | list[InferenceResultBatch], | ||
| kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch], | ||
| ) -> tuple[torch.Tensor, torch.Tensor]: | ||
| if len(args) == 1 and isinstance(args[0], InferenceResultBatch): | ||
| forward_batch = args[0] | ||
| labels = forward_batch.get_targets(self.target_key) | ||
| lm_logits = forward_batch.get_predictions(self.prediction_key) | ||
| elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch): | ||
| forward_batch = kwargs["forward_batch"] | ||
| labels = forward_batch.get_targets(self.target_key) | ||
| lm_logits = forward_batch.get_predictions(self.prediction_key) | ||
| elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args): | ||
| lm_logits, labels = args | ||
| elif ( | ||
| "outputs" in kwargs | ||
| and "targets" in kwargs | ||
| and isinstance(kwargs["outputs"], torch.Tensor) | ||
| and isinstance(kwargs["targets"], torch.Tensor) | ||
| ): | ||
| lm_logits = kwargs["outputs"] | ||
| labels = kwargs["targets"] | ||
| elif ( | ||
| len(args) == 1 | ||
| and "targets" in kwargs | ||
| and isinstance(args[0], torch.Tensor) | ||
| and isinstance(kwargs["targets"], torch.Tensor) | ||
| ): | ||
| lm_logits = args[0] | ||
| labels = kwargs["targets"] | ||
| else: | ||
| raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__") | ||
| return labels, lm_logits | ||
|
|
||
|
Comment on lines
+54
to
+88
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Idea: What about defining a new component "pp-loss", which takes a normal loss function and handles the PP-specific part? Generally, I think this parsing function could be improved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This was also our initial idea and we actually implemented it like this in a first draft. However, this had the disadvantage that we had to define the usage of the correct loss for a certain setup in the configs, making them more complex and less user friendly. |
||
|
|
||
| def nce_loss( | ||
| embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -20,6 +20,7 @@ | |
| from modalities.logging_broker.subscriber import MessageSubscriberIF | ||
| from modalities.registry.components import COMPONENTS | ||
| from modalities.registry.registry import Registry | ||
| from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_num_parallel_ranks | ||
| from modalities.trainer import Trainer | ||
| from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0 | ||
|
|
||
|
|
@@ -110,11 +111,20 @@ def run(self, components: TrainingComponentsInstantiationModel): | |
| ) | ||
|
|
||
| # Trainer | ||
| # FIXME replace by get_parallel_degree | ||
| if components.device_mesh is None: | ||
| num_pipeline_parallel_ranks = 1 | ||
| num_data_parallel_ranks = 1 | ||
| else: | ||
| num_pipeline_parallel_ranks = get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.PP) | ||
| num_data_parallel_ranks = get_num_parallel_ranks( | ||
| components.device_mesh, ParallelismDegrees.DP_SHARD | ||
| ) * get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.DP_REPLICATE) | ||
| global_num_tokens_per_train_step = ( | ||
| components.settings.step_profile.local_train_micro_batch_size | ||
| * components.settings.step_profile.sequence_length | ||
| * components.settings.step_profile.gradient_accumulation_steps | ||
| * components.settings.cuda_env.world_size | ||
| * num_data_parallel_ranks | ||
| ) | ||
| trainer = Trainer( | ||
| global_rank=components.settings.cuda_env.global_rank, | ||
|
|
@@ -128,6 +138,7 @@ def run(self, components: TrainingComponentsInstantiationModel): | |
| gradient_clipper=components.gradient_clipper, | ||
| global_num_tokens_per_train_step=global_num_tokens_per_train_step, | ||
| mfu_calculator=components.mfu_calculator, | ||
| num_pipeline_parallel_ranks=num_pipeline_parallel_ranks, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would prefer if we kept the |
||
| ) | ||
|
|
||
| # Evaluator | ||
|
|
@@ -143,7 +154,7 @@ def run(self, components: TrainingComponentsInstantiationModel): | |
| loss_fun=components.loss_fn, | ||
| num_ranks=components.settings.cuda_env.world_size, | ||
| ) | ||
| num_params = get_total_number_of_trainable_parameters(components.app_state.model) | ||
| num_params = get_total_number_of_trainable_parameters(components.app_state.model, components.device_mesh) | ||
| components.evaluation_subscriber.consume_dict({"No. parameters": num_params}) | ||
| logging.info(f"Training model with {num_params} parameters.") | ||
|
|
||
|
|
@@ -169,6 +180,7 @@ def run(self, components: TrainingComponentsInstantiationModel): | |
| checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps, | ||
| evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps, | ||
| training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps, | ||
| scheduled_pipeline=components.scheduled_pipeline if components.scheduled_pipeline else None, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same point as for the trainer. Could we wrap the scheduled pipeline instead and use the existing model interfaces? |
||
| ) | ||
|
|
||
| def get_logging_publishers( | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did we remove this? Our testing is always against 3.10 and 3.11. Do we need a more recent python version?