Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
67 commits
Select commit Hold shift + click to select a range
1e80bbb
feat: implemented stage FQN generation for pipeline parallelism
le1nux Aug 6, 2025
ed93d28
feat: added FQNs per stage calculation
le1nux Aug 7, 2025
6241ea8
feat: generic FQN-based PP staging
le1nux Aug 15, 2025
0ba8fbc
feat: added PP configs
le1nux Aug 15, 2025
4a41b6c
feat: wired up PP within dependency graph
le1nux Aug 15, 2025
ee529b7
feat: added FQN stages generator
le1nux Aug 15, 2025
625de59
feat: implemented scheduled pipeline
le1nux Aug 18, 2025
9677bd6
feat: wired up scheduled and staged pipelines.
le1nux Aug 18, 2025
7ac9edf
feat: added PP test config
le1nux Aug 18, 2025
d9f63c1
refactor: staging is now fully instantiable
le1nux Aug 19, 2025
83c87b9
feat: drafted pp e2e test for fwd/bwd pass
le1nux Aug 19, 2025
95f2470
refactor: renamings in the context of PP
le1nux Aug 29, 2025
521e586
chore: drafted the first PP test.
le1nux Aug 29, 2025
002b0ae
chore: pp config fixes
le1nux Aug 29, 2025
1d4943f
feat: Make test for pipeline parallelism work
rrutmann Sep 5, 2025
5b53ff9
refactor(parallelism): Removed necessity of additional model and loss…
BlueCrescent Sep 8, 2025
5147a7a
refactor(parallelism): Clean up for pp test.
BlueCrescent Sep 8, 2025
1cb9779
test: Print losses to debug tests
rrutmann Sep 8, 2025
27ad56d
feat: Use scheduled_pipeline for forwad backward pass
rrutmann Sep 9, 2025
41c4f36
feat: Use scheduled_pipeline for training
rrutmann Sep 9, 2025
6f3d5da
feat: Use scheduled_pipe in evaluation
rrutmann Sep 9, 2025
9b85334
test: Print losses if test fails
rrutmann Sep 9, 2025
84e2702
chore: Run evaluation before training
rrutmann Sep 9, 2025
32fbe94
chore: Increase microbatch size
rrutmann Sep 9, 2025
61ab311
fix: Use dp size instead of world size for last batch aggregation
rrutmann Sep 10, 2025
6952bcc
docs: Add TODOs for later check
rrutmann Sep 10, 2025
90dbe51
fix: Train before evaluation so that pp is initialized for backwards
rrutmann Sep 10, 2025
49df7d6
fix: Add missing parameter seed to GPT2LLMConfig
rrutmann Sep 12, 2025
7996a29
fix: Retrieve all PP ranks for gradient clipping
rrutmann Sep 15, 2025
cbddcbc
test: Add new parameter num_data_parallel_ranks to Trainer
rrutmann Sep 15, 2025
56a917a
fix: Make FSDP1GradientClipperConfig independent of device_mesh
rrutmann Sep 15, 2025
eb47aa9
fix: Handle optional device_mesh correctly
rrutmann Sep 15, 2025
d228351
feat: Consider pipeline parallelism in tensor pallelization
rrutmann Sep 17, 2025
55dad72
test: Use the same data on each rank & test tensor parallelism
rrutmann Sep 17, 2025
b6a1e2d
refactor(parallelism): Some clean-up.
BlueCrescent Sep 17, 2025
16a51af
chore: Merge branch 'pipeline_parallelism_fix' of github.com:Modaliti…
rrutmann Sep 18, 2025
c49895a
test: Update configs for parallelization testing
rrutmann Sep 19, 2025
f685fc5
test: Use correct length to create test sequences
rrutmann Sep 19, 2025
c07fcf6
test: Use realistic std for model initialization
rrutmann Sep 19, 2025
5019bbb
fix: Remove unused third dimension for reduced_losses
rrutmann Sep 19, 2025
a08e555
refactor: Remove unused filtering
rrutmann Sep 19, 2025
45b5418
fix: Aggregate loss of last train batch correct across pp ranks
rrutmann Sep 22, 2025
a394ab0
docs: Add example config for pipeline and tensor parallelism
rrutmann Sep 22, 2025
cae050e
docs: Add docstrings and type hints
rrutmann Sep 22, 2025
6952230
docs: Add type hints and docstrings
rrutmann Sep 22, 2025
ffa032c
fix: Check if parallelism method is initialized
rrutmann Sep 22, 2025
8d418a1
docs: Add new parameter in docstring
rrutmann Sep 22, 2025
fffd0a1
test: Run only one PP only test
rrutmann Sep 23, 2025
049472f
refactor: Addressed copilot review
rrutmann Sep 24, 2025
608c7fc
chore: Remove requirements for python and torch
rrutmann Oct 15, 2025
16c4bc4
fix: Allow dp shard degree 1
rrutmann Oct 17, 2025
f5a1020
test: Add test for checkpointing with pipeline parallelism
rrutmann Oct 17, 2025
9d1f107
fix(parallelism): Building model stages in PP now also filters the mo…
BlueCrescent Oct 17, 2025
dfc1bde
test(checkpointing): Some fixes for pp checkpointing test.
BlueCrescent Oct 17, 2025
cd9f595
test(checkpointing): Made dcp checkpointing test terminate correctly …
BlueCrescent Oct 20, 2025
edf7a4e
test(checkpointing): Checkpointing equality tests now explicitly only…
BlueCrescent Oct 21, 2025
abcf235
fix: Use ModuleDict for transformer layers for correct checkpointing …
Oct 21, 2025
554cd39
chore: Rename layer_id to layer_idx
Oct 21, 2025
484815e
test: Adapt tests to new gpt2 model structure
Oct 21, 2025
ddb249b
test: Adapt code to latest changes to pass tests
rrutmann Oct 21, 2025
51b7db4
test(data): Added tests for distributed multi dim data sampling.
BlueCrescent Oct 21, 2025
a84db67
fix(parallelism): Use dp degree instead of world size in global_num_t…
BlueCrescent Oct 23, 2025
bdc684e
fix(optimizer): Optimizer groups for with and without weight decay in…
BlueCrescent Oct 23, 2025
0a8ab53
test(parallelism): Added warmstart e2e test with fsdp2 + tp + pp.
BlueCrescent Oct 23, 2025
71d1f81
test(gradient_clipping): Check that gradient clipping in pp setting i…
BlueCrescent Oct 24, 2025
c0cf6a2
refactor(gradient_clipping): Removed duplicate code in fsdp2 gradient…
BlueCrescent Oct 24, 2025
c6fee18
fix(logging): Correct number of parameters computation in case of pip…
BlueCrescent Oct 24, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
411 changes: 411 additions & 0 deletions config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml

