diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml new file mode 100644 index 000000000..5d5557e6f --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp.yaml @@ -0,0 +1,411 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: ./data/lorem_ipsum_long.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 32 + evaluation_interval_in_steps: 32 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: distributed_sampler + config: + rank: ${settings.cuda_env.global_rank} + num_replicas: ${settings.cuda_env.world_size} + shuffle: false + drop_last: true + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +# mfu_calculator: +# component_key: mfu_calculator +# variant_key: gpt2 +# config: +# n_layer: ${model_raw.config.n_layer} +# sequence_length: ${settings.step_profile.sequence_length} +# n_embd: ${model_raw.config.n_embd} +# world_size: ${settings.cuda_env.world_size} +# raw_model: +# instance_key: model_raw +# pass_type: BY_REFERENCE +# wrapped_model: +# instance_key: initialized_model +# pass_type: BY_REFERENCE \ No newline at end of file diff --git a/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml new file mode 100644 index 000000000..f7b4835f6 --- /dev/null +++ b/config_files/training/config_lorem_ipsum_long_fsdp2_pp_tp.yaml @@ -0,0 +1,422 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: 8 + paths: + checkpoint_saving_path: data/checkpoints + train_dataset_path: /raid/s3/opengptx/user/richard-rutmann/data/modalities/gpt2_tokenized/000_00000.pbin + test_dataset_path: ./data/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 2 + checkpointing_interval_in_steps: 100000 + evaluation_interval_in_steps: 15 + consistency_enforcement: + enforce_tokens_per_step_consistency: true + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 16 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_num_steps + config: + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: ${settings.cuda_env.world_size} + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_steps: ${settings.training_target.num_target_steps} + num_target_steps: 20 + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +test_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.test_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +test_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: test + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: test_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: + - instance_key: test_dataloader + pass_type: BY_REFERENCE + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + tensor_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: wandb + config: + global_rank: ${settings.cuda_env.global_rank} + project: modalities_dcp_tests + mode: OFFLINE + experiment_id: ${settings.experiment_id} + directory: wandb_storage + config_file_path: ${settings.config_file_path} + +# mfu_calculator: +# component_key: mfu_calculator +# variant_key: gpt2 +# config: +# n_layer: ${model_raw.config.n_layer} +# sequence_length: ${settings.step_profile.sequence_length} +# n_embd: ${model_raw.config.n_embd} +# world_size: ${settings.cuda_env.world_size} +# raw_model: +# instance_key: model_raw +# pass_type: BY_REFERENCE +# wrapped_model: +# instance_key: initialized_model +# pass_type: BY_REFERENCE \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5a3c84bf1..1396b0cf5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,12 +1,10 @@ [project] name = "modalities" version = "0.3.2" -requires-python = ">=3.10,<3.12" description = "Modalities, a PyTorch-native framework for distributed and reproducible foundation model training." readme = "README.md" dependencies = [ "numpy<2.0", - "torch==2.6.0", "packaging", "tqdm", "pyyaml", diff --git a/src/modalities/config/instantiation_models.py b/src/modalities/config/instantiation_models.py index 4c57c133d..c4bb105e1 100644 --- a/src/modalities/config/instantiation_models.py +++ b/src/modalities/config/instantiation_models.py @@ -8,11 +8,13 @@ PydanticAppStateType, PydanticCheckpointSavingIFType, PydanticDatasetIFType, + PydanticDeviceMeshIFType, PydanticGradientClipperIFType, PydanticLLMDataLoaderIFType, PydanticLossIFType, PydanticMessageSubscriberIFType, PydanticMFUCalculatorABCType, + PydanticPipelineType, PydanticPytorchDeviceType, PydanticPytorchModuleType, PydanticTextInferenceComponentType, @@ -178,6 +180,8 @@ def _check_last_step_checkpointed(self) -> "TrainingComponentsInstantiationModel checkpoint_saving: PydanticCheckpointSavingIFType gradient_clipper: PydanticGradientClipperIFType mfu_calculator: Optional[PydanticMFUCalculatorABCType] = None + scheduled_pipeline: Optional[PydanticPipelineType] = None + device_mesh: Optional[PydanticDeviceMeshIFType] = None model_raw: PydanticPytorchModuleType @model_validator(mode="after") diff --git a/src/modalities/config/pydantic_if_types.py b/src/modalities/config/pydantic_if_types.py index aa12a444d..2aeceb53c 100644 --- a/src/modalities/config/pydantic_if_types.py +++ b/src/modalities/config/pydantic_if_types.py @@ -7,6 +7,7 @@ from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 +from torch.distributed.pipelining import PipelineStage from torch.optim import Optimizer from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import Sampler @@ -21,6 +22,7 @@ from modalities.inference.text.inference_component import TextInferenceComponent from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.loss_functions import Loss +from modalities.models.parallelism.pipeline_parallelism import Pipeline, StagesGenerator from modalities.nn.model_initialization.initialization_if import ModelInitializationIF from modalities.tokenization.tokenizer_wrapper import TokenizerWrapper from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -83,3 +85,6 @@ def __get_pydantic_core_schema__( PydanticDatasetBatchGeneratorIFType = Annotated[ DatasetBatchGeneratorIF, PydanticThirdPartyTypeIF(DatasetBatchGeneratorIF) ] +PydanticStagesGeneratorType = Annotated[StagesGenerator, PydanticThirdPartyTypeIF(StagesGenerator)] +PydanticPipelineType = Annotated[Pipeline, PydanticThirdPartyTypeIF(Pipeline)] +PydanticPipelineStageType = Annotated[PipelineStage, PydanticThirdPartyTypeIF(PipelineStage)] diff --git a/src/modalities/conversion/gpt2/conversion_model.py b/src/modalities/conversion/gpt2/conversion_model.py index 7b06e3ec0..89fbf194a 100644 --- a/src/modalities/conversion/gpt2/conversion_model.py +++ b/src/modalities/conversion/gpt2/conversion_model.py @@ -136,10 +136,10 @@ def _copy_weights_model(hf_model: GPT2ForCausalLM, modalities_model: GPT2LLM): modalities_model (GPT2LLM): The modalities model from which the weights will be copied. """ hf_model.model.embed_tokens.weight.data.copy_(modalities_model.transformer.wte.weight.data) - for hf_layer, modalities_layer in zip(hf_model.model.layers, modalities_model.transformer.h): - _copy_weights_attention(hf_layer, modalities_layer) - _copy_weights_mlp(hf_layer, modalities_layer) - _copy_weights_layer_norms(hf_layer, modalities_layer) + for hf_layer, modalities_layer_idx in zip(hf_model.model.layers, modalities_model.transformer.h): + _copy_weights_attention(hf_layer, modalities_model.transformer.h[modalities_layer_idx]) + _copy_weights_mlp(hf_layer, modalities_model.transformer.h[modalities_layer_idx]) + _copy_weights_layer_norms(hf_layer, modalities_model.transformer.h[modalities_layer_idx]) _copy_weights_base_modules(hf_model.lm_head, modalities_model.transformer.lm_head) _copy_weights_base_modules(hf_model.model.norm, modalities_model.transformer.lm_head_norm) diff --git a/src/modalities/evaluator.py b/src/modalities/evaluator.py index 456fcb47f..3f9f8f343 100644 --- a/src/modalities/evaluator.py +++ b/src/modalities/evaluator.py @@ -9,6 +9,7 @@ from modalities.logging_broker.messages import ExperimentStatus, MessageTypes, ProgressUpdate from modalities.logging_broker.publisher import MessagePublisher from modalities.models.model import model_predict_batch +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.running_env.fsdp.reducer import Reducer from modalities.trainer import ThroughputAggregationKeys from modalities.util import Aggregator, TimeRecorder @@ -36,20 +37,42 @@ def evaluate_batch( batch: DatasetBatch, model: nn.Module, loss_fun: Callable[[InferenceResultBatch], torch.Tensor], - ) -> torch.Tensor: + scheduled_pipeline: Pipeline | None = None, + ) -> torch.Tensor | None: """Evaluate a single batch by forwarding it through the model and calculating the loss. Args: batch (DatasetBatch): The batch to evaluate model (nn.Module): The model to evaluate loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: - torch.Tensor: The loss of the batch + torch.Tensor | None: The loss of the batch + None, if a non-last stage was processed in pipeline parallelism """ with torch.no_grad(): - result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) + if scheduled_pipeline is not None: + pp_schedule = scheduled_pipeline.pp_schedule + targets, losses = ( + (batch.targets[loss_fun.target_key].contiguous(), []) + if scheduled_pipeline.is_last_pp_stage + else (None, None) + ) + + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.eval(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) + else: + pp_schedule.eval(target=targets, losses=losses) + loss = ( + torch.mean(torch.stack(losses)).to(losses[0].device) + if scheduled_pipeline.is_last_pp_stage + else None + ) + else: + result_batch = model_predict_batch(model=model, batch=batch) + loss = loss_fun(result_batch) return loss def evaluate( @@ -58,6 +81,7 @@ def evaluate( data_loaders: list[LLMDataLoader], loss_fun: Callable[[InferenceResultBatch], torch.Tensor], num_train_steps_done: int, + scheduled_pipeline: Pipeline | None = None, ) -> dict[str, EvaluationResultBatch]: """Evaluate the model on a set of datasets. @@ -66,6 +90,8 @@ def evaluate( data_loaders (list[LLMDataLoader]): List of dataloaders to evaluate the model on loss_fun (Callable[[InferenceResultBatch], torch.Tensor]): The loss function to calculate the loss num_train_steps_done (int): The number of training steps done so far for logging purposes + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: dict[str, EvaluationResultBatch]: A dictionary containing the evaluation results for each dataloader @@ -90,10 +116,13 @@ def evaluate( batch=batch, model=model, loss_fun=loss_fun, + scheduled_pipeline=scheduled_pipeline, ) - cumulated_loss[0] += batch_loss.item() # sum up batch loss - cumulated_loss[1] += 1 + # The batch_loss might be None if we use pipeline parallelism and are not the last stage. + if batch_loss is not None: + cumulated_loss[0] += batch_loss.item() # sum up batch loss + cumulated_loss[1] += 1 batch_length_tensor = torch.tensor(len(batch)).to(device) thoughput_aggregator.add_value(key=ThroughputAggregationKeys.NUM_SAMPLES, value=batch_length_tensor) diff --git a/src/modalities/gym.py b/src/modalities/gym.py index 0394b7a28..7ea5e660f 100644 --- a/src/modalities/gym.py +++ b/src/modalities/gym.py @@ -9,6 +9,7 @@ from modalities.dataloader.dataloader import LLMDataLoader from modalities.evaluator import Evaluator from modalities.loss_functions import Loss +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.trainer import Trainer from modalities.training.training_progress import TrainingProgress from modalities.util import print_rank_0 @@ -40,6 +41,7 @@ def run( train_data_loader: LLMDataLoader, evaluation_data_loaders: list[LLMDataLoader], checkpoint_saving: CheckpointSaving, + scheduled_pipeline: Pipeline | None = None, ): """Runs the model training, including evaluation and checkpointing. @@ -51,12 +53,15 @@ def run( train_data_loader (LLMDataLoader): Data loader with the training data. evaluation_data_loaders (list[LLMDataLoader]): List of data loaders with the evaluation data. checkpoint_saving (CheckpointSaving): Routine for saving checkpoints. + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. """ evaluation_callback: Callable[[int], None] = partial( self._run_evaluation, model=app_state.model, evaluation_data_loaders=evaluation_data_loaders, evaluation_interval_in_steps=evaluation_interval_in_steps, + scheduled_pipeline=scheduled_pipeline, ) checkpointing_callback: Callable[[TrainingProgress], None] = partial( @@ -74,6 +79,7 @@ def run( evaluation_callback=evaluation_callback, checkpointing_callback=checkpointing_callback, training_log_interval_in_steps=training_log_interval_in_steps, + scheduled_pipeline=scheduled_pipeline, ) print_rank_0(f"Training done at {datetime.now()}.") @@ -101,11 +107,13 @@ def _run_evaluation( num_train_steps_done: int, evaluation_data_loaders: list[LLMDataLoader], evaluation_interval_in_steps: int, + scheduled_pipeline: Pipeline | None = None, ): - if num_train_steps_done % evaluation_interval_in_steps == 0: + if num_train_steps_done > 0 and num_train_steps_done % evaluation_interval_in_steps == 0: self.evaluator.evaluate( model=model, data_loaders=evaluation_data_loaders, loss_fun=self.loss_fun, num_train_steps_done=num_train_steps_done, + scheduled_pipeline=scheduled_pipeline, ) diff --git a/src/modalities/loss_functions.py b/src/modalities/loss_functions.py index 54d8de36b..e3be6100d 100644 --- a/src/modalities/loss_functions.py +++ b/src/modalities/loss_functions.py @@ -1,4 +1,5 @@ from abc import ABC, abstractmethod +from typing import overload import torch from torch.nn import CrossEntropyLoss @@ -31,9 +32,16 @@ def __init__(self, target_key: str, prediction_key: str, tag: str = "CLMCrossEnt # Mean over the tokens in the local-batch (batch per rank) self.loss_fun = CrossEntropyLoss(reduction="mean") + @overload def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: - labels = forward_batch.get_targets(self.target_key) - lm_logits = forward_batch.get_predictions(self.prediction_key) + ... + + @overload + def __call__(self, outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + ... + + def __call__(self, *args, **kwargs) -> torch.Tensor: + labels, lm_logits = self._parse_arguments(args, kwargs) # move labels to correct device to enable model parallelism labels = labels.to(lm_logits.device) @@ -43,6 +51,41 @@ def __call__(self, forward_batch: InferenceResultBatch) -> torch.Tensor: loss = self.loss_fun(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)) return loss + def _parse_arguments( + self, + args: list[torch.Tensor] | list[InferenceResultBatch], + kwargs: dict[str, torch.Tensor] | dict[str, InferenceResultBatch], + ) -> tuple[torch.Tensor, torch.Tensor]: + if len(args) == 1 and isinstance(args[0], InferenceResultBatch): + forward_batch = args[0] + labels = forward_batch.get_targets(self.target_key) + lm_logits = forward_batch.get_predictions(self.prediction_key) + elif "forward_batch" in kwargs and isinstance(kwargs["forward_batch"], InferenceResultBatch): + forward_batch = kwargs["forward_batch"] + labels = forward_batch.get_targets(self.target_key) + lm_logits = forward_batch.get_predictions(self.prediction_key) + elif len(args) == 2 and all(isinstance(arg, torch.Tensor) for arg in args): + lm_logits, labels = args + elif ( + "outputs" in kwargs + and "targets" in kwargs + and isinstance(kwargs["outputs"], torch.Tensor) + and isinstance(kwargs["targets"], torch.Tensor) + ): + lm_logits = kwargs["outputs"] + labels = kwargs["targets"] + elif ( + len(args) == 1 + and "targets" in kwargs + and isinstance(args[0], torch.Tensor) + and isinstance(kwargs["targets"], torch.Tensor) + ): + lm_logits = args[0] + labels = kwargs["targets"] + else: + raise TypeError("Invalid arguments for CLMCrossEntropyLoss.__call__") + return labels, lm_logits + def nce_loss( embedding1: torch.Tensor, embedding2: torch.Tensor, device: torch.device, is_asymmetric: bool, temperature: float diff --git a/src/modalities/main.py b/src/modalities/main.py index d995b9168..59845376f 100644 --- a/src/modalities/main.py +++ b/src/modalities/main.py @@ -20,6 +20,7 @@ from modalities.logging_broker.subscriber import MessageSubscriberIF from modalities.registry.components import COMPONENTS from modalities.registry.registry import Registry +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_num_parallel_ranks from modalities.trainer import Trainer from modalities.util import get_synced_experiment_id_of_run, get_total_number_of_trainable_parameters, print_rank_0 @@ -110,11 +111,20 @@ def run(self, components: TrainingComponentsInstantiationModel): ) # Trainer + # FIXME replace by get_parallel_degree + if components.device_mesh is None: + num_pipeline_parallel_ranks = 1 + num_data_parallel_ranks = 1 + else: + num_pipeline_parallel_ranks = get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.PP) + num_data_parallel_ranks = get_num_parallel_ranks( + components.device_mesh, ParallelismDegrees.DP_SHARD + ) * get_num_parallel_ranks(components.device_mesh, ParallelismDegrees.DP_REPLICATE) global_num_tokens_per_train_step = ( components.settings.step_profile.local_train_micro_batch_size * components.settings.step_profile.sequence_length * components.settings.step_profile.gradient_accumulation_steps - * components.settings.cuda_env.world_size + * num_data_parallel_ranks ) trainer = Trainer( global_rank=components.settings.cuda_env.global_rank, @@ -128,6 +138,7 @@ def run(self, components: TrainingComponentsInstantiationModel): gradient_clipper=components.gradient_clipper, global_num_tokens_per_train_step=global_num_tokens_per_train_step, mfu_calculator=components.mfu_calculator, + num_pipeline_parallel_ranks=num_pipeline_parallel_ranks, ) # Evaluator @@ -143,7 +154,7 @@ def run(self, components: TrainingComponentsInstantiationModel): loss_fun=components.loss_fn, num_ranks=components.settings.cuda_env.world_size, ) - num_params = get_total_number_of_trainable_parameters(components.app_state.model) + num_params = get_total_number_of_trainable_parameters(components.app_state.model, components.device_mesh) components.evaluation_subscriber.consume_dict({"No. parameters": num_params}) logging.info(f"Training model with {num_params} parameters.") @@ -169,6 +180,7 @@ def run(self, components: TrainingComponentsInstantiationModel): checkpointing_interval_in_steps=components.settings.intervals.checkpointing_interval_in_steps, evaluation_interval_in_steps=components.settings.intervals.evaluation_interval_in_steps, training_log_interval_in_steps=components.settings.intervals.training_log_interval_in_steps, + scheduled_pipeline=components.scheduled_pipeline if components.scheduled_pipeline else None, ) def get_logging_publishers( diff --git a/src/modalities/models/gpt2/gpt2_model.py b/src/modalities/models/gpt2/gpt2_model.py index a2022d716..3e27ec5d5 100644 --- a/src/modalities/models/gpt2/gpt2_model.py +++ b/src/modalities/models/gpt2/gpt2_model.py @@ -2,7 +2,7 @@ import math from abc import abstractmethod from enum import Enum -from typing import Annotated, Optional +from typing import Annotated, Optional, overload import torch import torch.nn as nn @@ -319,7 +319,7 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config (LayerNormWrapperConfig): Config for normalization of the feed-forward network. lm_head_norm_config (LayerNormWrapperConfig): Config for normalization of the language model head. use_weight_tying (bool): Whether to use weight tying. - + seed (int, optional): The seed for random number generation. Defaults to None. """ sample_key: str @@ -344,6 +344,7 @@ class GPT2LLMConfig(BaseModel): ffn_norm_config: LayerNormWrapperConfig lm_head_norm_config: LayerNormWrapperConfig use_weight_tying: bool + seed: Optional[int] = None @model_validator(mode="after") def check_divisibility(self) -> "GPT2LLMConfig": @@ -780,7 +781,7 @@ def __init__( ffn_norm_config: LayerNormWrapperConfig, lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, - seed: int = None, + seed: int | None = None, ): """ Initializes the GPT2LLM object. @@ -804,8 +805,8 @@ def __init__( attention_norm_config (LayerNormWrapperConfig): Config for the attention normalization module. ffn_norm_config (LayerNormWrapperConfig): Config for the feed-forward network normalization module. lm_head_norm_config (LayerNormWrapperConfig): Config for the language model head normalization module. - seed (int, optional): The random seed. Defaults to None. use_weight_tying (bool): Whether to use weight tying. + seed (Optional[int]): The random seed. Defaults to None. """ weight_decay_groups = { "linear": [".attn", ".mlp", ".lm_head.weight"], @@ -844,9 +845,9 @@ def __init__( wte=nn.Embedding(num_embeddings=vocab_size, embedding_dim=n_embd), wpe=wpe, drop=nn.Dropout(dropout), - h=nn.ModuleList( - [ - GPT2Block( + h=nn.ModuleDict( + { + str(layer_idx): GPT2Block( n_embd=n_embd, bias=bias, n_head_q=n_head_q, @@ -862,8 +863,8 @@ def __init__( attention_norm=attention_norm_config.norm_type.value(**dict(attention_norm_config.config)), ffn_norm=ffn_norm_config.norm_type.value(**dict(ffn_norm_config.config)), ) - for _ in range(n_layer) - ] + for layer_idx in range(n_layer) + } ), lm_head_norm=lm_head_norm_config.norm_type.value(**dict(lm_head_norm_config.config)), # NOTE: If we make the bias configurable, we must update the number of parameters calculation @@ -880,9 +881,10 @@ def __init__( self.transformer.lm_head.weight ) # https://paperswithcode.com/method/weight-tying - def forward_impl(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + @overload + def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: """ - Forward pass implementation of the GPT2LLM module. + Forward pass of the GPT2LLM module. Args: inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. @@ -892,42 +894,69 @@ def forward_impl(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tenso dict[str, torch.Tensor]: A dictionary containing output tensors. - prediction_key (str): Key for the output tensor containing logits. """ - input_ids = inputs[self.sample_key] - device = input_ids.device - _, t = input_ids.size() # batch size, sequence length - assert t <= self.sequence_length, f"Cannot forward sequence of length {t}, the model's maximum " - f"input sequence length is only {self.sequence_length}" + ... - # forward the GPT model itself - tok_emb = self.transformer.wte(input_ids) # token embeddings of shape (b, t, n_embd) - - if self.poe_type is PositionTypes.ABSOLUTE: - pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t) - pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd) - tok_emb = tok_emb + pos_emb + @overload + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the GPT2LLM module. - # TODO: use drop out also without absolute position embedding? - x = self.transformer.drop(tok_emb) + Args: + inputs (torch.Tensor): A tensor containing input token ids. - for block in self.transformer.h: - x = block(x) - x = self.transformer.lm_head_norm(x) - logits = self.transformer.lm_head(x) - return {self.prediction_key: logits} + Returns: + torch.Tensor: A tensor containing output logits. + """ + ... - def forward(self, inputs: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]: + def forward(self, inputs: dict[str, torch.Tensor] | torch.Tensor) -> dict[str, torch.Tensor] | torch.Tensor: """ Forward pass of the GPT2LLM module. Args: - inputs (dict[str, torch.Tensor]): A dictionary containing input tensors. - - sample_key (str): Key for the input tensor containing token ids. + inputs (dict[str, torch.Tensor] | torch.Tensor): Input data. Returns: - dict[str, torch.Tensor]: A dictionary containing output tensors. - - prediction_key (str): Key for the output tensor containing logits. + dict[str, torch.Tensor] | torch.Tensor: Model output. + """ + if isinstance(inputs, dict): + return {self.prediction_key: self.forward_impl(inputs[self.sample_key])} + else: + return self.forward_impl(inputs) + + def forward_impl(self, inputs: torch.Tensor) -> torch.Tensor: + """ + Forward pass implementation of the GPT2LLM module. + + Args: + inputs (torch.Tensor): A tensor containing input token ids. + + Returns: + torch.Tensor: A tensor containing output logits. """ - return self.forward_impl(inputs) + device = inputs.device + seq_len = inputs.size(1) + assert seq_len <= self.sequence_length, f"Cannot forward sequence of length {seq_len}, the model's maximum " + f"input sequence length is only {self.sequence_length}." + + # forward the GPT model itself + h = ( + self.transformer.wte(inputs) if hasattr(self.transformer, "wte") else inputs + ) # token embeddings of shape (b, seq_len, n_embd) + + if self.poe_type is PositionTypes.ABSOLUTE and hasattr(self.transformer, "wpe"): + pos = torch.arange(0, seq_len, dtype=torch.long, device=device) # shape (seq_len) + pos_emb = self.transformer.wpe(pos) # position embeddings of shape (seq_len, n_embd) + h = h + pos_emb + + # TODO: use drop out also without absolute position embedding? + h = self.transformer.drop(h) if hasattr(self.transformer, "drop") else h + + for layer_idx in self.transformer.h: + h = self.transformer.h[layer_idx](h) + h = self.transformer.lm_head_norm(h) if hasattr(self.transformer, "lm_head_norm") else h + h = self.transformer.lm_head(h) if hasattr(self.transformer, "lm_head") else h + return h def manual_scaled_dot_product_attention( diff --git a/src/modalities/models/model_factory.py b/src/modalities/models/model_factory.py index 877c9cbdc..7df9ba258 100644 --- a/src/modalities/models/model_factory.py +++ b/src/modalities/models/model_factory.py @@ -58,6 +58,15 @@ class ModelFactory: @staticmethod def _is_model_on_meta_device(model: nn.Module) -> bool: + """ + Checks if all parameters and buffers of the model are on the meta device. + + Args: + model (nn.Module): The model to check. + + Returns: + bool: True if all parameters and buffers are on meta device, False otherwise. + """ meta_counter = 0 param_counter = 0 for _, tensor in itertools.chain(model.named_parameters(), model.named_buffers()): @@ -567,7 +576,7 @@ def get_gpt2_model( lm_head_norm_config: LayerNormWrapperConfig, use_weight_tying: bool, use_meta_device: Optional[bool] = False, - seed: int = None, + seed: int | None = None, ) -> GPT2LLM: config = dict( sample_key=sample_key, @@ -631,7 +640,7 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) ), } - if isinstance(model.transformer.wpe, nn.Embedding): + if hasattr(model.transformer, "wpe") and isinstance(model.transformer.wpe, nn.Embedding): # If the position embedding is an nn.Embedding, we can shard it on the sequence dimension # to enable sequence parallelism in the downstream transformer blocks. # Note, for RoPE the wpe layer is an identity operation, which cannnot be sharded. @@ -640,11 +649,14 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) output_layouts=Shard(0), ) - parallelize_module( - module=model, - device_mesh=tp_mesh, - parallelize_plan=model_tp_plan, - ) + # only keep the relevant parts of the model parallel plan + model_tp_plan = {k: v for k, v in model_tp_plan.items() if hasattr(model.transformer, k.split(".")[1])} + if model_tp_plan: + parallelize_module( + module=model, + device_mesh=tp_mesh, + parallelize_plan=model_tp_plan, + ) transformer_block_tp_plan = { "attention_norm": SequenceParallel(), @@ -671,13 +683,13 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) desired_input_layouts=(Replicate(),), ), } - if isinstance(model.transformer.h[0].mlp, SwiGLU): + if isinstance(list(model.transformer.h.values())[0].mlp, SwiGLU): mlp_plan = { "mlp.W": ColwiseParallel(), "mlp.W_2": RowwiseParallel(output_layouts=Shard(1)), "mlp.V": ColwiseParallel(), } - elif isinstance(model.transformer.h[0].mlp, TransformerMLP): + elif isinstance(list(model.transformer.h.values())[0].mlp, TransformerMLP): mlp_plan = { "mlp.c_fc": ColwiseParallel(), "mlp.c_proj": RowwiseParallel(output_layouts=Shard(1)), @@ -689,7 +701,7 @@ def get_gpt2_tensor_parallelized_model(model: GPT2LLM, device_mesh: DeviceMesh) ) transformer_block_tp_plan.update(mlp_plan) - for transformer_block in model.transformer.h: + for transformer_block in model.transformer.h.values(): # override the number of q and kv heads if transformer_block.attn.n_head_q % tp_mesh.size() != 0: raise ValueError( diff --git a/src/modalities/models/parallelism/__init__.py b/src/modalities/models/parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/modalities/models/parallelism/pipeline_parallelism.py b/src/modalities/models/parallelism/pipeline_parallelism.py new file mode 100644 index 000000000..9d7e97718 --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism.py @@ -0,0 +1,288 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import copy +import re +from enum import Enum +from typing import Any, Optional, Type, cast + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh +from torch.distributed.pipelining import PipelineStage +from torch.distributed.pipelining.schedules import PipelineScheduleSingle, get_schedule_class + +from modalities.loss_functions import Loss +from modalities.models.model import NNModel +from modalities.models.parallelism.stages_generator import StagesGenerator +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from modalities.utils.logger_utils import get_logger + +logger = get_logger(__name__) + + +class Pipeline: + def __init__( + self, + pp_stage: PipelineStage, + model_part: nn.Module, + pp_schedule: Optional[PipelineScheduleSingle] = None, + ): + self._pp_stage = pp_stage + self._model_part = model_part + self._pp_schedule = pp_schedule + + @property + def is_first_pp_stage(self) -> bool: + return self._pp_stage.is_first + + @property + def is_last_pp_stage(self) -> bool: + return self._pp_stage.is_last + + @property + def pp_stage(self) -> PipelineStage: + return self._pp_stage + + @property + def model_part(self) -> nn.Module: + return self._model_part + + @property + def pp_schedule(self) -> Optional[PipelineScheduleSingle]: + return self._pp_schedule + + @pp_schedule.setter + def pp_schedule(self, schedule: PipelineScheduleSingle): + self._pp_schedule = schedule + + +class PipelineSelectionTypes(Enum): + """Enum for pipeline selection types.""" + + PP_STAGE = "PP_STAGE" + MODEL_PART = "MODEL_PART" + PP_SCHEDULE = "PP_SCHEDULE" + + +class ComponentSelectorFromPipeline: + @staticmethod + def select(pipeline: Pipeline, selection_type: PipelineSelectionTypes) -> Any: + """Selects a component from the pipeline based on the selection type.""" + if selection_type == PipelineSelectionTypes.PP_STAGE: + return pipeline.pp_stage + elif selection_type == PipelineSelectionTypes.MODEL_PART: + return pipeline.model_part + elif selection_type == PipelineSelectionTypes.PP_SCHEDULE: + return pipeline.pp_schedule + else: + raise ValueError(f"Unsupported selection type: {selection_type}") + + +class PipelineFactory: + """Pipeline factory class to create pipelined models.""" + + @staticmethod + def get_pipeline( + pp_stage: PipelineStage, model_part: NNModel, pp_schedule: Optional[PipelineScheduleSingle] = None + ) -> Pipeline: + return Pipeline(pp_stage=pp_stage, model_part=model_part, pp_schedule=pp_schedule) + + @staticmethod + def get_staged_pipeline( + whole_model: NNModel, + stages_generator: StagesGenerator, + device_mesh: DeviceMesh, + local_rank: int, + pp_schedule_name: str, + num_layers_per_stage: int, + ) -> Pipeline: + device = torch.device("cuda", local_rank) + pp_dims = device_mesh[ParallelismDegrees.PP.value].size() + + fqns_per_stage = stages_generator.get_stages( + num_layers_per_stage=num_layers_per_stage, + pp_dims=pp_dims, + ) + + pp_mesh = device_mesh[ParallelismDegrees.PP.value] + schedule_class = get_schedule_class(pp_schedule_name) + is_single_stage_schedule = issubclass(schedule_class, PipelineScheduleSingle) + if not is_single_stage_schedule: + raise ValueError( + f"Unsupported pipeline schedule: {pp_schedule_name}. We only support single-stage schedules." + ) + # torchtitan returns tuple of stages and models as depending on the schedule + # we might have multiple stages and model parts per rank. + # So far we don't support multi-stage schedules, which is why instead of tuples + # we work directly with the stage and model. + pp_stage, model_part = PipelineFactory._get_split_model( + whole_model=whole_model, + schedule_class=schedule_class, + pp_mesh=pp_mesh, + device=device, + fqns_per_stage=fqns_per_stage, + ) + + pipeline = Pipeline(pp_stage=pp_stage, model_part=model_part) + return pipeline + + @staticmethod + def _get_split_model( + whole_model: NNModel, + schedule_class: Type[PipelineScheduleSingle], + pp_mesh: DeviceMesh, + device: torch.device, + fqns_per_stage: list[list[str]], + ) -> tuple[PipelineStage, NNModel]: + def get_stage_id_of_pp_rank(pp_mesh: DeviceMesh): + # NOTE: torch titan a more complicated way to get the stage id of pp rank + # since they also allow for multi-stage schedules + pp_rank = pp_mesh.get_local_rank() + return pp_rank + + @staticmethod + def _get_fqn_tree(fqns: list[str]) -> dict[str, Any]: + fqn_tree = {} + fqns = set(fqns) # Ensure unique FQNs + for fqn in fqns: + parts = fqn.split(".") + current_level = fqn_tree + for part in parts[:-1]: + if part not in current_level: + current_level[part] = {} + elif len(current_level) == 0: + raise ValueError(f"Part {part} of {fqn} already exists " "in the tree as a leaf node.") + current_level = current_level[part] + if parts[-1] in current_level: + raise ValueError( + f" Leaf of {fqn} has already been defined in the tree as an intermediadate node or leaf! " + "Cannot replace the existing node as a leaf." + ) + current_level[parts[-1]] = {} + + return fqn_tree + + def _build_stage_from_modules( + fqn_tree: dict[str, Any], module: nn.Module, module_name: Optional[str] = None + ) -> nn.Module: + if isinstance(module, nn.ModuleDict): + if module_name not in fqn_tree: + dict_modules = nn.ModuleDict({}) + else: + if len(fqn_tree) == 0: + # If the module is a leaf node, we can directly use it + dict_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + dict_modules = {} + dict_module_names = [name for name in module.keys() if name in fqn_tree[module_name]] + for dict_module_name in dict_module_names: + dict_modules[dict_module_name] = _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], + module=module[dict_module_name], + module_name=dict_module_name, + ) + dict_modules = nn.ModuleDict(dict_modules) + # setattr(module, module_name, dict_modules) + return dict_modules + + elif isinstance(module, nn.ModuleList): + if module_name not in fqn_tree: + list_modules = nn.ModuleList([]) + else: + if len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + list_modules = module + else: + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + list_modules = [] + list_indices = [i for i in range(len(module)) if str(i) in fqn_tree[module_name]] + for idx in list_indices: + list_modules.append( + _build_stage_from_modules( + fqn_tree=fqn_tree[module_name], module=module[idx], module_name=str(idx) + ) + ) + list_modules = nn.ModuleList(list_modules) + # setattr(module, module_name, list_modules) + return list_modules + + else: # normal nn.Module + if module_name is not None and module_name not in fqn_tree: + # If the module is not in the FQN tree, set it to None + return None + elif module_name is not None and len(fqn_tree[module_name]) == 0: + # If the module is a leaf node, we can directly use it + return module + else: + # If the module is in the FQN tree, we need to build a staged module + # recursively from the FQN tree + for module_name, module_value in module.named_children(): + # If the module is not a leaf node, we need to build a staged module + # recursively from the FQN tree + staged_module = _build_stage_from_modules( + fqn_tree=fqn_tree, module=module_value, module_name=module_name + ) + setattr(module, module_name, staged_module) + + return module + + if not issubclass(schedule_class, PipelineScheduleSingle): + raise NotImplementedError("Only single-stage schedules are supported for pipeline parallelism.") + + # NOTE: For multi-stage schedule, e.g., Interleaved 1F1B, we have multiple stages per pp rank. + # This would need to be adapted accordingly in this case. + stage_idx = get_stage_id_of_pp_rank(pp_mesh) + module_names = fqns_per_stage[stage_idx] + whole_model = copy.deepcopy(whole_model) + fqn_tree = _get_fqn_tree(module_names) + stage_modules = _build_stage_from_modules(fqn_tree, whole_model) + stage_modules = cast(NNModel, stage_modules) + PipelineFactory._filter_weight_decay_groups_(stage_modules) + stage = PipelineStage( + submodule=stage_modules, + stage_index=stage_idx, + num_stages=len(fqns_per_stage), + device=device, + group=pp_mesh.get_group("pp"), + ) + return stage, stage_modules + + @staticmethod + def _filter_weight_decay_groups_(stage_modules: NNModel): + params = {name for name, parameter in stage_modules.named_parameters() if parameter.requires_grad} + for group_list in stage_modules.weight_decay_groups.values(): + remove_from_group = [ + group_entry + for group_entry in group_list + if all([not bool(re.search(group_entry, name)) for name in params]) + ] + for remove in remove_from_group: + group_list.remove(remove) + empty_group_keys = [k for k, v in stage_modules.weight_decay_groups.items() if len(v) == 0] + for key in empty_group_keys: + del stage_modules.weight_decay_groups[key] + + @staticmethod + def get_scheduled_pipeline( + loss_fn: Loss, pp_schedule_name: str, batch_size: int, microbatch_size: int, pp_degree: int, pipeline: Pipeline + ) -> Pipeline: + # TODO: Addd validation in config that batch_size is divisible by microbatch_size + # and n_microbatches must be >= pp_degree + n_microbatches = batch_size // microbatch_size + num_total_stages = pp_degree + pp_schedule_class = get_schedule_class(pp_schedule_name) + pp_schedule = pp_schedule_class( + stage=pipeline.pp_stage, + n_microbatches=n_microbatches, + loss_fn=loss_fn, + ) + logger.info( + f"Using pipeline schedule {pp_schedule} with {n_microbatches} microbatches and {num_total_stages} stages." + ) + pipeline.pp_schedule = pp_schedule + return pipeline diff --git a/src/modalities/models/parallelism/pipeline_parallelism_configs.py b/src/modalities/models/parallelism/pipeline_parallelism_configs.py new file mode 100644 index 000000000..831a6e15e --- /dev/null +++ b/src/modalities/models/parallelism/pipeline_parallelism_configs.py @@ -0,0 +1,46 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + +from modalities.config.pydantic_if_types import ( + PydanticDeviceMeshIFType, + PydanticLossIFType, + PydanticPipelineStageType, + PydanticPipelineType, + PydanticPytorchModuleType, + PydanticStagesGeneratorType, +) +from modalities.models.parallelism.pipeline_parallelism import PipelineSelectionTypes + + +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate + pass + + +class StagedPipelineConfig(BaseModel): + whole_model: PydanticPytorchModuleType + stages_generator: PydanticStagesGeneratorType + device_mesh: PydanticDeviceMeshIFType + local_rank: Annotated[int, Field(strict=True, ge=0)] + pp_schedule_name: str + num_layers_per_stage: Annotated[int, Field(strict=True, ge=1)] + + +class ScheduledPipelineConfig(BaseModel): + loss_fn: PydanticLossIFType + pp_schedule_name: str + batch_size: Annotated[int, Field(strict=True, ge=1)] + microbatch_size: Annotated[int, Field(strict=True, ge=1)] + pp_degree: Annotated[int, Field(strict=True, ge=2)] + pipeline: PydanticPipelineType + + +class ComponentSelectorFromPipelineConfig(BaseModel): + pipeline: PydanticPipelineType + selection_type: PipelineSelectionTypes + + +class PipelineConfig(BaseModel): + pp_stage: PydanticPipelineStageType + model_part: PydanticPytorchModuleType + pp_schedule: PydanticPipelineType | None = None diff --git a/src/modalities/models/parallelism/stages_generator.py b/src/modalities/models/parallelism/stages_generator.py new file mode 100644 index 000000000..0a212672a --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator.py @@ -0,0 +1,120 @@ +# Some portions of this implementation are inspired, adapted, or refactored +# from Meta's open-source project TorchTitan, +# licensed under the BSD 3-Clause License. + +import math +from abc import ABC, abstractmethod + + +class StagesGenerator(ABC): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + self._num_model_layers = num_model_layers + self._input_layer_equivalence = input_layer_equivalence + self._output_layer_equivalence = output_layer_equivalence + + def get_stages(self, num_layers_per_stage: int, pp_dims: int) -> list[list[str]]: + """ + Generate FQNs for each stage in a GPT-2 model. + + Args: + num_layers_per_stage (int): Number of layers per stage. + pp_dims (int): Number of pipeline parallel dimensions. + + Returns: + list[list[str]]: A list containing FQNs for each stage. + """ + + # calculate the number of stages + num_virtual_stages = math.ceil( + (self._num_model_layers + self._input_layer_equivalence + self._output_layer_equivalence) + / num_layers_per_stage + ) + if num_virtual_stages % pp_dims != 0: + raise ValueError( + f"Number of virtual stages {num_virtual_stages} is not divisible by parallel dimensions {pp_dims}. " + f"For reference: {self._num_model_layers=} {self._input_layer_equivalence=} " + f"{self._output_layer_equivalence=} {num_layers_per_stage=}" + ) + + stages_per_rank = num_virtual_stages // pp_dims + if stages_per_rank != 1: + raise ValueError( + f"Stages per rank {stages_per_rank} must be 1 for single-stage schedules. " + f"Please adjust {num_layers_per_stage=} to ensure each PP rank has exactly one stage." + ) + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = self._get_potential_split_points() + # Calculate the weight per stage based on the total weight and number of stages + weight_per_stage = math.ceil(sum(weight for _, weight in potential_split_points) / num_virtual_stages) + # pack the stages with the layers + next_split_point = 0 + module_names_per_stage: list[list[str]] = [] + for _ in range(num_virtual_stages): + stage_fqns = [] + stage_weight = 0 + while next_split_point < len(potential_split_points): + fqns, weight = potential_split_points[next_split_point] + if weight > weight_per_stage: + raise ValueError( + f"Weight of {weight} for {fqns} exceeds weight per stage {weight_per_stage}. " + "Please adjust the number of stages or the weight distribution." + ) + if stage_weight + weight > weight_per_stage: + break + stage_fqns.extend(fqns) + stage_weight += weight + next_split_point += 1 + module_names_per_stage.append(stage_fqns) + + return module_names_per_stage + + @abstractmethod + def _get_potential_split_points(self) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + raise NotImplementedError("This method should be implemented by subclasses.") + + +class GPT2LLMStagesGenerator(StagesGenerator): + def __init__(self, num_model_layers: int, input_layer_equivalence: int = 1, output_layer_equivalence: int = 1): + super().__init__(num_model_layers, input_layer_equivalence, output_layer_equivalence) + + def _get_potential_split_points( + self, + ) -> list[tuple[list[str], int]]: + """ + Returns a list of potential split points for the GPT-2 model. + + Args: + num_model_layers (int): Total number of layers in the model. + input_layer_equivalence (int): Number of layers corresponding to the input layer. + output_layer_equivalence (int): Number of layers corresponding to the output layer. + + Returns: + list[tuple[list[str], int]]: A list containing tuples of FQNs and their computational weights. + """ + + # Potential split points for GPT-2 model with each potential split point + # listing the FQNs of the modules in that stage and the computational weight. + # The computational weight of the input and output modules are estimated + # based on the number of layers they correspond to. + potential_split_points = [ + (["transformer.wte", "transformer.wpe", "transformer.drop"], self._input_layer_equivalence), + *[([f"transformer.h.{i}"], 1) for i in range(self._num_model_layers)], + (["transformer.lm_head_norm", "transformer.lm_head"], self._output_layer_equivalence), + ] + + return potential_split_points diff --git a/src/modalities/models/parallelism/stages_generator_configs.py b/src/modalities/models/parallelism/stages_generator_configs.py new file mode 100644 index 000000000..5d53f091d --- /dev/null +++ b/src/modalities/models/parallelism/stages_generator_configs.py @@ -0,0 +1,13 @@ +from typing import Annotated + +from pydantic import BaseModel, Field + + +class FQNsPerStageGeneratorConfig(BaseModel): # TODO duplicate + pass + + +class GPT2LLMStagesGeneratorConfig(BaseModel): + num_model_layers: Annotated[int, Field(strict=True, ge=1)] + input_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 + output_layer_equivalence: Annotated[int, Field(strict=True, ge=1)] = 1 diff --git a/src/modalities/optimizers/optimizer_factory.py b/src/modalities/optimizers/optimizer_factory.py index c430e82a1..9d7af332b 100644 --- a/src/modalities/optimizers/optimizer_factory.py +++ b/src/modalities/optimizers/optimizer_factory.py @@ -12,6 +12,7 @@ from modalities.exceptions import OptimizerError from modalities.models.model import NNModel from modalities.util import get_local_number_of_trainable_parameters, print_rank_0 +from modalities.utils.logger_utils import get_logger from modalities.utils.typing_utils import FSDPX OptimizerGroups = list[dict[str, list[nn.Parameter] | float]] @@ -80,7 +81,7 @@ def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_e optimizer_groups_names = ["all"] else: # there will be N optimizer groups, i.e. one for each model parameter group - _assert_existence_of_weight_decay_groups_excluded(model, weight_decay_groups_excluded) + _check_existence_of_weight_decay_groups_excluded(model, weight_decay_groups_excluded) optimizer_groups, optimizer_groups_names = _create_optimizer_groups( model, weight_decay, weight_decay_groups_excluded ) @@ -90,9 +91,7 @@ def get_optimizer_groups(model: FSDP, weight_decay: float, weight_decay_groups_e return optimizer_groups -def _assert_existence_of_weight_decay_groups_excluded( - model: nn.Module, weight_decay_groups_excluded: list[str] -) -> None: +def _check_existence_of_weight_decay_groups_excluded(model: nn.Module, weight_decay_groups_excluded: list[str]) -> None: """ checks the existence of all groups that are to be excluded from weight decay @@ -113,9 +112,10 @@ def _assert_existence_of_weight_decay_groups_excluded( weight_decay_groups = nn_model.weight_decay_groups for group in weight_decay_groups_excluded: if group not in weight_decay_groups.keys(): - raise OptimizerError( + get_logger(name="optimizer_factory").warning( f"group = {group} specified in weight_decay_groups_excluded is not " - + f"in models optimizer_module_groups = {list(weight_decay_groups.keys())}" + + f"in models optimizer_module_groups = {list(weight_decay_groups.keys())}. " + + "(This might be due to pipeline parallelism and is not necessarily an error.)" ) @@ -156,14 +156,45 @@ def _create_optimizer_groups( f"model {type(model)} has no parameters with requires_grad=True (i.e., no traininable parameters)." ) - optimizer_groups = [ + optimizer_groups = _built_optimizer_groups_via_weight_decay_split( + weight_decay, weight_decay_groups_excluded, weight_decay_groups, params + ) + return optimizer_groups, ["with_weight_decay", "without_weight_decay"] + + +def _built_optimizer_groups_via_weight_decay_split( + weight_decay: float, + weight_decay_groups_excluded: list[str], + weight_decay_groups: dict[str, list[str]], + params: dict[str, nn.Parameter], +) -> OptimizerGroups: + params_per_weight_decay_groups: list[dict[str, object]] = [ { "params": _filter_params_for_weight_decay_group(params, regex_expressions=weight_decay_groups[group]), - "weight_decay": weight_decay if group not in weight_decay_groups_excluded else 0.0, + "exclude": group not in weight_decay_groups_excluded, } for group in weight_decay_groups.keys() ] - return optimizer_groups, weight_decay_groups.keys() + + optimizer_groups: OptimizerGroups = [ + { + "params": sum((p["params"] for p in params_per_weight_decay_groups if not p["exclude"]), []), + "weight_decay": weight_decay, + }, + { + "params": sum((p["params"] for p in params_per_weight_decay_groups if p["exclude"]), []), + "weight_decay": 0.0, + }, + ] + + if len(optimizer_groups[0]["params"]) == 0 or len(optimizer_groups[1]["params"]) == 0: + raise OptimizerError( + "One of the optimizer groups has zero parameters. " + "This indicates that the weight_decay_groups_excluded configuration is not compatible " + "with the configured pipeline stages." + ) + + return optimizer_groups def _filter_params_for_weight_decay_group( diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 28afab4bb..d56946060 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -86,6 +86,15 @@ from modalities.models.gpt2.gpt2_model import GPT2LLMConfig from modalities.models.huggingface.huggingface_model import HuggingFacePretrainedModel, HuggingFacePretrainedModelConfig from modalities.models.model_factory import GPT2ModelFactory, ModelFactory +from modalities.models.parallelism.pipeline_parallelism import ComponentSelectorFromPipeline, PipelineFactory +from modalities.models.parallelism.pipeline_parallelism_configs import ( + ComponentSelectorFromPipelineConfig, + PipelineConfig, + ScheduledPipelineConfig, + StagedPipelineConfig, +) +from modalities.models.parallelism.stages_generator import GPT2LLMStagesGenerator +from modalities.models.parallelism.stages_generator_configs import GPT2LLMStagesGeneratorConfig from modalities.nn.model_initialization.composed_initialization import ( ComposedInitializationRoutines, ComposedModelInitializationConfig, @@ -95,16 +104,16 @@ from modalities.running_env.fsdp.device_mesh import DeviceMeshConfig, get_device_mesh from modalities.tokenization.tokenizer_wrapper import PreTrainedHFTokenizer, PreTrainedSPTokenizer from modalities.training.gradient_clipping.fsdp_gradient_clipper import ( - DummyGradientClipper, FSDP1GradientClipper, FSDP1LoggingOnlyGradientClipper, FSDP2GradientClipper, FSDP2LoggingOnlyGradientClipper, ) from modalities.training.gradient_clipping.fsdp_gradient_clipper_config import ( - DummyGradientClipperConfig, - FSDPDummyGradientClipperConfig, - FSDPGradientClipperConfig, + FSDP1DummyGradientClipperConfig, + FSDP1GradientClipperConfig, + FSDP2DummyGradientClipperConfig, + FSDP2GradientClipperConfig, ) from modalities.utils.mfu import GPT2MFUCalculator from modalities.utils.number_conversion import ( @@ -174,6 +183,12 @@ class ComponentEntity: ComponentEntity( "model", "debugging_enriched", ModelFactory.get_debugging_enriched_model, DebuggingEnrichedModelConfig ), + ComponentEntity("pipeline", "staged", PipelineFactory.get_staged_pipeline, StagedPipelineConfig), + ComponentEntity("pipeline", "scheduled", PipelineFactory.get_scheduled_pipeline, ScheduledPipelineConfig), + ComponentEntity("pipeline", "selector", ComponentSelectorFromPipeline.select, ComponentSelectorFromPipelineConfig), + ComponentEntity("pipeline", "builder", PipelineFactory.get_pipeline, PipelineConfig), + # Pipeline Stages Generators + ComponentEntity("stages_generator", "gpt2_stages_generator", GPT2LLMStagesGenerator, GPT2LLMStagesGeneratorConfig), # Device mesh ComponentEntity("device_mesh", "default", get_device_mesh, DeviceMeshConfig), # weight initializers @@ -209,7 +224,6 @@ class ComponentEntity: # tokenizers ComponentEntity("tokenizer", "pretrained_hf_tokenizer", PreTrainedHFTokenizer, PreTrainedHFTokenizerConfig), ComponentEntity("tokenizer", "pretrained_sp_tokenizer", PreTrainedSPTokenizer, PreTrainedSPTokenizerConfig), - # ComponentEntity("tokenizer", "llama_tokenizer_fast", GPT2TokenizerFast, None), # TODO # datasets ComponentEntity("dataset", "mem_map_dataset", DatasetFactory.get_mem_map_dataset, MemMapDatasetConfig), ComponentEntity( @@ -311,15 +325,14 @@ class ComponentEntity: ComponentEntity("layer_norm", "rms_norm", RMSLayerNorm, RMSLayerNormConfig), ComponentEntity("layer_norm", "layer_norm", nn.LayerNorm, LayerNormConfig), # gradient clippers - ComponentEntity("gradient_clipper", "fsdp1", FSDP1GradientClipper, FSDPGradientClipperConfig), + ComponentEntity("gradient_clipper", "fsdp1", FSDP1GradientClipper, FSDP1GradientClipperConfig), ComponentEntity( - "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDPDummyGradientClipperConfig + "gradient_clipper", "fsdp1_logging_only", FSDP1LoggingOnlyGradientClipper, FSDP1DummyGradientClipperConfig ), - ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDPGradientClipperConfig), + ComponentEntity("gradient_clipper", "fsdp2", FSDP2GradientClipper, FSDP2GradientClipperConfig), ComponentEntity( - "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDPDummyGradientClipperConfig + "gradient_clipper", "fsdp2_logging_only", FSDP2LoggingOnlyGradientClipper, FSDP2DummyGradientClipperConfig ), - ComponentEntity("gradient_clipper", "dummy", DummyGradientClipper, DummyGradientClipperConfig), # MFU calculators ComponentEntity("mfu_calculator", "gpt2", GPT2MFUCalculator, GPT2MFUCalculatorConfig), # Number conversion diff --git a/src/modalities/running_env/fsdp/device_mesh.py b/src/modalities/running_env/fsdp/device_mesh.py index 24e7d6e18..3a217c0e3 100644 --- a/src/modalities/running_env/fsdp/device_mesh.py +++ b/src/modalities/running_env/fsdp/device_mesh.py @@ -5,7 +5,6 @@ from torch.distributed.device_mesh import DeviceMesh, init_device_mesh from modalities.exceptions import ConfigError -from modalities.util import print_rank_0 from modalities.utils.logger_utils import get_logger logger = get_logger("model_factory") @@ -84,7 +83,8 @@ def get_device_mesh( enable_loss_parallel: bool, world_size: int, ) -> DeviceMesh: - """Gets the device mesh for the specified parallelism degrees. + """ + Gets the device mesh for the specified parallelism degrees. Args: device_type (str): The device type. @@ -118,12 +118,35 @@ def get_device_mesh( ], strict=True, ): - if dim > 1: + if dim > 1 or name == ParallelismDegrees.DP_SHARD.value: dims.append(dim) names.append(name) names = tuple(names) device_mesh = init_device_mesh(device_type, dims, mesh_dim_names=names) - print_rank_0(f"{device_mesh=} | {world_size=} | {enable_loss_parallel=}") + logger.info(f"{device_mesh=} | {world_size=} | {enable_loss_parallel=}") # TODO: Torch Titan had some more checks here. We need to check if we also need those: # https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/distributed/parallel_dims.py#L86-L104 return device_mesh + + +def get_num_parallel_ranks(device_mesh: DeviceMesh, parallelism_method: ParallelismDegrees) -> int: + """Gets the number of parallel ranks from the device mesh for a specific parallelism method. + + Args: + device_mesh (DeviceMesh): The device mesh. + parallelism_method (ParallelismDegrees): The parallelism method. + + Returns: + int: The number of parallel ranks for the specified parallelism method. + """ + if parallelism_method.value not in device_mesh.mesh_dim_names: + return 1 + else: + return device_mesh.size(device_mesh.mesh_dim_names.index(parallelism_method.value)) + + +def get_mesh_for_parallelism_method(device_mesh: DeviceMesh | None, parallelism_method: ParallelismDegrees): + if device_mesh is not None and parallelism_method.value in device_mesh.mesh_dim_names: + return device_mesh[parallelism_method.value] + else: + return None diff --git a/src/modalities/trainer.py b/src/modalities/trainer.py index b443c0ad3..f81407f02 100644 --- a/src/modalities/trainer.py +++ b/src/modalities/trainer.py @@ -14,6 +14,7 @@ from modalities.logging_broker.publisher import MessagePublisher from modalities.loss_functions import Loss from modalities.models.model import model_predict_batch +from modalities.models.parallelism.pipeline_parallelism import Pipeline from modalities.running_env.fsdp.reducer import Reducer from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF from modalities.training.training_progress import TrainingProgress @@ -30,6 +31,7 @@ class Trainer: def __init__( self, global_rank: int, + num_pipeline_parallel_ranks: int, progress_publisher: MessagePublisher[ProgressUpdate], evaluation_result_publisher: MessagePublisher[EvaluationResultBatch], gradient_acc_steps: int, @@ -45,23 +47,24 @@ def __init__( Initializes the Trainer object. Args: - global_rank (int): The global rank to which operates the trainer object. - progress_publisher (MessagePublisher[ProgressUpdate]): The publisher for progress updates. - evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): - The publisher for evaluation result batches. - gradient_acc_steps (int): The number of gradient accumulation steps. - global_num_tokens_per_train_step (int): The number of global tokens per training step. - num_seen_train_steps (int): The number of training steps already seen. - global_num_seen_tokens (int): The number of tokens already seen. - num_target_steps (int): The target number of training steps. - num_target_tokens (int): The target number of tokens. - gradient_clipper (GradientClipperIF): The gradient clipper. - mfu_calculator (Optional[MFUCalculatorABC]): The MFU calculator. + global_rank (int): The global rank. + num_pipeline_parallel_ranks (int): Number of pipeline parallel ranks. + progress_publisher (MessagePublisher[ProgressUpdate]): Progress publisher. + evaluation_result_publisher (MessagePublisher[EvaluationResultBatch]): Evaluation result publisher. + gradient_acc_steps (int): Gradient accumulation steps. + global_num_tokens_per_train_step (int): Global number of tokens per train step. + num_seen_train_steps (int): Number of seen train steps. + global_num_seen_tokens (int): Global number of seen tokens. + num_target_steps (int): Number of target steps. + num_target_tokens (int): Number of target tokens. + gradient_clipper (GradientClipperIF): Gradient clipper. + mfu_calculator (Optional[MFUCalculatorABC]): MFU calculator. Returns: None """ self.global_rank = global_rank + self.num_pipeline_parallel_ranks = num_pipeline_parallel_ranks self.progress_publisher = progress_publisher self.evaluation_result_publisher = evaluation_result_publisher self.gradient_acc_steps = gradient_acc_steps @@ -95,7 +98,8 @@ def _train_batch( scheduler: LRScheduler, loss_fun: Loss, micro_batch_id: int, - ) -> tuple[bool, int, torch.Tensor, Optional[torch.Tensor]]: + scheduled_pipeline: Optional[Pipeline] = None, + ) -> tuple[bool, int, Optional[torch.Tensor], Optional[torch.Tensor]]: """ Conducts a training step on batch of data. @@ -106,19 +110,39 @@ def _train_batch( scheduler (LRScheduler): The learning rate scheduler. loss_fun (Loss): The loss function used for training. micro_batch_id (int): The ID of the micro batch. + scheduled_pipeline (Optional[Pipeline], optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: tuple[bool, int, torch.Tensor, Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple containing the following: - step_performed (bool): Indicates whether a training step was performed. - num_train_steps_done (int): The number of training steps done. - - loss (torch.Tensor): The computed loss. + - loss (Optional[torch.Tensor]): The computed loss. + None, if a non-last stage was processes in pipeline parallelism. - gradient_norm_score (Optional[torch.Tensor]): The gradient norm score, if a training step was performed otherwise return None. """ - result_batch = model_predict_batch(model=model, batch=batch) - loss = loss_fun(result_batch) - (loss / self.gradient_acc_steps).backward() + if scheduled_pipeline is not None: + pp_schedule = scheduled_pipeline.pp_schedule + # Pipeline Parallel forward / backward inside step() call + # with self.train_context(optional_context_parallel_ctx): + targets, losses = ( + (batch.targets[loss_fun.target_key].contiguous(), []) + if scheduled_pipeline.is_last_pp_stage + else (None, None) + ) + + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.step(batch.samples[model.sample_key].contiguous(), target=targets, losses=losses) + else: + pp_schedule.step(target=targets, losses=losses) + loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage else None + else: + # else continue with loss calculation + result_batch = model_predict_batch(model=model, batch=batch) + loss = loss_fun(result_batch) + (loss / self.gradient_acc_steps).backward() if (micro_batch_id + 1) % self.gradient_acc_steps == 0: gradient_norm_score = self.gradient_clipper.clip_gradients() @@ -143,6 +167,7 @@ def train( training_log_interval_in_steps: int, evaluation_callback: Callable[[TrainingProgress], None], checkpointing_callback: Callable[[TrainingProgress], None], + scheduled_pipeline: Pipeline | None = None, ): """ Trains the model. @@ -154,6 +179,8 @@ def train( training_log_interval_in_steps (int): The interval at which training progress is logged. evaluation_callback (Callable[[TrainingProgress], None]): A callback function for evaluation. checkpointing_callback (Callable[[TrainingProgress], None]): A callback function for checkpointing. + scheduled_pipeline (Pipeline | None, optional): In case of pipeline parallelism, this is used to + operate the model. Defaults to None. Returns: None @@ -206,15 +233,18 @@ def train( scheduler=lr_scheduler, loss_fun=loss_fun, micro_batch_id=micro_batch_id, + scheduled_pipeline=scheduled_pipeline, ) forward_backward_time_recorder.stop() training_progress.num_seen_steps_current_run = num_train_steps_done training_progress.num_seen_tokens_current_run = self.global_num_tokens_per_train_step * num_train_steps_done - # Save the batch loss - cumulated_losses[0] += batch_loss.item() - # This works, because we always drop the last batch in case it has less samples than the batch size - cumulated_losses[-1] += 1 # number of local batches + # The batch_loss might be None if we use pipeline parallelism and are not the last stage. + if batch_loss is not None: + # Save the batch loss + cumulated_losses[0] += batch_loss.item() + # This works, because we always drop the last batch in case it has less samples than the batch size + cumulated_losses[-1] += 1 # number of local batches # gradient norm is already synced across all ranks if gradient_norm_score is not None: @@ -243,14 +273,17 @@ def train( synced_num_samples_per_second = synced_num_samples / synced_forward_backward_time # TODO: insert reducer from outside so Trainer is independent of FSDP # add the loss and gradient norm for the LAST batch - cumulated_losses[1] = batch_loss.item() + + cumulated_losses[1] = batch_loss.item() if batch_loss is not None else 0.0 reduced_losses = Reducer.reduce( tensor=cumulated_losses, operation=dist.ReduceOp.SUM, # 1.) summed batch loss / (num batches * world size) - # 2.) last batch loss / world size - post_processing_fun=lambda t: torch.stack([t[0] / t[-1], t[1] / dist.get_world_size()]), + # 2.) last batch loss / (world size / num_pipeline_parallel_ranks) + post_processing_fun=lambda t: torch.stack( + [t[0] / t[-1], t[1] / dist.get_world_size() * self.num_pipeline_parallel_ranks] + ), ) train_loss_avg, train_loss_last_batch = ( diff --git a/src/modalities/training/activation_checkpointing/activation_checkpointing.py b/src/modalities/training/activation_checkpointing/activation_checkpointing.py index 3cecf192d..0c194c350 100644 --- a/src/modalities/training/activation_checkpointing/activation_checkpointing.py +++ b/src/modalities/training/activation_checkpointing/activation_checkpointing.py @@ -135,8 +135,8 @@ def apply_activation_checkpointing_( raise ValueError(f"Unknown activation checkpointing variant: {ac_variant}") layers = model.get_submodule(layers_fqn) - if not isinstance(layers, nn.ModuleList): - raise ValueError(f"layers_fqn {layers_fqn} does not reference a ModuleList") + if not isinstance(layers, nn.ModuleDict): + raise ValueError(f"layers_fqn {layers_fqn} does not reference a ModuleDict") print_rank_0(f"Applying activation checkpointing to {len(list(layers.named_children()))} layers...") diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py index f1adddfb3..129b8ad93 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper.py @@ -1,11 +1,15 @@ -from typing import Iterable, Optional +import math +from typing import Optional import torch +from torch import distributed as dist +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 from torch.distributed.tensor import DTensor from modalities.config.lookup_enum import LookupEnum +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees, get_mesh_for_parallelism_method from modalities.training.gradient_clipping.gradient_clipper import GradientClipperIF @@ -88,65 +92,25 @@ def clip_gradients(self) -> torch.Tensor: return gradient_norm_score -class FSDP2GradientClipper(GradientClipperIF): - """The FSDP2GradientClipper class that is responsible for clipping the gradients of a model wrapped with FSDP.""" - - def __init__(self, wrapped_model: FSDP2, max_norm: float, norm_type=GradientClippingMode) -> None: - """ - Initialize the FSDP2GradientClipper object. - - Args: - wrapped_model (FSDP2): The wrapped model. - max_norm (float): The maximum norm value for gradient clipping. - norm_type (GradientClippingMode, optional): The type of gradient clipping. Defaults to GradientClippingMode. - - Returns: - None - """ - self.wrapped_model = wrapped_model - self.max_norm = max_norm - self.norm_type = norm_type - - @torch.no_grad() - def clip_gradients(self) -> torch.Tensor: - """ - Clips the gradients of the wrapped model using the specified maximum norm and norm type. - - Returns: - torch.Tensor: The gradient norm after clipping. - """ - gradient_norm_score = FSDP2GradientClipper.clip_grad_norm_( - parameters=self.wrapped_model.parameters(), - max_norm=self.max_norm, - norm_type=self.norm_type.value, - error_if_nonfinite=True, - foreach=True, - ) - return gradient_norm_score +class FSDP2LoggingOnlyGradientClipper(GradientClipperIF): + """The FSDP2LoggingOnlyGradientClipper class that is responsible for logging the gradient + norms without actually clipping the gradients.""" - @staticmethod - def clip_grad_norm_( - parameters: torch.Tensor | Iterable[torch.Tensor], - max_norm: float, - norm_type: float = 2.0, + def __init__( + self, + wrapped_model: FSDP2, + norm_type: GradientClippingMode, + device_mesh: Optional[DeviceMesh] = None, error_if_nonfinite: bool = False, foreach: Optional[bool] = None, - ) -> torch.Tensor: + ) -> None: """ - Clip the gradient norm of an iterable of parameters. - - Gradient norm clipping requires computing the gradient norm over the entire model. - `torch.nn.utils.clip_grad_norm_` only computes gradient norm along DP/FSDP/TP dimensions. - - TODO: for pipeline parallelism, we need to implement it like here: - https://github.com/pytorch/torchtitan/blob/b291ad662493b63d25b038a30a915082d3617baf/torchtitan/distributed/utils.py#L245 - I removed all the code w.r.t. pipeline parallelism for now. + Initialize the FSDP2LoggingOnlyGradientClipper. Args: - parameters: an iterable of Tensors or a single Tensor that will have gradients normalized - max_norm (float): max norm of the gradients - norm_type (float): type of the used p-norm. Can be ``'inf'`` for - infinity norm. + wrapped_model (FSDP2): The wrapped FSDP2 model. + norm_type (GradientClippingMode): The type of gradient clipping. + device_mesh (DeviceMesh, optional): The device mesh used for distributed training. Defaults to None. error_if_nonfinite (bool): if True, an error is thrown if the total norm of the gradients from :attr:`parameters` is ``nan``, ``inf``, or ``-inf``. Default: False (will switch to True in the future) @@ -156,12 +120,28 @@ def clip_grad_norm_( Default: ``None`` Returns: - Total norm of the parameter gradients (viewed as a single vector). + None + """ + self.wrapped_model = wrapped_model + self.norm_type = norm_type + self.device_mesh = device_mesh + self.error_if_nonfinite = error_if_nonfinite + self.foreach = foreach + + @torch.no_grad() + def clip_gradients(self) -> torch.Tensor: + """ + Returns the gradient norm, but does not apply clipping since max_norm is set to inifinity. + Returns: + torch.Tensor: The gradient norms. """ - grads = [p.grad for p in parameters if p.grad is not None] + grads = [p.grad for p in self.wrapped_model.parameters() if p.grad is not None] total_norm = torch.nn.utils.get_total_norm( - tensors=grads, norm_type=norm_type, error_if_nonfinite=error_if_nonfinite, foreach=foreach + tensors=grads, + norm_type=self.norm_type.value, + error_if_nonfinite=self.error_if_nonfinite, + foreach=self.foreach, ) # Inspired by torch titan @@ -174,61 +154,72 @@ def clip_grad_norm_( # If only using PP, total_norm will be a local tensor. total_norm = total_norm.full_tensor() - torch.nn.utils.clip_grads_with_norm_( - parameters=parameters, max_norm=max_norm, total_norm=total_norm, foreach=foreach + pp_mesh = get_mesh_for_parallelism_method( + device_mesh=self.device_mesh, parallelism_method=ParallelismDegrees.PP ) + if pp_mesh is not None: + if math.isinf(self.norm_type.value): + dist.all_reduce(total_norm, op=dist.ReduceOp.MAX, group=pp_mesh.get_group()) + else: + total_norm **= self.norm_type.value + dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + total_norm **= 1.0 / self.norm_type.value return total_norm -class FSDP2LoggingOnlyGradientClipper(GradientClipperIF): - """The FSDP2LoggingOnlyGradientClipper class that is responsible for logging the gradient - norms without actually clipping the gradients.""" +class FSDP2GradientClipper(FSDP2LoggingOnlyGradientClipper): + """The FSDP2GradientClipper class that is responsible for clipping the gradients of a model wrapped with FSDP.""" - def __init__(self, wrapped_model: FSDP2, norm_type=GradientClippingMode) -> None: + def __init__( + self, + wrapped_model: FSDP2, + max_norm: float, + norm_type: GradientClippingMode, + device_mesh: Optional[DeviceMesh] = None, + error_if_nonfinite: bool = False, + foreach: Optional[bool] = None, + ) -> None: """ - Initialize the FSDP2LoggingOnlyGradientClipper. + Initialize the FSDP2GradientClipper object. Args: wrapped_model (FSDP2): The wrapped FSDP2 model. - norm_type (GradientClippingMode, optional): The type of gradient clipping. Defaults to GradientClippingMode. + max_norm (float): The maximum norm value for gradient clipping. + norm_type (GradientClippingMode): The type of gradient clipping. + device_mesh (DeviceMesh, optional): The device mesh used for distributed training. Defaults to None. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + foreach (bool): use the faster foreach-based implementation. + If ``None``, use the foreach implementation for CUDA and CPU native tensors and silently + fall back to the slow implementation for other device types. + Default: ``None`` Returns: None """ - self.wrapped_model = wrapped_model - self.norm_type = norm_type + super().__init__( + wrapped_model=wrapped_model, + norm_type=norm_type, + device_mesh=device_mesh, + error_if_nonfinite=error_if_nonfinite, + foreach=foreach, + ) + self.max_norm = max_norm @torch.no_grad() def clip_gradients(self) -> torch.Tensor: """ - Returns the gradient norm, but does not apply clipping since max_norm is set to inifinity. + Clips the gradients of the wrapped model using the specified maximum norm and norm type. Returns: - torch.Tensor: The gradient norms. + torch.Tensor: The gradient norm after clipping. """ - grads = [p.grad for p in self.wrapped_model.parameters() if p.grad is not None] - total_norm = torch.nn.utils.get_total_norm( - tensors=grads, norm_type=self.norm_type.value, error_if_nonfinite=False, foreach=True + total_norm = super().clip_gradients() + torch.nn.utils.clip_grads_with_norm_( + parameters=self.wrapped_model.parameters(), + max_norm=self.max_norm, + total_norm=total_norm, + foreach=self.foreach, ) - if isinstance(total_norm, DTensor): - # Will reach here if any non-PP parallelism is used. - # If only using PP, total_norm will be a local tensor. - total_norm = total_norm.full_tensor() return total_norm - - -class DummyGradientClipper(GradientClipperIF): - """The DummyGradientClipper class that does not apply gradient clipping.""" - - def __init__(self) -> None: - pass - - def clip_gradients(self) -> torch.Tensor: - """ - Returns a tensor with value -1.0 indicating that DummyGradientClipper does not actually apply gradient clipping. - - Returns: - torch.Tensor: Tensor with value -1.0 - """ - gradient_norm_score = torch.Tensor([-1.0]) - return gradient_norm_score diff --git a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py index 4b4dd807d..b19971d69 100644 --- a/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py +++ b/src/modalities/training/gradient_clipping/fsdp_gradient_clipper_config.py @@ -2,11 +2,11 @@ from pydantic import BaseModel, Field -from modalities.config.pydantic_if_types import PydanticPytorchModuleType +from modalities.config.pydantic_if_types import PydanticDeviceMeshIFType, PydanticPytorchModuleType from modalities.training.gradient_clipping.fsdp_gradient_clipper import GradientClippingMode -class FSDPGradientClipperConfig(BaseModel): +class FSDP1GradientClipperConfig(BaseModel): """ Configuration class for FSDP gradient clipper. @@ -26,7 +26,27 @@ class FSDPGradientClipperConfig(BaseModel): wrapped_model: PydanticPytorchModuleType -class FSDPDummyGradientClipperConfig(BaseModel): +class FSDP2GradientClipperConfig(FSDP1GradientClipperConfig): + """ + Configuration class for FSDP gradient clipper. + + Args: + max_norm (float): The maximum norm value for gradient clipping. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. + + Attributes: + max_norm (float): The maximum norm value for gradient clipping. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. + """ + + device_mesh: PydanticDeviceMeshIFType | None = None + + +class FSDP1DummyGradientClipperConfig(BaseModel): """ Configuration class for FSDP dummy gradient clipper. @@ -43,17 +63,21 @@ class FSDPDummyGradientClipperConfig(BaseModel): norm_type: GradientClippingMode -class DummyGradientClipperConfig(BaseModel): +class FSDP2DummyGradientClipperConfig(FSDP1DummyGradientClipperConfig): """ - Configuration class for dummy gradient clipper. + Configuration class for FSDP dummy gradient clipper. - This class is a placeholder and does not have any specific functionality. + Args: + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. Attributes: - None - - Methods: - None + wrapped_model (PydanticPytorchModuleType): The wrapped PyTorch model. + norm_type (GradientClippingMode): The type of gradient clipping to be applied. + device_mesh (PydanticDeviceMeshIFType | None): The device mesh configuration. """ - pass + wrapped_model: PydanticPytorchModuleType + norm_type: GradientClippingMode + device_mesh: PydanticDeviceMeshIFType | None = None diff --git a/src/modalities/util.py b/src/modalities/util.py index eee5ff108..7c479da1d 100644 --- a/src/modalities/util.py +++ b/src/modalities/util.py @@ -11,12 +11,14 @@ import torch.distributed as dist import torch.nn as nn from pydantic import ValidationError +from torch.distributed.device_mesh import DeviceMesh from torch.distributed.fsdp import FSDPModule as FSDP2 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP1 from torch.distributed.tensor import DTensor from torch.types import Number from modalities.exceptions import TimeRecorderStateError +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees from modalities.running_env.fsdp.reducer import Reducer from modalities.utils.typing_utils import FSDPX @@ -164,12 +166,13 @@ def get_local_number_of_trainable_parameters(model: nn.Module) -> int: return num_params -def get_total_number_of_trainable_parameters(model: FSDPX) -> Number: +def get_total_number_of_trainable_parameters(model: FSDPX, device_mesh: DeviceMesh | None) -> Number: """Returns the total number of trainable parameters across all ranks. The model must be sharded with FSDP1 or FSDP2. Args: model (FSDPX): The model for which to calculate the number of trainable parameters. + device_mesh (DeviceMesh | None): The device mesh used for distributed training. Returns: Number: The total number of trainable parameters across all ranks. @@ -214,8 +217,13 @@ def get_total_number_of_trainable_parameters(model: FSDPX) -> Number: # >>> parameter_tensor.shape[0] * parameter_tensor.shape[1] # 6438912 - total_num_params = sum(p.numel() for p in model.parameters() if p.requires_grad) - return total_num_params + num_params_tensor = sum(p.numel() for p in model.parameters() if p.requires_grad) + if device_mesh is not None and ParallelismDegrees.PP.value in device_mesh.mesh_dim_names: + num_params_tensor = torch.tensor(num_params_tensor).cuda() + pp_mesh = device_mesh[ParallelismDegrees.PP.value] + dist.all_reduce(num_params_tensor, op=dist.ReduceOp.SUM, group=pp_mesh.get_group()) + return num_params_tensor.item() + return num_params_tensor else: raise ValueError( f"Model type {type(model)} is not supported. " diff --git a/tests/checkpointing/checkpointing_test_utils.py b/tests/checkpointing/checkpointing_test_utils.py index 21c4caabe..c350ccbc8 100644 --- a/tests/checkpointing/checkpointing_test_utils.py +++ b/tests/checkpointing/checkpointing_test_utils.py @@ -1,5 +1,6 @@ import torch from pydantic import BaseModel +from torch.distributed.tensor import DTensor from torch.nn import CrossEntropyLoss from torch.optim import Optimizer @@ -15,10 +16,17 @@ class CheckpointingTestUtils: @staticmethod def generate_batch(gpt2_model_config: dict): # prepare input and targets + if "settings" in gpt2_model_config: + batch_size = gpt2_model_config["settings"]["step_profile"]["local_train_micro_batch_size"] + else: + batch_size = 8 data = torch.randint( 0, # lowest token_id gpt2_model_config["model_raw"]["config"]["vocab_size"], # highest token_id + 1, i.e. vocab_size - (8, gpt2_model_config["model_raw"]["config"]["sequence_length"] + 1), # (batch_size, sequence_length + 1) + ( + batch_size, + gpt2_model_config["model_raw"]["config"]["sequence_length"] + 1, + ), # (batch_size, sequence_length + 1) ).cuda() batch_input_ids_dict = {gpt2_model_config["model_raw"]["config"]["sample_key"]: data[:, :-1]} batch_target_ids = data[:, 1:] @@ -49,6 +57,33 @@ def forward_backward_pass( optimizer.step() return loss + @staticmethod + def forward_backward_pp_pass( + scheduled_pipeline, + optimizer: Optimizer, + batch_input_ids_dict: dict, + batch_target_ids: torch.Tensor, + ): + pp_schedule = scheduled_pipeline.pp_schedule + # Pipeline Parallel forward / backward inside step() call + # with self.train_context(optional_context_parallel_ctx): + targets, losses = (batch_target_ids.contiguous(), []) if scheduled_pipeline.is_last_pp_stage else (None, None) + + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.step( + batch_input_ids_dict[scheduled_pipeline.model_part.sample_key].contiguous(), + target=targets, + losses=losses, + ) + else: + pp_schedule.step(target=targets, losses=losses) + loss = torch.mean(torch.stack(losses)).to(losses[0].device) if scheduled_pipeline.is_last_pp_stage else None + optimizer.step() + # clear the gradients + optimizer.zero_grad() + + return loss + @staticmethod def get_gpt2_model_from_config(gpt2_model_config_dict: dict) -> GPT2LLM: class GPT2InstantationModel(BaseModel): @@ -94,19 +129,32 @@ def assert_equality_optimizer_state( state_2 = optimizer_2_state[param_group_id] assert set(state_1.keys()) == set(state_2.keys()) for state_key in state_1.keys(): - if must_be_equal: - assert torch.equal( - state_1[state_key], state_2[state_key] - ), "_assert_equality_optimizer_state failed (must_be_equal = True)" - else: - assert not torch.equal( - state_1[state_key], state_2[state_key] - ), "_assert_equality_optimizer_state failed (must_be_equal = False)" + CheckpointingTestUtils.assert_equality_two_tensors( + tensor_1=state_1[state_key], + tensor_2=state_2[state_key], + must_be_equal=must_be_equal, + msg_on_failure="_assert_equality_optimizer_state failed", + ) @staticmethod def assert_equality_two_models(params_1: list[torch.Tensor], params_2: list[torch.Tensor], must_be_equal: bool): for p1, p2 in zip(params_1, params_2): - if must_be_equal: - assert torch.equal(p1, p2), "_assert_equality_two_models failed (must_be_equal = True)" - else: - assert not torch.equal(p1, p2), "_assert_equality_two_models failed (must_be_equal = False)" + CheckpointingTestUtils.assert_equality_two_tensors( + tensor_1=p1, + tensor_2=p2, + must_be_equal=must_be_equal, + msg_on_failure="_assert_equality_two_models failed", + ) + + @staticmethod + def assert_equality_two_tensors( + tensor_1: torch.Tensor, tensor_2: torch.Tensor, must_be_equal: bool, msg_on_failure: str = "" + ): + if isinstance(tensor_1, DTensor): + assert isinstance(tensor_2, DTensor), f"{msg_on_failure} (type mismatch with DTensor)" + tensor_1 = tensor_1.to_local() + tensor_2 = tensor_2.to_local() + if must_be_equal: + assert torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = True)" + else: + assert not torch.equal(tensor_1, tensor_2), f"{msg_on_failure} (must_be_equal = False)" diff --git a/tests/checkpointing/fsdp2_pp_gpt2_config.yaml b/tests/checkpointing/fsdp2_pp_gpt2_config.yaml new file mode 100644 index 000000000..4a02aa6b2 --- /dev/null +++ b/tests/checkpointing/fsdp2_pp_gpt2_config.yaml @@ -0,0 +1,194 @@ +settings: + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + step_profile: + local_train_micro_batch_size: 8 + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 4 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${cuda_env:LOCAL_RANK} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + sample_key: "input_ids" # TODO reference this + poe_type: NOPE + prediction_key: "logits" # TODO reference this + sequence_length: 256 # TODO reference this (same as sequence length) + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 4 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: gelu + attention_norm_config: + norm_type: rms_norm + config: + ndim: ${model_raw.config.n_embd} + bias: true + epsilon: 1e-5 + ffn_norm_config: + norm_type: rms_norm + config: + ndim: ${model_raw.config.n_embd} + bias: true + epsilon: 1e-5 + lm_head_norm_config: + norm_type: rms_norm + config: + ndim: ${model_raw.config.n_embd} + bias: true + epsilon: 1e-5 + use_weight_tying: false + use_meta_device: true + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0003 + betas: + - 0.9 + - 0.95 + eps: 1.0e-08 + weight_decay: 0.1 + weight_decay_groups_excluded: + - embedding + - layernorm + wrapped_model: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${cuda_env:WORLD_SIZE} diff --git a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py index e1f0f349c..bcbdbd32b 100644 --- a/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py +++ b/tests/checkpointing/test_fsdp2_dcp_checkpoint_loading_and_saving.py @@ -1,6 +1,10 @@ import json +import logging +import multiprocessing as py_mp import os import tempfile +import time +import traceback from copy import deepcopy from pathlib import Path @@ -16,7 +20,7 @@ from modalities.checkpointing.fsdp.fsdp_checkpoint_saving import DCPCheckpointSaving from modalities.checkpointing.stateful.app_state import AppState from modalities.config.config import ProcessGroupBackendType, load_app_config_dict -from modalities.config.pydantic_if_types import PydanticAppStateType +from modalities.config.pydantic_if_types import PydanticAppStateType, PydanticPipelineType from modalities.training.training_progress import TrainingProgress from tests.checkpointing.checkpointing_test_utils import CheckpointingTestUtils from tests.end2end_tests.custom_components import MultiProcessingCudaEnv @@ -41,37 +45,68 @@ def get_gpt2_model_config_dict(gpt2_model_config_path: Path) -> dict: @pytest.mark.skipif( - torch.cuda.device_count() < 2, - reason="This e2e test requires 2 GPUs", + torch.cuda.device_count() < 4, + reason="This e2e test requires 4 GPUs", ) class TestFSDP2DCPCheckpointing: @staticmethod - def _get_app_state(config_file_path: Path) -> AppState: - class ComponentsInstantiationModel(BaseModel): - app_state: PydanticAppStateType + def _get_app_state(config_file_path: Path, use_pp: bool = False) -> AppState: + if use_pp: + + class ComponentsInstantiationModel(BaseModel): + app_state: PydanticAppStateType + scheduled_pipeline: PydanticPipelineType + + else: + + class ComponentsInstantiationModel(BaseModel): + app_state: PydanticAppStateType main_obj = Main(config_file_path) components: ComponentsInstantiationModel = main_obj.build_components( components_model_type=ComponentsInstantiationModel ) - return components.app_state + app_state = components.app_state + if use_pp: + app_state.scheduled_pipeline = components.scheduled_pipeline + return app_state @staticmethod - def test_save_checkpoint_after_backward_pass(temporary_checkpoint_folder_path: Path, gpt2_model_config_path: Path): - world_size = 2 - mp.spawn( + @pytest.mark.parametrize( + "config_filename,world_size,use_pp", + [ + ("fsdp2_gpt2_config.yaml", 2, False), + ("fsdp2_pp_gpt2_config.yaml", 2, True), + ], + ) + def test_save_checkpoint_after_backward_pass( + temporary_checkpoint_folder_path: Path, config_filename: str, world_size: int, use_pp: bool + ): + working_dir = Path(os.path.dirname(__file__)) + config_file_path = working_dir / config_filename + # Use a Manager queue so child processes can report exceptions to the parent. + manager = py_mp.Manager() + error_queue = manager.Queue() + + # Start child processes without joining so the parent can monitor a shared queue + # and terminate remaining workers immediately if any child fails. + proc_ctx = mp.spawn( TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl_wrapper, - args=(world_size, temporary_checkpoint_folder_path, gpt2_model_config_path), + args=(world_size, temporary_checkpoint_folder_path, config_file_path, use_pp, error_queue), nprocs=world_size, - join=True, + join=False, ) + TestFSDP2DCPCheckpointing._monitor_child_processes(manager, error_queue, proc_ctx) + @staticmethod def _test_save_checkpoint_after_backward_pass_impl_wrapper( process_id: int, world_size: int, temporary_checkpoint_folder_path: Path, gpt2_model_config_path: Path, + use_pp: bool, + error_queue: "py_mp.managers.SyncManager.Queue", ): # wraps the actual test function to be able to run it in a distributed multiprocessing setup with MultiProcessingCudaEnv( @@ -79,31 +114,41 @@ def _test_save_checkpoint_after_backward_pass_impl_wrapper( global_rank=process_id, local_rank=process_id, world_size=world_size, - rdvz_port=22356, + rdvz_port=22354, ): - # build all the components for the test - app_state1 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - app_state2 = TestFSDP2DCPCheckpointing._get_app_state(config_file_path=gpt2_model_config_path) - - gpt2_model_config_dict = get_gpt2_model_config_dict(gpt2_model_config_path=gpt2_model_config_path) - experiment_id = "0" - checkpoint_loading = DCPCheckpointLoading(global_rank=process_id) - checkpoint_saving = DCPCheckpointSaving( - checkpoint_path=temporary_checkpoint_folder_path, - experiment_id=experiment_id, - global_rank=process_id, - ) + try: + # build all the components for the test + app_state1 = TestFSDP2DCPCheckpointing._get_app_state(gpt2_model_config_path, use_pp) + app_state2 = TestFSDP2DCPCheckpointing._get_app_state(gpt2_model_config_path, use_pp) - # run the test - TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl( - app_state1=app_state1, - app_state2=app_state2, - gpt2_model_config_dict=gpt2_model_config_dict, - checkpoint_loading=checkpoint_loading, - checkpoint_saving=checkpoint_saving, - temporary_checkpoint_folder_path=temporary_checkpoint_folder_path, - experiment_id=experiment_id, - ) + gpt2_model_config_dict = get_gpt2_model_config_dict(gpt2_model_config_path=gpt2_model_config_path) + experiment_id = "0" + checkpoint_loading = DCPCheckpointLoading(global_rank=process_id) + checkpoint_saving = DCPCheckpointSaving( + checkpoint_path=temporary_checkpoint_folder_path, + experiment_id=experiment_id, + global_rank=process_id, + ) + + # run the test + TestFSDP2DCPCheckpointing._test_save_checkpoint_after_backward_pass_impl( + app_state1=app_state1, + app_state2=app_state2, + gpt2_model_config_dict=gpt2_model_config_dict, + checkpoint_loading=checkpoint_loading, + checkpoint_saving=checkpoint_saving, + temporary_checkpoint_folder_path=temporary_checkpoint_folder_path, + experiment_id=experiment_id, + ) + except Exception as e: + tb = traceback.format_exc() + logging.error(f"Process {process_id} encountered an error:\n{e}") + logging.error(tb) + try: + error_queue.put((process_id, tb)) + except Exception: + logging.error("Failed to put exception info into error queue.") + os._exit(1) @staticmethod def _test_save_checkpoint_after_backward_pass_impl( @@ -139,13 +184,21 @@ def _test_save_checkpoint_after_backward_pass_impl( # run backward pass batch_input_ids_dict, batch_target_ids = CheckpointingTestUtils.generate_batch(gpt2_model_config_dict) - loss_0 = CheckpointingTestUtils.forward_backward_pass( - prediction_key=prediction_key, - model=app_state1.model, - optimizer=app_state1.optimizer, - batch_input_ids_dict=batch_input_ids_dict, - batch_target_ids=batch_target_ids, - ) + if hasattr(app_state1, "scheduled_pipeline"): + loss_0 = CheckpointingTestUtils.forward_backward_pp_pass( + scheduled_pipeline=app_state1.scheduled_pipeline, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) + else: + loss_0 = CheckpointingTestUtils.forward_backward_pass( + prediction_key=prediction_key, + model=app_state1.model, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) # save the updated model and optimizer states for later comparisons updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) @@ -195,26 +248,46 @@ def _test_save_checkpoint_after_backward_pass_impl( ) loaded_and_updated_model_parameters = CheckpointingTestUtils.clone_parameters(app_state1.model) - loaded_and_updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) - + loaded_and_updated_optimizer_state_dict = deepcopy(app_state1.optimizer.state_dict()) + # perform another forward pass and backward pass for the previous and the loaded model - loss_1 = CheckpointingTestUtils.forward_backward_pass( - prediction_key=prediction_key, - model=app_state1.model, - optimizer=app_state1.optimizer, - batch_input_ids_dict=batch_input_ids_dict, - batch_target_ids=batch_target_ids, - ) + if hasattr(app_state1, "scheduled_pipeline"): + try: + loss_1 = CheckpointingTestUtils.forward_backward_pp_pass( + scheduled_pipeline=app_state1.scheduled_pipeline, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) + loss_2 = CheckpointingTestUtils.forward_backward_pp_pass( + scheduled_pipeline=app_state2.scheduled_pipeline, + optimizer=app_state2.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) + except Exception as e: + print(f"Exception in _forward_step_with_pp: {e}") + traceback.print_exc() + raise + else: + loss_1 = CheckpointingTestUtils.forward_backward_pass( + prediction_key=prediction_key, + model=app_state1.model, + optimizer=app_state1.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) - loss_2 = CheckpointingTestUtils.forward_backward_pass( - prediction_key=prediction_key, - model=app_state2.model, - optimizer=app_state2.optimizer, - batch_input_ids_dict=batch_input_ids_dict, - batch_target_ids=batch_target_ids, - ) + loss_2 = CheckpointingTestUtils.forward_backward_pass( + prediction_key=prediction_key, + model=app_state2.model, + optimizer=app_state2.optimizer, + batch_input_ids_dict=batch_input_ids_dict, + batch_target_ids=batch_target_ids, + ) assert loss_1 == loss_2, f"loss_1 = {loss_1} does not equal loss_2 = {loss_2}" - assert loss_1 < loss_0, f"loss_1 = {loss_1} is not less than loss_0 = {loss_0}" + if loss_1 is not None: + assert loss_1 < loss_0, f"loss_1 = {loss_1} is not less than loss_0 = {loss_0}" # check that the model and optimizer states after each backward pass are as expected # model weights @@ -251,3 +324,104 @@ def _test_save_checkpoint_after_backward_pass_impl( CheckpointingTestUtils.assert_equality_optimizer_state( app_state1.optimizer.state_dict(), updated_optimizer_state_dict, must_be_equal=False ) + + @staticmethod + def _monitor_child_processes(manager, error_queue, proc_ctx): + # Normalize the return value from mp.spawn. When join=False it often + # returns a ProcessContext-like object that may expose a `processes` + # attribute. Other implementations may return an iterable of Process + # objects. Build a `processes` list defensively so we can monitor and + # terminate child processes below without assuming a particular type. + processes = [] + if proc_ctx is None: + processes = [] + else: + # common attribute names that might hold the list of processes + candidate_attrs = ["processes", "_processes", "workers", "process_list", "processes_"] + found = False + for attr in candidate_attrs: + if hasattr(proc_ctx, attr): + ps = getattr(proc_ctx, attr) + try: + processes = list(ps) + except Exception: + processes = [ps] + found = True + break + if not found: + # If proc_ctx itself is iterable, exhaust it into a list + try: + processes = list(proc_ctx) + except Exception: + # Fallback: if proc_ctx behaves like a single process-like + # object (has terminate/is_alive/join), wrap it in a list. + if hasattr(proc_ctx, "terminate") or hasattr(proc_ctx, "is_alive") or hasattr(proc_ctx, "join"): + processes = [proc_ctx] + else: + processes = [] + + # Monitor the error queue and child processes. If any child reports an exception, + # terminate the other workers and raise the error in the parent to fail the test fast. + try: + # Loop until all processes finished or an error is reported + while True: + # If an error was reported by any child process, terminate remaining children + if not error_queue.empty(): + proc_id, tb = error_queue.get() + # terminate and join all processes (or the proc_ctx wrapper) + for p in processes: + try: + if hasattr(p, "is_alive"): + alive = p.is_alive() + elif hasattr(p, "exitcode"): + alive = getattr(p, "exitcode") is None + else: + alive = True + if alive and hasattr(p, "terminate"): + p.terminate() + except Exception: + pass + # If we didn't find individual process objects but proc_ctx + # exposes a terminate method, call it as a fallback. + try: + if not processes and hasattr(proc_ctx, "terminate"): + proc_ctx.terminate() + except Exception: + pass + + for p in processes: + try: + if hasattr(p, "join"): + p.join(timeout=5) + except Exception: + pass + try: + if hasattr(proc_ctx, "join"): + proc_ctx.join(timeout=1) + except Exception: + pass + raise AssertionError(f"Child process {proc_id} raised an exception:\n{tb}") + + # If all processes have finished, break + all_finished = all((not p.is_alive()) for p in processes) + if all_finished: + # join them to collect exitcodes + for p in processes: + try: + p.join() + except Exception: + pass + # If we have a ProcessContext, call its join to clean up as well + try: + if hasattr(proc_ctx, "join"): + proc_ctx.join(timeout=1) + except Exception: + pass + break + + time.sleep(0.05) + finally: + try: + manager.shutdown() + except Exception: + pass diff --git a/tests/conftest.py b/tests/conftest.py index bc92e004b..9bcc5f1d6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -196,6 +196,7 @@ def trainer(progress_publisher_mock, gradient_clipper_mock): global_num_seen_tokens=0, num_target_tokens=100, num_target_steps=10, + num_pipeline_parallel_ranks=1, ) diff --git a/tests/conversion/gpt2/helper.py b/tests/conversion/gpt2/helper.py index 328633ccb..99adbacbc 100644 --- a/tests/conversion/gpt2/helper.py +++ b/tests/conversion/gpt2/helper.py @@ -6,14 +6,14 @@ def check_same_weight_model(converted_model: GPT2ForCausalLM, modalities_model: GPT2LLM): - converted_model.to(device=modalities_model.transformer.h[0].attn.q_attn.weight.device) + converted_model.to(device=modalities_model.transformer.h["0"].attn.q_attn.weight.device) assert torch.equal(converted_model.model.embed_tokens.weight, modalities_model.transformer.wte.weight) - for i, (llama_layer, modalities_layer) in enumerate( + for i, (llama_layer, modalities_layer_idx) in enumerate( zip(converted_model.model.layers, modalities_model.transformer.h) ): - check_same_weight_attention(llama_layer, modalities_layer) - check_same_weight_mlp(llama_layer, modalities_layer) - check_same_weight_layer_norms(llama_layer, modalities_layer) + check_same_weight_attention(llama_layer, modalities_model.transformer.h[modalities_layer_idx]) + check_same_weight_mlp(llama_layer, modalities_model.transformer.h[modalities_layer_idx]) + check_same_weight_layer_norms(llama_layer, modalities_model.transformer.h[modalities_layer_idx]) check_same_weight_base_modules(converted_model.lm_head, modalities_model.transformer.lm_head) check_same_weight_base_modules(converted_model.model.norm, modalities_model.transformer.lm_head_norm) diff --git a/tests/dataloader/distributed/mocks.py b/tests/dataloader/distributed/mocks.py new file mode 100644 index 000000000..cc3f044e2 --- /dev/null +++ b/tests/dataloader/distributed/mocks.py @@ -0,0 +1,42 @@ +import os + + +class MultiProcessingCudaEnvMock: + """Context manager to set the CUDA environment for distributed training.""" + + def __init__( + self, + global_rank: int, + local_rank: int, + world_size: int, + rdvz_port: int, + ) -> None: + self.global_rank = global_rank + self.local_rank = local_rank + self.world_size = world_size + self.rdvz_port = rdvz_port + self._original_env: dict[str, str | None] = {} + + def __enter__(self): + # Store original values + for key in ["MASTER_ADDR", "MASTER_PORT", "RANK", "LOCAL_RANK", "WORLD_SIZE"]: + self._original_env[key] = os.environ.get(key) + + # Set new environment variables + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(self.rdvz_port) + os.environ["RANK"] = str(self.global_rank) + os.environ["LOCAL_RANK"] = str(self.local_rank) + os.environ["WORLD_SIZE"] = str(self.world_size) + + # torch.cuda.set_device(local_rank) + + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + # Restore original environment variables + for key, value in self._original_env.items(): + if value is None: + os.environ.pop(key, None) + else: + os.environ[key] = value diff --git a/tests/dataloader/distributed/test_distributed_multidim_dataloader.py b/tests/dataloader/distributed/test_distributed_multidim_dataloader.py new file mode 100644 index 000000000..e3546b00c --- /dev/null +++ b/tests/dataloader/distributed/test_distributed_multidim_dataloader.py @@ -0,0 +1,84 @@ +import os +from unittest.mock import MagicMock + +import pytest +from torch.utils.data import BatchSampler + +from modalities.dataloader.dataloader_factory import DataloaderFactory +from modalities.dataloader.sampler_factory import SamplerFactory +from modalities.running_env.fsdp.device_mesh import ParallelismDegrees +from tests.dataloader.distributed.mocks import MultiProcessingCudaEnvMock +from tests.dataloader.dummy_sequential_dataset import TestDataset + + +@pytest.mark.parametrize("world_size, dp_degree", [(4, 2)]) +def test_distributed_multidim_dataloader_produces_same_data_on_connected_non_dp_ranks(world_size: int, dp_degree: int): + batches_on_rank = _build_batch_for_each_rank_combination(world_size, dp_degree) + + for dp_rank in range(dp_degree): + assert all( + batches_on_rank[(dp_rank, 0)] == batches_on_rank[(dp_rank, other_rank)] + for other_rank in range(1, world_size // dp_degree) + ), f"Batches on dp_rank {dp_rank} differ across other ranks." + + +@pytest.mark.parametrize("world_size, dp_degree", [(4, 2)]) +def test_distributed_multidim_dataloader_produces_different_data_on_different_dp_ranks(world_size: int, dp_degree: int): + batches_on_rank = _build_batch_for_each_rank_combination(world_size, dp_degree) + + for dp_rank1 in range(dp_degree): + for dp_rank2 in range(dp_rank1 + 1, dp_degree): + samples_dp_rank1 = sum(batches_on_rank[(dp_rank1, 0)], []) + samples_dp_rank2 = sum(batches_on_rank[(dp_rank2, 0)], []) + assert ( + len(set(samples_dp_rank1).intersection(samples_dp_rank2)) == 0 + ), f"Data samples on different data parallel ranks {dp_rank1} and {dp_rank2} should be disjoint." + + +def _build_batch_for_each_rank_combination(world_size: int, dp_degree: int): + return { + (dp_rank, other_rank): _load_data_for_ranks(dp_rank, other_rank, world_size, dp_degree) + for dp_rank, other_rank in _get_rank_combinations(world_size, dp_degree) + } + + +def _get_rank_combinations(world_size: int, dp_degree: int): + other_degree = world_size // dp_degree + return [(dp_rank, other_rank) for dp_rank in range(dp_degree) for other_rank in range(other_degree)] + + +def _load_data_for_ranks(dp_rank: int, other_rank: int, world_size: int, dp_degree: int): + global_rank = dp_rank * 2 + other_rank + with MultiProcessingCudaEnvMock( + global_rank=global_rank, + local_rank=other_rank, + world_size=world_size, + rdvz_port=22350, + ): + device_mesh = _build_device_mesh_mock(world_size, dp_degree, dp_rank, other_rank) + dataset = TestDataset(8) + sampler = SamplerFactory.create_resumable_distributed_multi_dim_sampler( + dataset=dataset, device_mesh=device_mesh, data_parallel_key=ParallelismDegrees.DP_SHARD + ) + batch_sampler = BatchSampler(sampler, batch_size=2, drop_last=True) + train_dataloader = DataloaderFactory.get_dataloader( + dataloader_tag="train", + dataset=dataset, + batch_sampler=batch_sampler, + collate_fn=None, + num_workers=2, + pin_memory=False, + ) + return [batch.tolist() for batch in train_dataloader] + + +def _build_device_mesh_mock(world_size: int, dp_degree: int, dp_rank: int, other_rank: int): + dp_device_mesh = MagicMock() + dp_device_mesh.size.return_value = dp_degree + dp_device_mesh.get_coordinate.return_value = [dp_rank] + other_device_mesh = MagicMock() + other_degree = world_size // dp_degree + other_device_mesh.size.return_value = int(os.environ["WORLD_SIZE"]) // other_degree + other_device_mesh.get_coordinate.return_value = [other_rank] + device_mesh_mock = {ParallelismDegrees.DP_SHARD.value: dp_device_mesh, "other": other_device_mesh} + return device_mesh_mock diff --git a/tests/end2end_tests/gpt2_train_num_steps_7_pp_tp.yaml b/tests/end2end_tests/gpt2_train_num_steps_7_pp_tp.yaml new file mode 100644 index 000000000..00a93a9d0 --- /dev/null +++ b/tests/end2end_tests/gpt2_train_num_steps_7_pp_tp.yaml @@ -0,0 +1,364 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: tmp/checkpoints + train_dataset_path: tests/end2end_tests/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 4 + evaluation_interval_in_steps: 1 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: num_tokens_from_packed_mem_map_dataset_continuous + config: + dataset_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + num_ranks: 2 # FIXME: adapt to dp_parallel_degree + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_steps_from_num_tokens + config: + num_ranks: 2 # FIXME: adapt to dp_parallel_degree + local_micro_batch_size: ${settings.step_profile.local_train_micro_batch_size} + global_num_tokens: ${settings.training_target.num_target_tokens} + sequence_length: ${settings.step_profile.sequence_length} + gradient_accumulation_steps: ${settings.step_profile.gradient_accumulation_steps} + training_progress: + global_num_seen_tokens: 0 + num_seen_steps: 0 + num_seen_samples: 0 + last_step: -1 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + tensor_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 8 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + last_epoch: ${settings.training_progress.last_step} + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: save_all + config: {} diff --git a/tests/end2end_tests/gpt2_warm_start_from_step_4_pp_tp.yaml b/tests/end2end_tests/gpt2_warm_start_from_step_4_pp_tp.yaml new file mode 100644 index 000000000..caea5ba49 --- /dev/null +++ b/tests/end2end_tests/gpt2_warm_start_from_step_4_pp_tp.yaml @@ -0,0 +1,393 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + paths: + checkpoint_saving_path: tmp/checkpoints + train_dataset_path: tests/end2end_tests/lorem_ipsum.pbin + intervals: + training_log_interval_in_steps: 1 + checkpointing_interval_in_steps: 2 + evaluation_interval_in_steps: 1 + consistency_enforcement: + enforce_tokens_per_step_consistency: false + enforce_last_step_logged: false + enforce_last_step_evaluated: false + enforce_last_step_checkpointed: false + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 2 + sequence_length: 256 + training_target: + num_target_tokens: + component_key: number_conversion + variant_key: global_num_target_tokens_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + num_target_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_target_steps_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + training_progress: + global_num_seen_tokens: # used below + component_key: number_conversion + variant_key: global_num_seen_tokens_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + num_seen_steps: # for the batch progress subscriber + component_key: number_conversion + variant_key: num_seen_steps_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + num_seen_samples: + component_key: number_conversion + variant_key: num_samples_from_num_tokens + config: + num_tokens: ${settings.training_progress.global_num_seen_tokens} + sequence_length: ${settings.step_profile.sequence_length} + last_step: # for the scheduler + component_key: number_conversion + variant_key: last_step_from_checkpoint_path + config: + checkpoint_path: ${settings.warmstart_checkpoint_paths.model_checkpoint_path} + warmstart_checkpoint_paths: + # we pass in the checkpoint paths as filenames such that the num_target_tokens and num_target_steps can be calculated and correctly passed to the training loop + # Within the test is replaced with the actual path to the checkpoint. + model_checkpoint_path: eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168 + optimizer_checkpoint_path: eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168 + +collate_fn: + component_key: collate_fn + variant_key: gpt_2_llm_collator + config: + sample_key: ${settings.referencing_keys.sample_key} + target_key: ${settings.referencing_keys.target_key} + +train_dataset: + component_key: dataset + variant_key: packed_mem_map_dataset_continuous + config: + raw_data_path: ${settings.paths.train_dataset_path} + sequence_length: ${settings.step_profile.sequence_length} + sample_key: ${settings.referencing_keys.sample_key} + +train_dataloader: + component_key: data_loader + variant_key: default + config: + num_workers: 2 + pin_memory: true + dataloader_tag: train + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + batch_sampler: + component_key: batch_sampler + variant_key: default + config: + batch_size: ${settings.step_profile.local_train_micro_batch_size} + drop_last: true + sampler: + component_key: sampler + variant_key: resumable_distributed_multi_dim_sampler + config: + dataset: + instance_key: train_dataset + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + data_parallel_key: dp_shard + shuffle: true + seed: 42 + drop_last: true + skip_num_global_samples: ${settings.training_progress.num_seen_samples} + collate_fn: + instance_key: collate_fn + pass_type: BY_REFERENCE + +eval_dataloaders: [] + +# checkpoint_loading: +# component_key: checkpoint_loading +# variant_key: dcp +# config: +# global_rank: ${settings.cuda_env.global_rank} + +checkpoint_saving: + component_key: checkpoint_saving + variant_key: default + config: + checkpoint_saving_strategy: + component_key: checkpoint_saving_strategy + variant_key: save_k_most_recent_checkpoints_strategy + config: + k: -1 # -1 to save all checkpoints + checkpoint_saving_execution: + component_key: checkpoint_saving_execution + variant_key: dcp + config: + checkpoint_path: ${settings.paths.checkpoint_saving_path} + global_rank: ${settings.cuda_env.global_rank} + experiment_id: ${settings.experiment_id} + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + tensor_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +app_state: + component_key: app_state + variant_key: dcp + config: + raw_app_state: + instance_key: app_state_raw + pass_type: BY_REFERENCE + checkpoint_dir_path: checkpoint/path/will/be/set/in/code + +app_state_raw: + component_key: app_state + variant_key: raw + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + lr_scheduler: + instance_key: lr_scheduler + pass_type: BY_REFERENCE + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: scheduled_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 1 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + # maybe better to use the fsdp model and the schedule here + # instead of passing in the staged pipeline? + # If fsdp_model creates a copy then this is not in the scope of + # the staged pipeline. + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: model_raw + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 2 + +model_raw: + component_key: model + variant_key: gpt2 + config: + seed: 42 + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 2 + n_head_q: 8 + n_head_kv: 8 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + +lr_scheduler: + component_key: scheduler + variant_key: onecycle_lr + config: + optimizer: + instance_key: optimizer + pass_type: BY_REFERENCE + max_lr: 6e-4 + div_factor: 10 + final_div_factor: 1 + total_steps: ${settings.training_target.num_target_steps} + pct_start: 0.01 + anneal_strategy: cos + # last_epoch: ${settings.training_progress.last_step} # Not required. App state will take care of the correct initialization. + +optimizer: + component_key: optimizer + variant_key: adam_w + config: + lr: 0.0001 + betas: [0.9, 0.95] + eps: 1e-8 + weight_decay: 1e-1 + weight_decay_groups_excluded: [embedding, layernorm] + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + +gradient_clipper: + component_key: gradient_clipper + variant_key: fsdp2 + config: + wrapped_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + norm_type: P2_NORM + max_norm: 1.0 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +progress_subscriber: + component_key: progress_subscriber + variant_key: rich + config: + global_rank: ${settings.cuda_env.global_rank} + num_seen_steps: ${settings.training_progress.num_seen_steps} + num_target_steps: ${settings.training_target.num_target_steps} + train_dataloader_tag: ${train_dataloader.config.dataloader_tag} + eval_dataloaders: + instance_key: eval_dataloaders + pass_type: BY_REFERENCE + +evaluation_subscriber: + component_key: results_subscriber + variant_key: save_all + config: {} diff --git a/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py b/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py new file mode 100644 index 000000000..63d20d8ba --- /dev/null +++ b/tests/end2end_tests/test_fsdp2_warmstart_pp_tp.py @@ -0,0 +1,276 @@ +import json +import logging +import os +import re +import shutil +import traceback +from pathlib import Path + +import pytest +import torch +import torch.distributed as dist +from pydantic import BaseModel + +from modalities.__main__ import Main, load_app_config_dict +from modalities.batch import EvaluationResultBatch +from modalities.config.config import ProcessGroupBackendType, PydanticLLMDataLoaderIFType +from modalities.config.instantiation_models import TrainingComponentsInstantiationModel +from modalities.dataloader.dataloader import LLMDataLoader +from modalities.logging_broker.messages import Message +from modalities.running_env.cuda_env import CudaEnv +from tests.end2end_tests.custom_components import SaveAllResultSubscriber, SaveAllResultSubscriberConfig + + +def extract_seen_steps_and_tokens(filename: str) -> tuple[int, int]: + pattern = r"seen_steps_(\d+)-seen_tokens_(\d+)" + match = re.search(pattern, filename) + return int(match.group(1)), int(match.group(2)) + + +# NOTE: We need to run the tests in a torch distributed environment with at eight GPUs. +# CUDA_VISIBLE_DEVICES=0,1,2,4,5,6,7 torchrun --rdzv-endpoint localhost:29502 --nnodes 1 --nproc_per_node 8 \ +# $(which pytest) path/to/test_fsdp_to_disc_checkpointing.py + +# NOTE that we can only run one test at time due to NCCL issues with multiple tests in parallel. +# You can specify the test to run with the -k flag, e.g.: -k test_warm_start + + +working_dir = Path(os.path.dirname(__file__)) +tmp_folder = working_dir / "../tmp/fsdp2_warmstart_pp_tp" + + +class TrainDataloaderInstantiationModel(BaseModel): + settings: TrainingComponentsInstantiationModel.Settings + train_dataloader: PydanticLLMDataLoaderIFType + + +@pytest.mark.skipif( + "RANK" not in os.environ or torch.cuda.device_count() < 8, + reason="This e2e test requires 8 GPUs and a torchrun distributed environment.", +) +class TestWarmstart: + @staticmethod + def get_loss_scores(messages: list[Message[EvaluationResultBatch]], loss_key: str) -> list[float]: + return [message.payload.losses[loss_key].value.item() for message in messages] + + def test_warm_start(self): + # We want to verify that the training continues after starting from checkpoint (i.e, warm start) + # exactly the same way, as if we trained it from scratch. + # To do so, we have two configs. The first config trains a model for 8 steps and + # saves multiple intermediary checkpoints. + # The second config starts from the 4th step and trains the model for 4 more steps. + # We compare the loss values of the two models after 4 steps and expect them to be the same. + + try: + if tmp_folder.exists(): + shutil.rmtree(tmp_folder) + tmp_folder.mkdir(parents=False, exist_ok=False) + # config for two steps model + gpt2_8_steps_config_file_path = working_dir / "gpt2_train_num_steps_7_pp_tp.yaml" + gpt2_8_steps_config_dict = load_app_config_dict(gpt2_8_steps_config_file_path, experiment_id="0") + + # adopt the checkpoint path + checkpoint_path = str(tmp_folder) + gpt2_8_steps_config_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ + "checkpoint_path" + ] = checkpoint_path + gpt2_8_steps_config_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_path + loss_values_experiment_0_path = checkpoint_path + "/experiment_0_loss_scores.txt" + + # config for one step model + gpt2_warm_start_after_4_steps_config_file_path = working_dir / "gpt2_warm_start_from_step_4_pp_tp.yaml" + gpt2_warm_start_after_4_steps_dict = load_app_config_dict( + gpt2_warm_start_after_4_steps_config_file_path, experiment_id="1" + ) + + # adopt the checkpoint path + gpt2_warm_start_after_4_steps_dict["app_state"]["config"]["checkpoint_dir_path"] = ( + checkpoint_path + "/0/eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168" + ) + gpt2_warm_start_after_4_steps_dict["checkpoint_saving"]["config"]["checkpoint_saving_execution"]["config"][ + "checkpoint_path" + ] = checkpoint_path + gpt2_warm_start_after_4_steps_dict["settings"]["paths"]["checkpoint_saving_path"] = checkpoint_path + loss_values_experiment_1_path = checkpoint_path + "/experiment_1_loss_scores.txt" + + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + try: + main_obj_0 = Main(gpt2_8_steps_config_file_path) + main_obj_0.config_dict = gpt2_8_steps_config_dict + main_obj_0.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + print( + main_obj_0.config_dict["settings"]["training_target"]["num_target_tokens"]["config"][ + "dataset_path" + ] + ) + components_0 = main_obj_0.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + main_obj_0.run(components_0) + + # we collect the loss values from rank 0 and store them in the temporary experiment folder + if dist.get_rank() == 0: + messages_0: list[ + Message[EvaluationResultBatch] + ] = components_0.evaluation_subscriber.message_list + loss_scores_0 = TestWarmstart.get_loss_scores(messages_0, "train loss avg") + with open(loss_values_experiment_0_path, "w") as f: + json.dump(loss_scores_0, f) + + # make sure that the checkpoints have been written and checkpoint info file has been updated + checkpoint_info_file_path = Path(checkpoint_path) / "0/last_checkpoint_info.json" + assert checkpoint_info_file_path.exists() + with open(checkpoint_info_file_path, "r") as f: + checkpoint_info = json.load(f) + assert checkpoint_info["checkpoint_folder_path"] == ( + checkpoint_path + "/0/eid_0-seen_steps_4-seen_tokens_4096-target_steps_7-target_tokens_7168" + ) + assert Path(checkpoint_info["checkpoint_folder_path"]).exists() + + checkpoint_paths = list(Path(checkpoint_path).glob("**/*seen_steps_*-seen_tokens_*")) + assert len(checkpoint_paths) > 0 + max_seen_steps = -1 + max_seen_tokens = -1 + for checkpoint_path in checkpoint_paths: + seen_steps, seen_tokens = extract_seen_steps_and_tokens(checkpoint_path.name) + max_seen_steps = max(max_seen_steps, seen_steps) + max_seen_tokens = max(max_seen_tokens, seen_tokens) + + cp_info_model_seen_steps, cp_info_model_seen_tokens = extract_seen_steps_and_tokens( + checkpoint_info["checkpoint_folder_path"] + ) + + assert cp_info_model_seen_steps == max_seen_steps + assert cp_info_model_seen_tokens == max_seen_tokens + + main_obj_1 = Main(gpt2_warm_start_after_4_steps_config_file_path) + main_obj_1.config_dict = gpt2_warm_start_after_4_steps_dict + + main_obj_1.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + components_1 = main_obj_1.build_components( + components_model_type=TrainingComponentsInstantiationModel + ) + + assert ( + components_0.app_state.lr_scheduler.base_lrs == components_1.app_state.lr_scheduler.base_lrs + ) # make sure that the initial learning rates are the same + assert components_1.app_state.lr_scheduler.last_epoch == 4 # we start from step 4 + + main_obj_1.run(components_1) + + # we collect the loss values from rank 0 for the warmstart model + # and store them in the temporary experiment folder + if dist.get_rank() == 0: + messages_1: list[ + Message[EvaluationResultBatch] + ] = components_1.evaluation_subscriber.message_list + loss_scores_1 = TestWarmstart.get_loss_scores(messages_1, "train loss avg") + with open(loss_values_experiment_1_path, "w") as f: + json.dump(loss_scores_1, f) + + # read the losses from disc + # note that the temporary directory is only correct for the rank 0. + # rank 1 has a different one and we don't store anything in there + with open(loss_values_experiment_0_path, "r") as f: + loaded_loss_values_0 = json.load(f) + + with open(loss_values_experiment_1_path, "r") as f: + loaded_loss_values_1 = json.load(f) + + # we check if the losses for the model from scratch + # and the warm start model have the same loss values + assert loaded_loss_values_0[4:] == pytest.approx(loaded_loss_values_1, abs=1e-16) + + # assert that the scheduler state is the same for both models + assert ( + components_1.app_state.lr_scheduler.last_epoch == components_0.app_state.lr_scheduler.last_epoch + ) + assert ( + components_0.app_state.lr_scheduler.get_last_lr() + == components_1.app_state.lr_scheduler.get_last_lr() + ) + except Exception as e: + tb = traceback.format_exc() + logging.error(f"Exception in rank {os.environ.get('RANK', -1)}: {e}") + logging.error(tb) + raise + finally: + logging.info(f"Rank {os.environ.get('RANK', -1)} cleaning up.") + if int(os.environ.get("RANK", -1)) == 0: + try: + if tmp_folder.exists(): + shutil.rmtree(tmp_folder) + except Exception as e: + logging.warning(f"Rank {os.environ.get('RANK', -1)}: failed to remove tmp folder {tmp_folder}: {e}") + + def test_warmstart_dataloader(self): + # non-skipped config + gpt2_two_steps_config_file_path = working_dir / "gpt2_train_num_steps_8.yaml" + gpt2_two_steps_config_dict = load_app_config_dict(gpt2_two_steps_config_file_path, experiment_id="0") + + # skipped config + gpt2_warm_start_from_step_1_config_file_path = working_dir / "gpt2_warm_start_from_step_4.yaml" + gpt2_warm_start_from_step_1_dict = load_app_config_dict( + gpt2_warm_start_from_step_1_config_file_path, experiment_id="1" + ) + + with CudaEnv(process_group_backend=ProcessGroupBackendType.nccl): + main_obj_1 = Main(gpt2_two_steps_config_file_path) + main_obj_1.config_dict = gpt2_two_steps_config_dict + + main_obj_2 = Main(gpt2_warm_start_from_step_1_config_file_path) + main_obj_2.config_dict = gpt2_warm_start_from_step_1_dict + + main_obj_1.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + components_1: TrainDataloaderInstantiationModel = main_obj_1.build_components( + components_model_type=TrainDataloaderInstantiationModel + ) + dataloader_1: LLMDataLoader = components_1.train_dataloader + dl_1_samples = [s for s in dataloader_1] + + main_obj_2.add_custom_component( + component_key="results_subscriber", + variant_key="save_all", + custom_component=SaveAllResultSubscriber, + custom_config=SaveAllResultSubscriberConfig, + ) + components_2 = main_obj_2.build_components(components_model_type=TrainDataloaderInstantiationModel) + dataloader_2: LLMDataLoader = components_2.train_dataloader + dl_2_samples = [s for s in dataloader_2] + + # fast forward the first dataloader + + num_skip_steps = components_2.settings.training_progress.num_seen_steps + + # make sure that we actually skip as defined in the config + assert num_skip_steps == 4 + assert len(dl_1_samples) == num_skip_steps + len(dl_2_samples) + + # make sure that the first dataloader is not skipped + assert components_1.settings.training_progress.num_seen_steps == 0 + + # iterate through both sample lists from the dataloaders + # and assert the equality of the samples + + for i in range(len(dataloader_2)): + assert dl_1_samples[i + num_skip_steps].samples["input_ids"].equal(dl_2_samples[i].samples["input_ids"]) + + dl_1_samples[i + num_skip_steps].samples["input_ids"][-1] = 0 + assert not ( + dl_1_samples[i + num_skip_steps].samples["input_ids"].equal(dl_2_samples[i].samples["input_ids"]) + ) diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py b/tests/fsdp2_parallelization/pipeline_parallelism/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml new file mode 100644 index 000000000..988e70eba --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml @@ -0,0 +1,108 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: initialized_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml new file mode 100644 index 000000000..f41e912bc --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml @@ -0,0 +1,166 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 4 + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml new file mode 100644 index 000000000..fb8ee5f7d --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml @@ -0,0 +1,177 @@ +settings: + experiment_id: ${modalities_env:experiment_id} + config_file_path: ${modalities_env:config_file_path} + referencing_keys: + sample_key: input_ids + target_key: target_ids + prediction_key: logits + cuda_env: + local_rank: ${cuda_env:LOCAL_RANK} + global_rank: ${cuda_env:RANK} + world_size: ${cuda_env:WORLD_SIZE} + step_profile: + gradient_accumulation_steps: 1 + local_train_micro_batch_size: 4 + sequence_length: 256 + +loss_fn: + component_key: loss + variant_key: clm_cross_entropy_loss + config: + target_key: ${settings.referencing_keys.target_key} + prediction_key: ${settings.referencing_keys.prediction_key} + +device_mesh: + component_key: device_mesh + variant_key: default + config: + device_type: cuda + data_parallel_replicate_degree: 1 + pipeline_parallel_degree: 2 + data_parallel_shard_degree: -1 + world_size: ${settings.cuda_env.world_size} + +scheduled_pipeline: + component_key: pipeline + variant_key: scheduled + config: + loss_fn: + instance_key: loss_fn + pass_type: BY_REFERENCE + pp_schedule_name: gpipe + batch_size: ${settings.step_profile.local_train_micro_batch_size} + microbatch_size: 2 + pp_degree: ${device_mesh.config.pipeline_parallel_degree} + pipeline: + component_key: pipeline + variant_key: builder + config: + pp_stage: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: PP_STAGE + model_part: + instance_key: fsdp_model + pass_type: BY_REFERENCE + +fsdp_model: + component_key: model + variant_key: fsdp2_wrapped + config: + model: + instance_key: gpt2_tp_model + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + mixed_precision_settings: + param_dtype: BF_16 + reduce_dtype: BF_16 + block_names: [GPT2Block] + +gpt2_tp_model: + component_key: model + variant_key: gpt2_tp + config: + model: + instance_key: model_part + pass_type: BY_REFERENCE + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + +model_part: + component_key: pipeline + variant_key: selector + config: + pipeline: + instance_key: staged_pipeline + pass_type: BY_REFERENCE + selection_type: MODEL_PART + +staged_pipeline: + component_key: pipeline + variant_key: staged + config: + whole_model: + instance_key: initialized_model + pass_type: BY_REFERENCE + stages_generator: + component_key: stages_generator + variant_key: gpt2_stages_generator + config: + num_model_layers: ${model_raw.config.n_layer} + input_layer_equivalence: 1 + output_layer_equivalence: 1 + device_mesh: + instance_key: device_mesh + pass_type: BY_REFERENCE + local_rank: ${settings.cuda_env.local_rank} + pp_schedule_name: gpipe + num_layers_per_stage: 4 + +initialized_model: + component_key: model + variant_key: model_initialized + config: + model: + instance_key: model_raw + pass_type: BY_REFERENCE + model_initializer: + component_key: model_initialization + variant_key: composed + config: + model_type: gpt2 + weight_init_type: scaled + mean: 0.0 + std: 0.02 + num_layers: ${model_raw.config.n_layer} + +model_raw: + component_key: model + variant_key: gpt2 + config: + use_meta_device: true + use_weight_tying: false + sample_key: ${settings.referencing_keys.sample_key} + poe_type: NOPE + sequence_length: ${settings.step_profile.sequence_length} + prediction_key: ${loss_fn.config.prediction_key} + vocab_size: 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency + n_layer: 6 + n_head_q: 8 + n_head_kv: 4 + ffn_hidden: 128 + n_embd: 128 + dropout: 0.0 + bias: true # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster + attention_config: + qkv_transforms: + - type_hint: RotaryTransform + config: + n_embd: ${model_raw.config.n_embd} + n_head: ${model_raw.config.n_head_q} #it has to be head_q here + seq_length_dim: -2 + base_freq: 10000 + attention_implementation: manual + activation_type: swiglu + attention_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + ffn_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + lm_head_norm_config: + norm_type: layer_norm + config: + normalized_shape: ${model_raw.config.n_embd} + eps: 1e-5 + diff --git a/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py new file mode 100644 index 000000000..d255d62e0 --- /dev/null +++ b/tests/fsdp2_parallelization/pipeline_parallelism/test_pp_fwd_bwd_pass.py @@ -0,0 +1,166 @@ +import os +import tempfile +from pathlib import Path + +import pytest +import torch +import torch.multiprocessing as mp +import yaml +from pydantic import BaseModel + +from modalities.__main__ import Main +from modalities.batch import InferenceResultBatch +from modalities.config.config import ProcessGroupBackendType +from modalities.config.pydantic_if_types import PydanticFSDP2ModuleType, PydanticLossIFType, PydanticPipelineType +from modalities.models.parallelism.pipeline_parallelism import Pipeline +from tests.end2end_tests.custom_components import MultiProcessingCudaEnv + + +@pytest.fixture +def temp_file_path() -> Path: + with tempfile.NamedTemporaryFile() as tf: + yield tf.name + + +class ComponentsInstantiationPPModel(BaseModel): + scheduled_pipeline: PydanticPipelineType + + +class ComponentsInstantiationModel(BaseModel): + fsdp_model: PydanticFSDP2ModuleType + loss_fn: PydanticLossIFType + + +@pytest.mark.skipif( + torch.cuda.device_count() < 8, + reason="This test requires 8 GPUs", +) +class TestPipelineParallelism: + def _get_tmp_sharding_config_path( + self, sharding_degree: int, tp_degree: int, pp_degree: int, temp_file_path: Path + ) -> Path: + working_dir = Path(os.path.dirname(__file__)) + if tp_degree > 1: + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_tp_fwd_bwd_pass.yaml" + else: + config_file_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_pp_fwd_bwd_pass.yaml" + + with open(config_file_path, "r") as file: + config_string = file.read() + config_dict = yaml.safe_load(config_string) + config_dict["device_mesh"]["config"]["data_parallel_shard_degree"] = sharding_degree + config_dict["device_mesh"]["config"]["tensor_parallel_degree"] = tp_degree + config_dict["device_mesh"]["config"]["pipeline_parallel_degree"] = pp_degree + + # save to temporary file + with open(temp_file_path, "w") as file: + yaml.dump(config_dict, file) + + return temp_file_path + + def _get_components( + self, config_file_path: Path, use_pp: bool + ) -> ComponentsInstantiationPPModel | ComponentsInstantiationModel: + torch.manual_seed(42) + main_obj = Main(config_file_path) + components_model_type = ComponentsInstantiationPPModel if use_pp else ComponentsInstantiationModel + components = main_obj.build_components(components_model_type=components_model_type) + assert isinstance(components, components_model_type) + return components + + @pytest.mark.parametrize( + "sharding_degree, tp_degree, pp_degree, world_size", + [ + (2, 1, 2, 4), + (2, 2, 2, 8), + ], + ) + def test_pp(self, sharding_degree: int, tp_degree: int, pp_degree: int, world_size: int, temp_file_path: Path): + tmp_sharding_config_path = self._get_tmp_sharding_config_path( + sharding_degree=sharding_degree, + tp_degree=tp_degree, + pp_degree=pp_degree, + temp_file_path=temp_file_path, + ) + mp.spawn( + self._test_pp_impl, + args=(world_size, tmp_sharding_config_path), + nprocs=world_size, + join=True, + ) + + def _test_pp_impl( + self, + process_id: int, + world_size: int, + pp_model_config_path: Path, + ): + # wraps the actual test function to be able to run it in a distributed multiprocessing setup + with MultiProcessingCudaEnv( + process_group_backend=ProcessGroupBackendType.nccl, + global_rank=process_id, + local_rank=process_id, + world_size=world_size, + rdvz_port=22359, + ): + vocab_size = 50304 + sequence_length = 256 + batch_size = 4 + torch.manual_seed(42) + sequences = torch.randint(0, vocab_size, (batch_size, sequence_length + 1)) + targets = sequences[:, 1:].contiguous() + inputs = sequences[:, :-1].contiguous() + + is_last_pp_stage, loss_pp = self._forward_step_with_pp(pp_model_config_path, inputs, targets) + fsdp2_loss = self._forward_step_without_pp(inputs, targets) + + if is_last_pp_stage: + assert torch.allclose( + loss_pp, fsdp2_loss, atol=1e-6, rtol=1e-5 + ), f"Losses do not match.\nLoss with PP: {loss_pp.item()}, Loss without PP: {fsdp2_loss.item()}" + + def _forward_step_with_pp( + self, pp_model_config_path: Path, inputs: torch.Tensor, targets: torch.Tensor + ) -> tuple[bool, torch.Tensor]: + try: + components = self._get_components(pp_model_config_path, use_pp=True) + scheduled_pipeline = components.scheduled_pipeline + loss_pp = self._forward_step(scheduled_pipeline, inputs, targets) + except Exception as e: + import traceback + + print(f"Exception in _forward_step_with_pp: {e}") + traceback.print_exc() # <-- Add this line to print the full stack trace + raise e + return scheduled_pipeline.is_last_pp_stage, loss_pp + + def _forward_step(self, scheduled_pipeline: Pipeline, inputs: torch.Tensor, targets: torch.Tensor): + """Runs a forward step on the model.""" + os.environ["MODEL_TYPE"] = "PP" + pp_schedule = scheduled_pipeline.pp_schedule + targets, losses = (targets, []) if scheduled_pipeline.is_last_pp_stage else (None, None) + if scheduled_pipeline.is_first_pp_stage: + pp_schedule.step(inputs, target=targets, losses=losses) + else: + pp_schedule.step(target=targets, losses=losses) + + # accumulate losses across pipeline microbatchess + return ( + torch.mean(torch.stack(losses)).to(losses[0].device) + if scheduled_pipeline.is_last_pp_stage + else torch.tensor([-1.0], device=inputs.device) + ) + + def _forward_step_without_pp(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + os.environ["MODEL_TYPE"] = "NOPP" + working_dir = Path(os.path.dirname(__file__)) + fsdp2_model_config_path = working_dir / "configs/config_lorem_ipsum_long_fsdp2_fwd_bwd_pass.yaml" + fsdp2_components = self._get_components(fsdp2_model_config_path, use_pp=False) + fsdp2_model = fsdp2_components.fsdp_model + fsdp2_loss_fn = fsdp2_components.loss_fn + + input_dict = {"input_ids": inputs} + fsdp2_out = fsdp2_model(input_dict) + forward_batch = InferenceResultBatch(predictions=fsdp2_out, targets={fsdp2_loss_fn.target_key: targets}) + fsdp2_loss = fsdp2_loss_fn(forward_batch) + return fsdp2_loss diff --git a/tests/fsdp2_parallelization/test_tensor_parallelism.py b/tests/fsdp2_parallelization/test_tensor_parallelism.py index 449fdb996..ac6554124 100644 --- a/tests/fsdp2_parallelization/test_tensor_parallelism.py +++ b/tests/fsdp2_parallelization/test_tensor_parallelism.py @@ -117,11 +117,11 @@ def _test_tp_sharding_impl( # Ensure models use the correct MLP if activation_type == "gelu": - assert isinstance(fsdp2_model.transformer.h[0].mlp, TransformerMLP) - assert isinstance(tp_model.transformer.h[0].mlp, TransformerMLP) + assert isinstance(fsdp2_model.transformer.h["0"].mlp, TransformerMLP) + assert isinstance(tp_model.transformer.h["0"].mlp, TransformerMLP) elif activation_type == "swiglu": - assert isinstance(fsdp2_model.transformer.h[0].mlp, SwiGLU) - assert isinstance(tp_model.transformer.h[0].mlp, SwiGLU) + assert isinstance(fsdp2_model.transformer.h["0"].mlp, SwiGLU) + assert isinstance(tp_model.transformer.h["0"].mlp, SwiGLU) # Ensure models are sharded correctly assert "tp" in tp_model.transformer.wte.weight.device_mesh.mesh_dim_names diff --git a/tests/test_torch_compile.py b/tests/test_torch_compile.py index fab2ed217..59ae6ecb9 100644 --- a/tests/test_torch_compile.py +++ b/tests/test_torch_compile.py @@ -1,3 +1,6 @@ + +import copy + import pytest import torch.nn as nn @@ -57,7 +60,7 @@ def gpt2_model(): def test_get_compiled_model_compiles_blocks(gpt2_model): - original_blocks = list(gpt2_model.transformer.h) + original_model = copy.deepcopy(gpt2_model) original_wte = gpt2_model.transformer.wte original_lm_head = gpt2_model.transformer.lm_head @@ -65,9 +68,9 @@ def test_get_compiled_model_compiles_blocks(gpt2_model): result_model = ModelFactory.get_compiled_model(gpt2_model, block_names, fullgraph=True) assert len(result_model.transformer.h) == 4, "Should still have four blocks" - for i, (original_block, new_block) in enumerate(zip(original_blocks, result_model.transformer.h)): - assert new_block is not original_block, f"Block {i} should be a compiled version" - assert isinstance(new_block, nn.Module), f"Block {i} should be an nn.Module" + for i, (original_block_idx, new_block_idx) in enumerate(zip(original_model.transformer.h, result_model.transformer.h)): + assert result_model.transformer.h[new_block_idx] is not original_model.transformer.h[original_block_idx], f"Block {i} should be a compiled version" + assert isinstance(result_model.transformer.h[new_block_idx], nn.Module), f"Block {i} should be an nn.Module" assert result_model.transformer.wte is original_wte, "Embedding layer should remain unchanged" assert result_model.transformer.lm_head is original_lm_head, "LM head should remain unchanged" assert result_model is gpt2_model, "Should return the same model instance" diff --git a/tests/training/__init__.py b/tests/training/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/training/gradient_clipping/__init__.py b/tests/training/gradient_clipping/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_gradient_clipping.py b/tests/training/gradient_clipping/test_fsdp_gradient_clipper.py similarity index 52% rename from tests/test_gradient_clipping.py rename to tests/training/gradient_clipping/test_fsdp_gradient_clipper.py index 14ff0e7d8..edc797151 100644 --- a/tests/test_gradient_clipping.py +++ b/tests/training/gradient_clipping/test_fsdp_gradient_clipper.py @@ -1,16 +1,21 @@ +import tempfile import types +from multiprocessing import Queue from unittest.mock import MagicMock import torch +import torch.distributed as dist +import torch.multiprocessing as mp +import torch.nn as nn from modalities.training.gradient_clipping.fsdp_gradient_clipper import ( - DummyGradientClipper, FSDP1GradientClipper, FSDP1LoggingOnlyGradientClipper, FSDP2GradientClipper, FSDP2LoggingOnlyGradientClipper, GradientClippingMode, ) +from tests.utility import find_free_port class MockFSDPModel: @@ -93,14 +98,14 @@ def test_fsdp2_clip_grad_norm(): # Test case 1: max_norm > total_norm (no clipping) max_norm = expected_norm + 1 # 3.0 - norm = FSDP2GradientClipper.clip_grad_norm_(parameters=mock_model.parameters(), max_norm=max_norm, norm_type=2.0) + norm = FSDP2GradientClipper(mock_model, max_norm=max_norm, norm_type=GradientClippingMode.P2_NORM).clip_gradients() assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match expected total norm" assert torch.allclose(mock_model.param1.grad, torch.tensor([1.0, 1.0])), "Gradients should not be clipped" assert torch.allclose(mock_model.param2.grad, torch.tensor([1.0, 1.0])), "Gradients should not be clipped" # Test case 2: max_norm < total_norm (clipping occurs) max_norm = expected_norm / 2 # 1.0 - norm = FSDP2GradientClipper.clip_grad_norm_(parameters=mock_model.parameters(), max_norm=max_norm, norm_type=2.0) + norm = FSDP2GradientClipper(mock_model, max_norm=max_norm, norm_type=GradientClippingMode.P2_NORM).clip_gradients() assert torch.allclose(norm, torch.tensor(expected_norm)), "Norm should match pre-clipping total norm" scale = max_norm / expected_norm # 1.0 / 2.0 = 0.5 expected_grad = torch.tensor([1.0 * scale, 1.0 * scale]) @@ -151,10 +156,159 @@ def test_fsdp2_logging_only_gradient_clipper(): assert torch.allclose(param.grad, torch.tensor([1.0, 1.0])), "Gradients should not be modified" -def test_dummy_gradient_clipper(): - """ - Test that DummyGradientClipper returns a tensor with -1.0 and does not affect gradients. - """ - clipper = DummyGradientClipper() - norm = clipper.clip_gradients() - assert torch.allclose(norm, torch.tensor([-1.0])), "Norm should be -1.0 indicating no clipping" +def test_pipeline_parallelized_clipping_equivalent_to_single_stage_clipping(): + max_norm = 0.1 + # create full model and initialize deterministically + torch.manual_seed(42) + full = FullModel() + + # create an input and compute gradients on the full model + x = torch.randn(2, 4) + out = full(x) + loss = out.pow(2).sum() + loss.backward() + + # save full model state and grads to a temporary file for workers + state = {} + for name, p in full.named_parameters(): + # store parameter data and grads on CPU + state[name] = p.data.cpu().clone() + grads = {} + for name, p in full.named_parameters(): + grads[name] = p.grad.cpu().clone() + + with tempfile.NamedTemporaryFile() as tmp: + store_path = tmp.name + torch.save({"state": state, "grads": grads}, store_path) + + # set up multiprocessing to simulate 2 pipeline stages + world_size = 2 + port = find_free_port() + q = mp.get_context("spawn").Queue() + mp.spawn(_worker, args=(world_size, store_path, port, max_norm, q), nprocs=world_size, join=True) + + # collect results + results = {} + for _ in range(world_size): + rank, coll = q.get() + results[rank] = coll + + # perform clipping on the full model (single-stage) + FSDP2GradientClipper( + wrapped_model=full, + max_norm=max_norm, + norm_type=GradientClippingMode.P2_NORM, + device_mesh=None, + error_if_nonfinite=True, + foreach=True, + ).clip_gradients() + + # compare full model parts to the per-stage results + full_a_params = [p.data.cpu() for p in full.a.parameters()] + full_b_params = [p.data.cpu() for p in full.b.parameters()] + + # ranks: 0 -> partA, 1 -> partB + assert 0 in results and 1 in results + + for p_full, p_pp in zip(full_a_params, results[0]): + t_pp = torch.as_tensor(p_pp, dtype=p_full.dtype) + assert torch.allclose(p_full, t_pp, atol=1e-6, rtol=1e-5) + + for p_full, p_pp in zip(full_b_params, results[1]): + t_pp = torch.as_tensor(p_pp, dtype=p_full.dtype) + assert torch.allclose(p_full, t_pp, atol=1e-6, rtol=1e-5) + + +class PartA(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(4, 5, bias=False) + + def forward(self, x: torch.Tensor): + return self.lin(x) + + +class PartB(nn.Module): + def __init__(self): + super().__init__() + self.lin = nn.Linear(5, 3, bias=False) + + def forward(self, x: torch.Tensor): + return self.lin(x) + + +class FullModel(nn.Module): + def __init__(self): + super().__init__() + self.a = PartA() + self.b = PartB() + + def forward(self, x: torch.Tensor): + return self.b(self.a(x)) + + +def _worker(rank: int, world_size: int, store_path: str, port: int, max_norm: float, q: Queue): + # initialize distributed + dist.init_process_group(backend="gloo", init_method=f"tcp://127.0.0.1:{port}", rank=rank, world_size=world_size) + + # load saved full model state and grads + data = torch.load(store_path) + state = data["state"] + grads = data["grads"] + + # create the corresponding part for this rank and load weights + if rank == 0: + part = PartA() + # map parameters from full model: a.lin.weight + part.lin.weight.data.copy_(state["a.lin.weight"]) + # assign gradients + for name, p in part.named_parameters(): + full_name = f"a.{name}" + if full_name in grads: + p.grad = grads[full_name].clone() + else: + part = PartB() + part.lin.weight.data.copy_(state["b.lin.weight"]) + for name, p in part.named_parameters(): + full_name = f"b.{name}" + if full_name in grads: + p.grad = grads[full_name].clone() + + # create a dummy device_mesh-like object that matches the parts of DeviceMesh + # expected by get_mesh_for_parallelism_method and FSDP2GradientClipper. + class DummyPPMesh: + def __init__(self, group): + self._group = group + + def get_group(self): + return self._group + + class DummyDeviceMesh: + def __init__(self, group): + # include the PP mesh name so get_mesh_for_parallelism_method finds it + self.mesh_dim_names = ("pp",) + self._pp = DummyPPMesh(group) + + def __getitem__(self, name: str): + if name == "pp": + return self._pp + raise KeyError(name) + + mesh = DummyDeviceMesh(dist.group.WORLD) + + # call the clipping function which will perform all_reduce across the pp group + FSDP2GradientClipper( + wrapped_model=part, + max_norm=max_norm, + norm_type=GradientClippingMode.P2_NORM, + device_mesh=mesh, + error_if_nonfinite=True, + foreach=True, + ).clip_gradients() + + # collect clipped parameter tensors (cpu) and serialize to plain Python lists + # to avoid multiprocessing shared-storage pickling issues. + collected = [p.data.cpu().numpy().tolist() for p in part.parameters()] + q.put((rank, collected)) + + dist.destroy_process_group() diff --git a/tests/utility.py b/tests/utility.py new file mode 100644 index 000000000..c839bc76f --- /dev/null +++ b/tests/utility.py @@ -0,0 +1,9 @@ +import socket + + +def find_free_port(): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + s.bind(("127.0.0.1", 0)) + port = s.getsockname()[1] + s.close() + return port