Large diffs are not rendered by default.

422 changes: 422 additions & 0 deletions config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml

Large diffs are not rendered by default.

2 changes: 0 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
[project]
name = "modalities"
version = "0.3.2"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we remove this? Our testing is always against 3.10 and 3.11. Do we need a more recent python version?

requires-python = ">=3.10,<3.12"
description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training."
readme = "README.md"
dependencies = [
"numpy<2.0",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should update the README that for now we need torch-nightly to be installed manually by the user

"torch==2.6.0",
"packaging",
"tqdm",
"pyyaml",
Expand Down
4 changes: 4 additions & 0 deletions src/modalities/config/instantiation_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,13 @@
PydanticAppStateType,
PydanticCheckpointSavingIFType,
PydanticDatasetIFType,
PydanticDeviceMeshIFType,
PydanticGradientClipperIFType,
PydanticLLMDataLoaderIFType,
PydanticLossIFType,
PydanticMessageSubscriberIFType,
PydanticMFUCalculatorABCType,
PydanticPipelineType,
PydanticPytorchDeviceType,
PydanticPytorchModuleType,
PydanticTextInferenceComponentType,
Expand Down Expand Up @@ -178,6 +180,8 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel
checkpoint_saving: PydanticCheckpointSavingIFType
gradient_clipper: PydanticGradientClipperIFType
mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None
scheduled_pipeline: Optional[PydanticPipelineType] = None
device_mesh: Optional[PydanticDeviceMeshIFType] = None
model_raw: PydanticPytorchModuleType

@model_validator(mode="after")
Expand Down
5 changes: 5 additions & 0 deletions src/modalities/config/pydantic_if_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch.distributed.device_mesh import DeviceMesh
from torch.distributed.fsdp import FSDPModule as FSDP2
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1
from torch.distributed.pipelining import PipelineStage
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import Sampler
Expand All @@ -21,6 +22,7 @@
from modalities.inference.text.inference_component import TextInferenceComponent
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator
from modalities.nn.model_initialization.initialization_if import ModelInitializationIF
from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper
from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF
Expand Down Expand Up @@ -83,3 +85,6 @@ def __get_pydantic_core_schema__(
PydanticDatasetBatchGeneratorIFType = Annotated[
DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF)
]
PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)]
PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)]
PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)]
8 changes: 4 additions & 4 deletions src/modalities/conversion/gpt2/conversion_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ def _copy_weights_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM):
modalities_model (GPT2LLM): The modalities model from which the weights will be copied.
"""
hf_model.model.embed_tokens.weight.data.copy_(modalities_model.transformer.wte.weight.data)
for hf_layer, modalities_layer in zip(hf_model.model.layers, modalities_model.transformer.h):
_copy_weights_attention(hf_layer, modalities_layer)
_copy_weights_mlp(hf_layer, modalities_layer)
_copy_weights_layer_norms(hf_layer, modalities_layer)
for hf_layer, modalities_layer_idx in zip(hf_model.model.layers, modalities_model.transformer.h):
_copy_weights_attention(hf_layer, modalities_model.transformer.h[modalities_layer_idx])
_copy_weights_mlp(hf_layer, modalities_model.transformer.h[modalities_layer_idx])
_copy_weights_layer_norms(hf_layer, modalities_model.transformer.h[modalities_layer_idx])
_copy_weights_base_modules(hf_model.lm_head, modalities_model.transformer.lm_head)
_copy_weights_base_modules(hf_model.model.norm, modalities_model.transformer.lm_head_norm)

Expand Down
41 changes: 35 additions & 6 deletions src/modalities/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate
from modalities.logging_broker.publisher import MessagePublisher
from modalities.models.model import model_predict_batch
from modalities.models.parallelism.pipeline_parallelism import Pipeline
from modalities.running_env.fsdp.reducer import Reducer
from modalities.trainer import ThroughputAggregationKeys
from modalities.util import Aggregator, TimeRecorder
Expand Down Expand Up @@ -36,20 +37,42 @@ def evaluate_batch(
batch: DatasetBatch,
model: nn.Module,
loss_fun: Callable[[InferenceResultBatch], torch.Tensor],
) -> torch.Tensor:
scheduled_pipeline: Pipeline | None = None,
) -> torch.Tensor | None:
"""Evaluate a single batch by forwarding it through the model and calculating the loss.

Args:
batch (DatasetBatch): The batch to evaluate
model (nn.Module): The model to evaluate
loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss
scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to
operate the model. Defaults to None.

Returns:
torch.Tensor: The loss of the batch
torch.Tensor | None: The loss of the batch
None, if a non-last stage was processed in pipeline parallelism
"""
with torch.no_grad():
result_batch = model_predict_batch(model=model, batch=batch)
loss = loss_fun(result_batch)
if scheduled_pipeline is not None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

basically code duplication from the trainer.
Also not a big fan of passing the scheduled_pipeline in here.

pp_schedule = scheduled_pipeline.pp_schedule
targets, losses = (
(batch.targets[loss_fun.target_key].contiguous(), [])
if scheduled_pipeline.is_last_pp_stage
else (None, None)
)

if scheduled_pipeline.is_first_pp_stage:
pp_schedule.eval(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses)
else:
pp_schedule.eval(target=targets, losses=losses)
loss = (
torch.mean(torch.stack(losses)).to(losses[0].device)
if scheduled_pipeline.is_last_pp_stage
else None
)
else:
result_batch = model_predict_batch(model=model, batch=batch)
loss = loss_fun(result_batch)
return loss

def evaluate(
Expand All @@ -58,6 +81,7 @@ def evaluate(
data_loaders: list[LLMDataLoader],
loss_fun: Callable[[InferenceResultBatch], torch.Tensor],
num_train_steps_done: int,
scheduled_pipeline: Pipeline | None = None,
) -> dict[str, EvaluationResultBatch]:
"""Evaluate the model on a set of datasets.

Expand All @@ -66,6 +90,8 @@ def evaluate(
data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on
loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss
num_train_steps_done (int): The number of training steps done so far for logging purposes
scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to
operate the model. Defaults to None.

Returns:
dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader
Expand All @@ -90,10 +116,13 @@ def evaluate(
batch=batch,
model=model,
loss_fun=loss_fun,
scheduled_pipeline=scheduled_pipeline,
)

cumulated_loss[0] += batch_loss.item() # sum up batch loss
cumulated_loss[1] += 1
# The batch_loss might be None if we use pipeline parallelism and are not the last stage.
if batch_loss is not None:
cumulated_loss[0] += batch_loss.item() # sum up batch loss
cumulated_loss[1] += 1
batch_length_tensor = torch.tensor(len(batch)).to(device)
thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor)

Expand Down
10 changes: 9 additions & 1 deletion src/modalities/gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.evaluator import Evaluator
from modalities.loss_functions import Loss
from modalities.models.parallelism.pipeline_parallelism import Pipeline
from modalities.trainer import Trainer
from modalities.training.training_progress import TrainingProgress
from modalities.util import print_rank_0
Expand Down Expand Up @@ -40,6 +41,7 @@ def run(
train_data_loader: LLMDataLoader,
evaluation_data_loaders: list[LLMDataLoader],
checkpoint_saving: CheckpointSaving,
scheduled_pipeline: Pipeline | None = None,
):
"""Runs the model training, including evaluation and checkpointing.

Expand All @@ -51,12 +53,15 @@ def run(
train_data_loader (LLMDataLoader): Data loader with the training data.
evaluation_data_loaders (list[LLMDataLoader]): List of data loaders with the evaluation data.
checkpoint_saving (CheckpointSaving): Routine for saving checkpoints.
scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to
operate the model. Defaults to None.
"""
evaluation_callback: Callable[[int], None] = partial(
self._run_evaluation,
model=app_state.model,
evaluation_data_loaders=evaluation_data_loaders,
evaluation_interval_in_steps=evaluation_interval_in_steps,
scheduled_pipeline=scheduled_pipeline,
)

checkpointing_callback: Callable[[TrainingProgress], None] = partial(
Expand All @@ -74,6 +79,7 @@ def run(
evaluation_callback=evaluation_callback,
checkpointing_callback=checkpointing_callback,
training_log_interval_in_steps=training_log_interval_in_steps,
scheduled_pipeline=scheduled_pipeline,
)
print_rank_0(f"Training done at {datetime.now()}.")

Expand Down Expand Up @@ -101,11 +107,13 @@ def _run_evaluation(
num_train_steps_done: int,
evaluation_data_loaders: list[LLMDataLoader],
evaluation_interval_in_steps: int,
scheduled_pipeline: Pipeline | None = None,
):
if num_train_steps_done % evaluation_interval_in_steps == 0:
if num_train_steps_done > 0 and num_train_steps_done % evaluation_interval_in_steps == 0:
self.evaluator.evaluate(
model=model,
data_loaders=evaluation_data_loaders,
loss_fun=self.loss_fun,
num_train_steps_done=num_train_steps_done,
scheduled_pipeline=scheduled_pipeline,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as in previous comments. not a big fan of passing aroudn the scheduled_pipeline

)
47 changes: 45 additions & 2 deletions src/modalities/loss_functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
from typing import overload

import torch
from torch.nn import CrossEntropyLoss
Expand Down Expand Up @@ -31,9 +32,16 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt
# Mean over the tokens in the local-batch (batch per rank)
self.loss_fun = CrossEntropyLoss(reduction="mean")

@overload
def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
...

@overload
def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
...

def __call__(self, *args, **kwargs) -> torch.Tensor:
Copy link

Copilot AI Sep 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using *args and **kwargs instead of proper overloads makes the API less type-safe and harder to understand. Consider implementing proper method overloading with specific parameter types.

Copilot uses AI. Check for mistakes.

labels, lm_logits = self._parse_arguments(args, kwargs)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be improved from a software engineering point of view


# move labels to correct device to enable model parallelism
labels = labels.to(lm_logits.device)
Expand All @@ -43,6 +51,41 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor:
loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
return loss

def _parse_arguments(
self,
args: list[torch.Tensor] | list[InferenceResultBatch],
kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch],
) -> tuple[torch.Tensor, torch.Tensor]:
if len(args) == 1 and isinstance(args[0], InferenceResultBatch):
forward_batch = args[0]
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch):
forward_batch = kwargs["forward_batch"]
labels = forward_batch.get_targets(self.target_key)
lm_logits = forward_batch.get_predictions(self.prediction_key)
elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args):
lm_logits, labels = args
elif (
"outputs" in kwargs
and "targets" in kwargs
and isinstance(kwargs["outputs"], torch.Tensor)
and isinstance(kwargs["targets"], torch.Tensor)
):
lm_logits = kwargs["outputs"]
labels = kwargs["targets"]
elif (
len(args) == 1
and "targets" in kwargs
and isinstance(args[0], torch.Tensor)
and isinstance(kwargs["targets"], torch.Tensor)
):
lm_logits = args[0]
labels = kwargs["targets"]
else:
raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__")
return labels, lm_logits

Comment on lines +54 to +88
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Idea: What about defining a new component "pp-loss", which takes a normal loss function and handles the PP-specific part?

Generally, I think this parsing function could be improved.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was also our initial idea and we actually implemented it like this in a first draft. However, this had the disadvantage that we had to define the usage of the correct loss for a certain setup in the configs, making them more complex and less user friendly.


def nce_loss(
embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float
Expand Down
16 changes: 14 additions & 2 deletions src/modalities/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.registry.components import COMPONENTS
from modalities.registry.registry import Registry
from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_num_parallel_ranks
from modalities.trainer import Trainer
from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0

Expand Down Expand Up @@ -110,11 +111,20 @@ def run(self, components: TrainingComponentsInstantiationModel):
)

# Trainer
# FIXME replace by get_parallel_degree
if components.device_mesh is None:
num_pipeline_parallel_ranks = 1
num_data_parallel_ranks = 1
else:
num_pipeline_parallel_ranks = get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.PP)
num_data_parallel_ranks = get_num_parallel_ranks(
components.device_mesh, ParallelismDegrees.DP_SHARD
) * get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.DP_REPLICATE)
global_num_tokens_per_train_step = (
components.settings.step_profile.local_train_micro_batch_size
* components.settings.step_profile.sequence_length
* components.settings.step_profile.gradient_accumulation_steps
* components.settings.cuda_env.world_size
* num_data_parallel_ranks
)
trainer = Trainer(
global_rank=components.settings.cuda_env.global_rank,
Expand All @@ -128,6 +138,7 @@ def run(self, components: TrainingComponentsInstantiationModel):
gradient_clipper=components.gradient_clipper,
global_num_tokens_per_train_step=global_num_tokens_per_train_step,
mfu_calculator=components.mfu_calculator,
num_pipeline_parallel_ranks=num_pipeline_parallel_ranks,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer if we kept the Trainer high level and abstract away specifics like PP.

)

# Evaluator
Expand All @@ -143,7 +154,7 @@ def run(self, components: TrainingComponentsInstantiationModel):
loss_fun=components.loss_fn,
num_ranks=components.settings.cuda_env.world_size,
)
num_params = get_total_number_of_trainable_parameters(components.app_state.model)
num_params = get_total_number_of_trainable_parameters(components.app_state.model, components.device_mesh)
components.evaluation_subscriber.consume_dict({"No. parameters": num_params})
logging.info(f"Training model with {num_params} parameters.")

Expand All @@ -169,6 +180,7 @@ def run(self, components: TrainingComponentsInstantiationModel):
checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps,
evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps,
training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps,
scheduled_pipeline=components.scheduled_pipeline if components.scheduled_pipeline else None,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same point as for the trainer. Could we wrap the scheduled pipeline instead and use the existing model interfaces?

)

def get_logging_publishers(
Expand Down
Loading