diff --git a/Makefile b/Makefile index 0ab20da62..37aecc96e 100644 --- a/Makefile +++ b/Makefile @@ -20,3 +20,15 @@ test: --color=yes \ --verbose \ examples/llama/tests/ + +install-moe: + pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@main + +test-moe: + pytest --color=yes --verbose tests/test_moe_dispatcher.py + pytest --color=yes --verbose tests/test_moe.py + pytest --color=yes --verbose tests/test_distributed_primitives.py::test_all_to_all + +run-sanity-moe: + CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config-file /fsx/phuc/new_workspace/snippets/experiment_configs/qwen_moe/exp0a0_sanity_dense.yaml + CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config-file /fsx/phuc/new_workspace/snippets/experiment_configs/qwen_moe/exp0b0_sanity_moe_ep8.yaml diff --git a/README.md b/README.md index 719d0720b..1b5df079e 100644 --- a/README.md +++ b/README.md @@ -98,7 +98,7 @@ CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node=8 run_train.py --config- The model will be saved in the `checkpoints` directory as specified in the config file. > [!NOTE] -> You can use `examples/config_tiny_llama.py` to generate your own training config +> You can use `examples/config_tiny_llama.py` to generate your own training config For detailed instructions on training your first model, check out our [Your First Training guide](docs/your-first-training.md). For multi-node training with Slurm, see our [Multi-Node Training guide](docs/multi-node-training.md). @@ -173,6 +173,7 @@ We currently support the following features: - [x] Custom module checkpointing for large models - [x] Spectral µTransfer parametrization for scaling up neural networks - [x] Mamba example +- [x] CUDA event-based timing for accurate GPU performance measurement And we have on our roadmap: - [ ] FP8 training diff --git a/docs/cuda_event_timing.md b/docs/cuda_event_timing.md new file mode 100644 index 000000000..9f8f49351 --- /dev/null +++ b/docs/cuda_event_timing.md @@ -0,0 +1,92 @@ +# CUDA Event-Based Timing in Nanotron + +## Overview + +Nanotron now uses CUDA events for timing GPU operations instead of CPU-based timing with `time.time()`. This change provides several benefits: + +1. **More accurate measurement of GPU execution time**: CUDA events are recorded directly on the GPU timeline, providing more precise timing of GPU operations. +2. **Reduced need for explicit CUDA synchronization**: CPU-based timing requires synchronization between CPU and GPU to get accurate measurements, which can introduce overhead and affect performance. +3. **Lower overhead**: CUDA event-based timing has minimal impact on the execution of GPU operations. +4. **Better performance monitoring**: More accurate timing leads to better performance analysis and optimization. + +## Implementation Details + +The implementation uses `torch.cuda.Event` with `enable_timing=True` to create start and end events that are recorded on the GPU timeline. The elapsed time is then calculated using `start_event.elapsed_time(end_event)`, which returns the time in milliseconds. + +### Key Changes + +1. **Default Timer Type**: The default timer type in `nanotron/src/nanotron/logging/timers.py` has been changed from `TimerType.CPU` to `TimerType.CUDA`. + +2. **Iteration Timing**: The iteration timing in `trainer.py` now uses CUDA events instead of `time.time()`. + +3. **Synchronization Control**: By default, CUDA event-based timers do not force synchronization unless explicitly requested with `cuda_sync=True`. + +## Usage + +### Basic Usage + +```python +# Create and use a CUDA timer (default) +with nanotron_timer("my_operation"): + # Your GPU operation here + ... + +# Explicitly specify CUDA timing +with nanotron_timer("my_operation", timer_type="cuda"): + # Your GPU operation here + ... + +# For CPU-only operations, you can still use CPU-based timing +with nanotron_timer("cpu_operation", timer_type="cpu"): + # Your CPU operation here + ... + +# As a decorator with default CUDA timing +@nanotron_timer +def my_function(): + # Your GPU operation here + ... + +# As a decorator with custom name +@nanotron_timer("custom_name") +def my_function(): + # Your GPU operation here + ... + +# As a decorator with CPU timing +@nanotron_timer(timer_type=TimerType.CPU) +def my_cpu_function(): + # Your CPU operation here + ... +``` + +### Advanced Usage + +```python +# Start and end a timer manually +timer = nanotron_timer("my_operation") +timer.start() +# Your operation here +timer.end() + +# Get the elapsed time in seconds +elapsed_time = timer.elapsed + +# Get the total time across all calls +total_time = timer.total_time + +# Get the average time per call +avg_time = timer.average_time +``` + +## Considerations + +1. **Synchronization**: By default, CUDA event-based timers do not force synchronization to avoid overhead. If you need more accurate timing at the cost of performance, you can set `cuda_sync=True`. + +2. **Units**: CUDA events measure time in milliseconds, but the timer API converts this to seconds for consistency with the previous CPU-based timing. + +3. **Fallback**: If CUDA is not available, the timer will automatically fall back to CPU-based timing. + +## Performance Impact + +Using CUDA events for timing instead of CPU-based timing with synchronization can significantly reduce overhead, especially in distributed training scenarios with thousands of GPUs. diff --git a/examples/OLMoE-1B-7B-0924-test.yml b/examples/OLMoE-1B-7B-0924-test.yml new file mode 100644 index 000000000..f203c878e --- /dev/null +++ b/examples/OLMoE-1B-7B-0924-test.yml @@ -0,0 +1,138 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: /fsx/nouamane/checkpoints + checkpoints_path_is_shared_file_system: false + load_lr_scheduler: true + load_optimizer: true + resume_checkpoint_path: null + save_final_state: true + save_initial_state: false +data_stages: +- data: + # dataset: + # dataset_folder: + # - /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged + # dataset_max_tokens: null + # dataset_read_path: null + # dataset_weights: null + # pad_samples_to_global_batch_size: false + # return_positions: true + # shuffle_files: false + # skip_in_stream: false + # token_size_in_bytes: 4 + # tokenizer_name: meta-llama/Llama-3.2-1B + # use_old_brrr_dataloader: false + # vocab_size: 128256 + num_loading_workers: 1 + seed: 6198 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: true + project: olmoe + run: olmoe-test + seed: 6198 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +metrics_logging: null +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.02 + # scaling_method: NONE + make_vocab_size_divisible_by: 1 + model_config: + _attn_implementation: flash_attention_2 + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true + attention_bias: false + bos_token_id: 1 + eos_token_id: 0 + flex_attention_mask: null + hidden_act: silu + hidden_size: 2048 + initializer_range: 0.02 + intermediate_size: 2048 + is_qwen2_config: true + max_position_embeddings: 4096 + no_rope_layer: null + num_attention_heads: 16 + num_hidden_layers: 2 + num_key_value_heads: 16 + pad_token_id: 1 + pretraining_tp: 1 + rms_norm_eps: 1.0e-06 + rope_interleaved: false + rope_scaling: null + rope_theta: 10000.0 + sliding_window_size: 20 + tie_word_embeddings: false + use_cache: true + vocab_size: 128256 + z_loss_enabled: false + moe_config: + num_experts: 8 + top_k: 2 + moe_hidden_size: 2048 + moe_intermediate_size: 1024 # output_multiplier=0.5 for swiglu + # shared_expert_hidden_size: 2048 + # shared_expert_intermediate_size: 1024 + # router_aux_loss_coef: 0.01 + # enable_shared_expert: false + # token_dispatcher_type: allgather +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 1.0e-4 + lr_decay_starting_step: null + lr_decay_steps: 31998 + lr_decay_style: cosine + lr_warmup_steps: 500 + lr_warmup_style: linear + min_decay_lr: 4.0e-5 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-8 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.1 + weight_decay_exclude_named_params: [] + zero_stage: 0 +parallelism: + context_parallel_size: 1 + dp: 2 + expert_parallel_size: 1 + expert_data_parallel_size: 2 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER + tp_recompute_allgather: true + enabled_moe: true +profiler: null +s3_upload: null +# tokenizer: +# tokenizer_max_length: null +# tokenizer_name_or_path: meta-llama/Llama-3.2-1B +# tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 4 + sequence_length: 4096 + train_steps: 1000 + val_check_interval: -1 diff --git a/examples/config_qwen.py b/examples/config_qwen.py index 639ed2d6b..a5d901b24 100644 --- a/examples/config_qwen.py +++ b/examples/config_qwen.py @@ -30,7 +30,7 @@ "410m": (24, 1024, 16, 16, 4096), # ~410M params # Small to medium models "1b": (16, 2048, 16, 16, 5632), # ~1B params - "3b": (28, 2048, 16, 2, 11008), # ~3B params + "3b": (36, 2048, 16, 4, 11008), # ~3B params # Standard sizes "7b": (32, 4096, 32, 32, 11008), # ~7B params "13b": (40, 5120, 40, 40, 13824), # ~13B params @@ -47,7 +47,7 @@ def get_args(): parser.add_argument( "--model", choices=MODEL_SIZES.keys(), - default="custom", + default="3b", help="Model size to generate config for (e.g., 7b, 13b)", ) parser.add_argument( @@ -76,6 +76,10 @@ def get_args(): tokens_group.add_argument("--mbs", type=int, default=3, help="Micro batch size") tokens_group.add_argument("--acc", type=int, default=1, help="Batch accumulation per replica") + # checkpoints + checkpoints_group = parser.add_argument_group("checkpoints") + checkpoints_group.add_argument("--ckpt-save", type=int, default=10, help="Checkpoint save interval") + args = parser.parse_args() return args @@ -108,7 +112,7 @@ def get_model_config(model_size: str) -> Qwen2Config: is_qwen2_config=True, pad_token_id=None, _attn_implementation="flash_attention_2", - sliding_window_size=20, + _use_doc_masking=True, ) @@ -154,7 +158,7 @@ def calculate_parameters(model_config: Qwen2Config) -> str: def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config: learning_rate = LRSchedulerArgs( - learning_rate=3e-4, lr_warmup_steps=2, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=1e-5 + learning_rate=3e-4, lr_warmup_steps=2000, lr_warmup_style="linear", lr_decay_style="cosine", min_decay_lr=0 ) parallelism = ParallelismArgs( dp=args.dp, @@ -175,7 +179,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config ) optimizer = OptimizerArgs( zero_stage=args.zero, - weight_decay=0.01, + weight_decay=0.1, clip_grad=1.0, accumulate_grad_in_fp32=True, learning_rate_scheduler=learning_rate, @@ -192,7 +196,7 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config return Config( general=GeneralArgs(project="debug", run=args.run, seed=seed, ignore_sanity_checks=args.no_sanity), - checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=10), + checkpoints=CheckpointsArgs(checkpoints_path=checkpoints_path, checkpoint_interval=args.ckpt_save), parallelism=parallelism, model=ModelArgs(init_method=RandomInit(std=0.025), model_config=model_config), # tokenizer=TokenizerArgs("HuggingFaceTB/cosmo2-tokenizer"), @@ -219,7 +223,11 @@ def create_config(model_config: Qwen2Config, args: argparse.Namespace) -> Config world_size = args.dp * args.tp * args.pp * args.cp if world_size <= 8: print( - f"CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" + f"ENABLE_TIMERS=1 DEBUG_CPU=1 STATS_SAMPLING_INTERVAL_IN_SEC=1 CUDA_DEVICE_MAX_CONNECTIONS=1 torchrun --nproc_per_node={world_size} run_train.py --config-file {args.out}" ) + print("You can also use environment variables for more debugging:") + print(" - ENABLE_TIMERS=1: Enable detailed timing information") + print(" - DEBUG_CPU=1: Log CPU and memory usage statistics") + print(" - STATS_SAMPLING_INTERVAL_IN_SEC=1: Set sampling interval for metrics collection") else: print("Checkout slurm_launcher.py to launch a multi-node job") diff --git a/examples/config_qwen.yaml b/examples/config_qwen.yaml index 5fc8e48ea..cf6f40fac 100644 --- a/examples/config_qwen.yaml +++ b/examples/config_qwen.yaml @@ -1,5 +1,5 @@ checkpoints: - checkpoint_interval: 10 + checkpoint_interval: 100000 checkpoints_path: checkpoints checkpoints_path_is_shared_file_system: false load_lr_scheduler: true @@ -30,9 +30,9 @@ data_stages: general: benchmark_csv_path: null consumed_train_samples: null - ignore_sanity_checks: false + ignore_sanity_checks: true project: debug - run: qwen_20250410_014907_16027793 + run: qwen_20250424_120835_16423158 seed: 42 step: null lighteval: null @@ -45,6 +45,7 @@ model: ddp_bucket_cap_mb: 25 dtype: bfloat16 init_method: + scaling_method: NUM_LAYERS std: 0.025 make_vocab_size_divisible_by: 1 model_config: @@ -58,15 +59,15 @@ model: eos_token_id: 2 flex_attention_mask: null hidden_act: silu - hidden_size: 256 + hidden_size: 2048 initializer_range: 0.02 - intermediate_size: 768 + intermediate_size: 11008 is_qwen2_config: true max_position_embeddings: 4096 moe_config: null no_rope_layer: null - num_attention_heads: 4 - num_hidden_layers: 12 + num_attention_heads: 16 + num_hidden_layers: 36 num_key_value_heads: 4 pad_token_id: null pretraining_tp: 1 @@ -74,7 +75,7 @@ model: rope_interleaved: false rope_scaling: null rope_theta: 10000.0 - sliding_window_size: 20 + sliding_window_size: null tie_word_embeddings: true use_cache: true vocab_size: 128256 @@ -104,11 +105,10 @@ parallelism: context_parallel_size: 1 dp: 2 expert_parallel_size: 1 - moe_layer_recompute: false pp: 1 pp_engine: 1f1b recompute_layer: false - tp: 1 + tp: 2 tp_linear_async_communication: true tp_mode: REDUCE_SCATTER tp_recompute_allgather: true diff --git a/examples/config_qwen_with_moe.yaml b/examples/config_qwen_with_moe.yaml new file mode 100644 index 000000000..5e51307ff --- /dev/null +++ b/examples/config_qwen_with_moe.yaml @@ -0,0 +1,132 @@ +checkpoints: + checkpoint_interval: 1000 + checkpoints_path: /fsx/phuc/new_workspace/experiments/qwen2_moe_test + checkpoints_path_is_shared_file_system: false + load_lr_scheduler: true + load_optimizer: true + resume_checkpoint_path: null + save_final_state: true + save_initial_state: false +data_stages: +- data: + dataset: + dataset_folder: + - /fsx/loubna/datasets/llama_tokenized/fineweb-edu/merged + dataset_max_tokens: null + dataset_read_path: null + dataset_weights: null + pad_samples_to_global_batch_size: false + return_positions: true + shuffle_files: false + skip_in_stream: false + token_size_in_bytes: 4 + tokenizer_name: meta-llama/Llama-3.2-1B + use_old_brrr_dataloader: false + vocab_size: 128256 + num_loading_workers: 1 + seed: 42 + name: Stable Training Stage + start_training_step: 1 +general: + benchmark_csv_path: null + consumed_train_samples: null + ignore_sanity_checks: false + project: qwen_moe + run: qwen_20250410_014907_16027793 + seed: 42 + step: null +lighteval: null +logging: + iteration_step_info_interval: 1 + log_level: info + log_level_replica: info +metrics_logging: null +model: + ddp_bucket_cap_mb: 25 + dtype: bfloat16 + init_method: + std: 0.025 + make_vocab_size_divisible_by: 1 + model_config: + _attn_implementation: flash_attention_2 + _fused_rms_norm: true + _fused_rotary_emb: true + _use_doc_masking: true + _use_qkv_packed: true + attention_bias: false + bos_token_id: 1 + eos_token_id: 2 + flex_attention_mask: null + hidden_act: silu + hidden_size: 256 + initializer_range: 0.02 + intermediate_size: 768 + is_qwen2_config: true + max_position_embeddings: 4096 + moe_config: null + no_rope_layer: null + num_attention_heads: 4 + num_hidden_layers: 12 + num_key_value_heads: 4 + pad_token_id: null + pretraining_tp: 1 + rms_norm_eps: 1.0e-06 + rope_interleaved: false + rope_scaling: null + rope_theta: 10000.0 + sliding_window_size: 20 + tie_word_embeddings: true + use_cache: true + vocab_size: 128256 + z_loss_coefficient: 0.0001 + z_loss_enabled: false + moe_config: + num_experts: 8 + top_k: 1 + enable_shared_expert: true + token_dispatcher_type: alltoall +optimizer: + accumulate_grad_in_fp32: true + clip_grad: 1.0 + learning_rate_scheduler: + learning_rate: 0.0003 + lr_decay_starting_step: null + lr_decay_steps: 31998 + lr_decay_style: cosine + lr_warmup_steps: 2 + lr_warmup_style: linear + min_decay_lr: 1.0e-05 + optimizer_factory: + adam_beta1: 0.9 + adam_beta2: 0.95 + adam_eps: 1.0e-08 + name: adamW + torch_adam_is_fused: true + weight_decay: 0.01 + weight_decay_exclude_named_params: [] + zero_stage: 0 +parallelism: + context_parallel_size: 1 + dp: 2 + expert_parallel_size: 1 + pp: 1 + pp_engine: 1f1b + recompute_layer: false + tp: 1 + tp_linear_async_communication: true + tp_mode: REDUCE_SCATTER + tp_recompute_allgather: true +profiler: null +s3_upload: null +tokenizer: + tokenizer_max_length: null + tokenizer_name_or_path: meta-llama/Llama-3.2-1B + tokenizer_revision: null +tokens: + batch_accumulation_per_replica: 1 + limit_test_batches: 0 + limit_val_batches: 0 + micro_batch_size: 3 + sequence_length: 4096 + train_steps: 32000 + val_check_interval: -1 diff --git a/examples/inference/qwen_moe/README.md b/examples/inference/qwen_moe/README.md new file mode 100644 index 000000000..2d4ee3974 --- /dev/null +++ b/examples/inference/qwen_moe/README.md @@ -0,0 +1,27 @@ +# Qwen-MoE Inference + +This guide explains how to convert Hugging face Qwen-MoE models to Nanotron format and run inference with them. + +## Convert Qwen-MoE to Nanotron Format + +Navigate to the `inference/qwen_moe` directory and run: + +```bash +torchrun --nproc-per-node 1 examples/inference/qwen_moe/convert.py \ + --nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B \ + --pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B +``` + +This command will save the converted model weights to the specified path in `nanotron_checkpoints` + +## Run Inference + +From the root directory of Nanotron, run: + +```bash +torchrun --rdzv_endpoint=localhost:29700 --rdzv-backend=c10d --nproc_per_node=1 \ + run_generate.py \ + --ckpt-path examples/inference/qwen_moe/nanotron_checkpoints/Qwen1.5-MoE-A2.7B +``` + +This command will load the converted model weights and run inference. diff --git a/examples/inference/qwen_moe/convert.py b/examples/inference/qwen_moe/convert.py new file mode 100644 index 000000000..419495da2 --- /dev/null +++ b/examples/inference/qwen_moe/convert.py @@ -0,0 +1,329 @@ +""" +torchrun --nproc-per-node 1 convert.py --nanotron-checkpoint-path nanotron_checkpoints/Qwen1.5-MoE-A2.7B --pretrained-model-name-or-path Qwen/Qwen1.5-MoE-A2.7B +""" +import argparse +import json +from dataclasses import asdict +from pathlib import Path + +import torch +import yaml +from nanotron import logging +from nanotron.config import Config, GeneralArgs, LoggingArgs, ModelArgs, ParallelismArgs, TokenizerArgs +from nanotron.config.models_config import ExistingCheckpointInit, MoEConfig, Qwen2Config +from nanotron.logging import log_rank, set_ranks_logging_level +from nanotron.models import build_model +from nanotron.models.qwen import Qwen2ForTraining +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import sanity_check +from nanotron.serialize import TrainingMetadata, save_meta, save_weights +from nanotron.serialize.metadata import DataStageMetadata +from nanotron.trainer import mark_tied_parameters +from tqdm import tqdm +from transformers import AutoModelForCausalLM, AutoTokenizer, Qwen2MoeConfig + +logger = logging.get_logger(__name__) + +# NOTE: We need to initialize the model on gpu, because RotaryEmbedding +# requires its buffer to be on gpu +DEVICE = torch.device("cuda") +TORCH_DTYPE = torch.bfloat16 + + +def get_args(): + parser = argparse.ArgumentParser() + group = parser.add_argument_group(title="Nanotron Model") + group.add_argument( + "--nanotron-checkpoint-path", + type=str, + required=True, + help="A path to a directory to store the converted Nanotron Checkpoint", + ) + + group = parser.add_argument_group(title="HuggingFace Model") + group.add_argument( + "--pretrained-model-name-or-path", + type=str, + required=True, + help="A path to a directory containing model weights saved using save_pretrained() or the model id of a pretrained model hosted inside a model repo on the Hugging Face Hub", + ) + + args = parser.parse_args() + + return args + + +def main(args): + # Init Nanotron Parallel Utilities + parallel_config = ParallelismArgs(dp=1, pp=1, tp=1) + + parallel_context = ParallelContext( + data_parallel_size=parallel_config.dp, + pipeline_parallel_size=parallel_config.pp, + tensor_parallel_size=parallel_config.tp, + ) + + set_ranks_logging_level(parallel_context=parallel_context, logging_config=LoggingArgs()) + + # Load Qwen-MoE HF model + log_rank( + f"Loading pretrained qwen moe Model: {args.pretrained_model_name_or_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + hf_model = AutoModelForCausalLM.from_pretrained( + args.pretrained_model_name_or_path, torch_dtype=TORCH_DTYPE, attn_implementation="flash_attention_2" + ).to(DEVICE) + hf_config: Qwen2MoeConfig = hf_model.config + + # Set Nanotron Qwen2Config + nanotron_config = Qwen2Config( + bos_token_id=hf_config.bos_token_id, + eos_token_id=hf_config.eos_token_id, + hidden_act=hf_config.hidden_act, + hidden_size=hf_config.hidden_size, + initializer_range=hf_config.initializer_range, + intermediate_size=hf_config.intermediate_size, + is_qwen2_config=True, + max_position_embeddings=hf_config.max_position_embeddings, + num_attention_heads=hf_config.num_attention_heads, + num_hidden_layers=hf_config.num_hidden_layers, + num_key_value_heads=hf_config.num_key_value_heads, + pad_token_id=None, + attention_bias=True, # qwen-moe uses attention bias + rms_norm_eps=hf_config.rms_norm_eps, + rope_scaling=hf_config.rope_scaling, + rope_theta=hf_config.rope_theta, + rope_interleaved=False, + tie_word_embeddings=hf_config.tie_word_embeddings, + use_cache=hf_config.use_cache, + vocab_size=hf_config.vocab_size, + moe_config=MoEConfig( + top_k=hf_config.num_experts_per_tok, + num_experts=hf_config.num_experts, + moe_intermediate_size=hf_config.moe_intermediate_size, + shared_expert_intermediate_size=hf_config.shared_expert_intermediate_size, + router_aux_loss_coef=hf_config.router_aux_loss_coef, + enable_shared_expert=True, + ), + ) + + # Init Nanotron Qwen-MoE model + log_rank("Init empty Nanotron Qwen Moe Model", logger=logger, level=logging.INFO, rank=0) + nanotron_model = build_model( + model_builder=lambda: Qwen2ForTraining( + config=nanotron_config, + parallel_context=parallel_context, + parallel_config=parallel_config, + ), + parallel_context=parallel_context, + dtype=TORCH_DTYPE, + device=DEVICE, + ) + + mark_tied_parameters(model=nanotron_model, parallel_context=parallel_context, parallel_config=parallel_config) + sanity_check(root_module=nanotron_model) + + # Copy params from HF to Nanotron + log_rank("Copying weights from HF model to Nanotron model...", logger=logger, level=logging.INFO, rank=0) + # Token embeddings + log_rank("Copying Token Embeddings...", logger=logger, level=logging.INFO, rank=0) + + with torch.no_grad(): + # token embeddings + assert ( + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.shape + == hf_model.model.embed_tokens.weight.shape + ) + + nanotron_model.model.token_position_embeddings.pp_block.token_embedding.weight.copy_( + hf_model.model.embed_tokens.weight + ) + + # Decoder layers + for i in tqdm( + range(nanotron_config.num_hidden_layers), + desc="Copying Hidden Layers", + total=nanotron_config.num_hidden_layers, + ): + # Input layer norm + assert ( + hf_model.model.layers[i].input_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.shape + ) + + nanotron_model.model.decoder[i].pp_block.input_layernorm.weight.copy_( + hf_model.model.layers[i].input_layernorm.weight + ) + + # Self attn + ## QKV + tmp_qkv_proj = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.weight, + hf_model.model.layers[i].self_attn.k_proj.weight, + hf_model.model.layers[i].self_attn.v_proj.weight, + ], + dim=0, + ) + assert tmp_qkv_proj.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.weight.copy_(tmp_qkv_proj) + + ## QKV bias + tmp_qkv_bias = torch.cat( + [ + hf_model.model.layers[i].self_attn.q_proj.bias, + hf_model.model.layers[i].self_attn.k_proj.bias, + hf_model.model.layers[i].self_attn.v_proj.bias, + ], + dim=0, + ) + assert tmp_qkv_bias.shape == nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.bias.shape + nanotron_model.model.decoder[i].pp_block.attn.qkv_proj.bias.copy_(tmp_qkv_bias) + + ## O + assert ( + hf_model.model.layers[i].self_attn.o_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.attn.o_proj.weight.copy_( + hf_model.model.layers[i].self_attn.o_proj.weight + ) + + # MLP + ## Router + assert ( + hf_model.model.layers[i].mlp.gate.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.router.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.router.weight.copy_(hf_model.model.layers[i].mlp.gate.weight) + + ## shared expert: Gate Up Proj + tmp_shared_expert = torch.cat( + [ + hf_model.model.layers[i].mlp.shared_expert.gate_proj.weight, + hf_model.model.layers[i].mlp.shared_expert.up_proj.weight, + ], + dim=0, + ) + assert ( + tmp_shared_expert.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.gate_up_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.gate_up_proj.weight.copy_(tmp_shared_expert) + + ## shared expert: Down Proj + assert ( + hf_model.model.layers[i].mlp.shared_expert.down_proj.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.down_proj.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert.down_proj.weight.copy_( + hf_model.model.layers[i].mlp.shared_expert.down_proj.weight + ) + + ## shared expert: Gate + assert ( + hf_model.model.layers[i].mlp.shared_expert_gate.weight.shape + == nanotron_model.model.decoder[i].pp_block.mlp.shared_expert_gate.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.mlp.shared_expert_gate.weight.copy_( + hf_model.model.layers[i].mlp.shared_expert_gate.weight + ) + + ## experts: + # concatenate all gate_up_proj and down_proj weights for experts into merged_gate_up_proj and merged_down_proj + tmp_merged_gate_up_proj = torch.zeros( + nanotron_config.moe_config.num_experts, + nanotron_config.hidden_size, + 2 * nanotron_config.moe_config.moe_intermediate_size, + ) + tmp_merged_down_proj = torch.zeros( + nanotron_config.moe_config.num_experts, + nanotron_config.moe_config.moe_intermediate_size, + nanotron_config.hidden_size, + ) + + for j in range(nanotron_config.moe_config.num_experts): + ## Gate Up Proj + tmp_merged_gate_up_proj[j, :, : nanotron_config.moe_config.moe_intermediate_size] = ( + hf_model.model.layers[i].mlp.experts[j].gate_proj.weight.T + ) + tmp_merged_gate_up_proj[j, :, nanotron_config.moe_config.moe_intermediate_size :] = ( + hf_model.model.layers[i].mlp.experts[j].up_proj.weight.T + ) + + ## Down Proj + tmp_merged_down_proj[j] = hf_model.model.layers[i].mlp.experts[j].down_proj.weight.T + + # copy to merged_gate_up_proj and merged_down_proj + nanotron_model.model.decoder[i].pp_block.mlp.experts.merged_gate_up_proj.copy_(tmp_merged_gate_up_proj) + nanotron_model.model.decoder[i].pp_block.mlp.experts.merged_down_proj.copy_(tmp_merged_down_proj) + + # Post attn layer norm + assert ( + hf_model.model.layers[i].post_attention_layernorm.weight.shape + == nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.shape + ) + nanotron_model.model.decoder[i].pp_block.post_attention_layernorm.weight.copy_( + hf_model.model.layers[i].post_attention_layernorm.weight + ) + + # Last layer norm + log_rank("Copying Final Layer Norm...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.final_layer_norm.pp_block.weight.shape == hf_model.model.norm.weight.shape + nanotron_model.model.final_layer_norm.pp_block.weight.copy_(hf_model.model.norm.weight) + + # LM_Head + log_rank("Copying LM Head...", logger=logger, level=logging.INFO, rank=0) + assert nanotron_model.model.lm_head.pp_block.weight.shape == hf_model.lm_head.weight.shape + nanotron_model.model.lm_head.pp_block.weight.copy_(hf_model.lm_head.weight) + + log_rank("Copied weights from HF model to Nanotron model!", logger=logger, level=logging.INFO, rank=0) + # Store weights + nanotron_checkpoint_path = Path(args.nanotron_checkpoint_path) + save_weights(model=nanotron_model, parallel_context=parallel_context, root_folder=nanotron_checkpoint_path) + + # Store metadata + log_rank("Storing Nanotron model Configs and Metadata!", logger=logger, level=logging.INFO, rank=0) + training_metadata = TrainingMetadata( + last_train_step=0, + consumed_train_samples=0, + data_stages=[DataStageMetadata(name="Empty", consumed_train_samples=0, start_training_step=0)], + ) + save_meta( + root_folder=nanotron_checkpoint_path, parallel_context=parallel_context, training_metadata=training_metadata + ) + # Store Tokenizer into Nanotron Checkpoint folder + tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) + tokenizer.save_pretrained(nanotron_checkpoint_path) + + # Store Config and Model Config files + with open(nanotron_checkpoint_path / "config.yaml", "w") as f: + config = Config( + general=GeneralArgs(project="Nanotron", run="Qwen2-MoE"), + parallelism=parallel_config, + model=ModelArgs( + init_method=ExistingCheckpointInit(nanotron_checkpoint_path), + model_config=nanotron_config, + ), + tokenizer=TokenizerArgs(tokenizer_name_or_path=args.pretrained_model_name_or_path), + ) + log_rank("Saving config ...", logger=logger, level=logging.INFO, rank=0) + yaml.dump(config.as_dict(), f) + + with open(nanotron_checkpoint_path / "model_config.json", "w") as f: + log_rank("Saving model config ...", logger=logger, level=logging.INFO, rank=0) + json.dump(asdict(nanotron_config), f) + + log_rank( + f"Checkpoint conversion finished, check {args.nanotron_checkpoint_path}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + +if __name__ == "__main__": + _args = get_args() + main(_args) diff --git a/examples/llama/tests/test_conversion.py.orig b/examples/llama/tests/test_conversion.py.orig deleted file mode 100644 index af0688371..000000000 --- a/examples/llama/tests/test_conversion.py.orig +++ /dev/null @@ -1,264 +0,0 @@ -# ruff: noqa: E402 -import json -<<<<<<< HEAD -from pathlib import Path -======= ->>>>>>> main - -import pytest -import torch -from transformers import LlamaForCausalLM -from utils import set_system_path - -set_system_path() - -import nanotron -from nanotron.config import LlamaConfig as NanotronLlamaConfig -from nanotron.models.base import init_on_device_and_dtype -from nanotron.models.llama import LlamaForTraining -from nanotron.parallel import ParallelContext - -from examples.llama.convert_hf_to_nanotron import convert_checkpoint_and_save as convert_hf_to_nt_and_save -<<<<<<< HEAD -from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save -from examples.llama.convert_hf_to_nanotron import convert_hf_to_nt -from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config -from examples.llama.convert_weights import load_nanotron_model -from tests.helpers.context import TestContext -from tests.helpers.utils import init_distributed -======= -from examples.llama.convert_hf_to_nanotron import convert_hf_to_nt -from examples.llama.convert_nanotron_to_hf import convert_checkpoint_and_save as convert_nt_to_hf_and_save -from examples.llama.convert_nanotron_to_hf import convert_nt_to_hf, get_hf_config -from examples.llama.convert_weights import load_nanotron_model, make_parallel_config -from tests.helpers.context import TestContext -from tests.helpers.utils import init_distributed, rerun_if_address_is_in_use ->>>>>>> main - -CONFIG = NanotronLlamaConfig( - **{ - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 512, - "initializer_range": 0.02, - "intermediate_size": 1024, - "is_llama_config": True, - "max_position_embeddings": 128, - "num_attention_heads": 8, - "num_hidden_layers": 4, - "num_key_value_heads": 4, - "pad_token_id": None, - "pretraining_tp": 1, - "rms_norm_eps": 1e-06, - "rope_scaling": None, - "tie_word_embeddings": False, - "use_cache": True, - "vocab_size": 4096, - } -) - - -BATCH_SIZE = 3 -SEQUENCE_LENGTH = 5 -ATOL = 0.02 - - -def create_nanotron_model(pp: int = 1, tp: int = 1, dp: int = 1) -> LlamaForTraining: - parallel_config = make_parallel_config(dp, pp, tp) - return load_nanotron_model(parallel_config, CONFIG, torch.device("cuda"), torch.bfloat16) - - -def create_huggingface_model() -> LlamaForCausalLM: - config_hf = get_hf_config(CONFIG) - with init_on_device_and_dtype(torch.device("cuda"), torch.bfloat16): - model_hf = LlamaForCausalLM._from_config(config_hf) - return model_hf - - -@pytest.fixture(autouse=True, scope="module") -def fix_seed(): - torch.manual_seed(0) - yield - - -@pytest.fixture -def input_ids() -> torch.Tensor: - return torch.randint(0, CONFIG.vocab_size, size=(BATCH_SIZE, SEQUENCE_LENGTH), device="cuda") - - -def _test_nt_to_hf(parallel_context: ParallelContext, input_ids: torch.Tensor): - model_nt = create_nanotron_model() - model_hf = create_huggingface_model() - convert_nt_to_hf(model_nt, model_hf, CONFIG) - input_mask = torch.ones_like(input_ids) - logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) - logits_hf = model_hf(input_ids).logits - assert logits_nt.size() == logits_hf.size() - assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) - - -def test_nt_to_hf(input_ids: torch.Tensor): - init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf)(input_ids=input_ids) - - -def _test_nt_to_hf_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext): - # Create and save nanotron model. - model_nt = create_nanotron_model() - root = test_context.get_auto_remove_tmp_dir() - nt_path = root / "nanotron" - hf_path = root / "hf" - nanotron.serialize.save_weights(model=model_nt, parallel_context=parallel_context, root_folder=nt_path) - with open(nt_path / "model_config.json", "w+") as f: - json.dump(vars(CONFIG), f) - input_mask = torch.ones_like(input_ids) - logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) - del model_nt - # Perform conversion. - convert_nt_to_hf_and_save(nt_path, hf_path) - # Load huggingface and get logits. - model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() - logits_hf = model_hf(input_ids).logits - assert logits_nt.size() == logits_hf.size() - assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) - - -def test_nt_to_hf_with_files(input_ids: torch.Tensor): - init_distributed(tp=1, dp=1, pp=1)(_test_nt_to_hf_with_files)(input_ids=input_ids, test_context=TestContext()) - - -def _test_hf_to_nt(parallel_context: ParallelContext, input_ids: torch.Tensor): - model_nt = create_nanotron_model() - model_hf = create_huggingface_model() - convert_hf_to_nt(model_hf, model_nt, CONFIG) - input_mask = torch.ones_like(input_ids) - logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) - logits_hf = model_hf(input_ids).logits - assert logits_nt.size() == logits_hf.size() - assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) - - -def test_hf_to_nt(input_ids: torch.Tensor): - init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt)(input_ids=input_ids) - - -def _test_hf_to_nt_with_files(parallel_context: ParallelContext, input_ids: torch.Tensor, test_context: TestContext): - # Create and save hf model. - model_hf = create_huggingface_model() - root = test_context.get_auto_remove_tmp_dir() - nt_path = root / "nanotron" - hf_path = root / "hf" - model_hf.save_pretrained(hf_path) - logits_hf = model_hf(input_ids).logits - del model_hf - # Perform conversion. - convert_hf_to_nt_and_save(hf_path, nt_path) - # Load nanotron and get logits. - input_mask = torch.ones_like(input_ids) - model_nt = load_nanotron_model(checkpoint_path=nt_path) - logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) - assert logits_nt.size() == logits_hf.size() - assert torch.allclose(logits_nt, logits_hf, atol=ATOL) - - -def test_hf_to_nt_with_files(input_ids: torch.Tensor): - init_distributed(tp=1, dp=1, pp=1)(_test_hf_to_nt_with_files)(input_ids=input_ids, test_context=TestContext()) - - -def _test_composed_conversion(parallel_context: ParallelContext): - # Get HF statedict. - model_hf = create_huggingface_model() - hf_sd = {key: val.clone() for key, val in model_hf.state_dict().items()} - # Convert once to nanotron, save its statedict. - model_nt = create_nanotron_model() - convert_hf_to_nt(model_hf, model_nt, CONFIG) - nt_sd = {key: val.clone() for key, val in model_nt.state_dict().items()} - # Convert back to HF, compare statedicts. - del model_hf - model_hf = create_huggingface_model() - convert_nt_to_hf(model_nt, model_hf, CONFIG) - hf_sd_new = model_hf.state_dict() - assert set(hf_sd_new) == set(hf_sd) - assert all(torch.all(hf_sd[key] == hf_sd_new[key]) for key in hf_sd_new) - # Convert to nanotron one more time, compare statedicts. - del model_nt - model_nt = create_nanotron_model() - convert_hf_to_nt(model_hf, model_nt, CONFIG) - nt_sd_new = model_nt.state_dict() - assert set(nt_sd_new) == set(nt_sd) - assert all(torch.all(nt_sd[key] == nt_sd_new[key]) for key in nt_sd_new) - - -def test_composed_conversion(): - init_distributed(tp=1, dp=1, pp=1)(_test_composed_conversion)() - - -<<<<<<< HEAD -def _save_parallel_nanotron(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path): - # Create and save a parallel model. - model_nt = create_nanotron_model(tp=parallel_context.tensor_parallel_size, pp=parallel_context.pipeline_parallel_size) - # print(torch.distributed.get_rank(), "model_nt", set(p.device for p in model_nt.parameters())) - nanotron.serialize.save_weights(model=model_nt, parallel_context=parallel_context, root_folder=nt_path) - with open(nt_path/"model_config.json", "w+") as f: - json.dump(vars(CONFIG), f) - - # Get parallel predictions. - input_ids = input_ids.cuda() # Move them to the current device index. - input_mask = torch.ones_like(input_ids) - # print(torch.distributed.get_rank(), "input_ids", input_ids.device) - logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) - if torch.distributed.get_rank() == 0: - torch.save(logits_nt.detach().cpu(), nt_path/"logits.pt") - # print(torch.distributed.get_rank(), logits_nt.shape) - - # Convert nanotron to hf, load it and compare logits. - # hf_path = root/"hf" - # convert_nt_to_hf_and_save(nt_path, hf_path) - # model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() - # logits_hf = model_hf(input_ids).logits - - # assert logits_nt.size() == logits_hf.size() - # assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) - - -def _convert_from_parallel(parallel_context: ParallelContext, input_ids: torch.Tensor, nt_path: Path, hf_path: Path): - # Convert parallel nanotron to hf, get and save huggingface predictions. - convert_nt_to_hf_and_save(nt_path, hf_path) - model_hf = LlamaForCausalLM.from_pretrained(hf_path).cuda() - logits_hf = model_hf(input_ids).logits - torch.save(logits_hf.detach().cpu(), hf_path/"logits.pt") - -def test_tensor_parallel_conversion(input_ids: torch.Tensor): - # Set up test. - test_context = TestContext() - root = test_context.get_auto_remove_tmp_dir() - nt_path =root/"nanotron" - hf_path =root/"nanotron" - - # Launch both parts. - init_distributed(tp=2, dp=1, pp=1)(_save_parallel_nanotron)(input_ids=input_ids, nt_path=nt_path) - assert (nt_path/"logits.pt").exists() - init_distributed(tp=1, dp=1, pp=1)(_convert_from_parallel)(input_ids=input_ids, nt_path=nt_path, hf_path=hf_path) - assert (hf_path/"logits.pt").exists() - - # Load logits and verify they match. - logits_nt = torch.load(nt_path/"logits.pt") - logits_hf = torch.load(hf_path/"logits.pt") - assert logits_nt.size() == logits_hf.size() - assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) -======= -def _test_tensor_parallel_conversion(parallel_context: ParallelContext): - model_nt = create_nanotron_model(tp=2) - model_hf = create_huggingface_model() - convert_nt_to_hf(model_nt, model_hf, CONFIG) - input_mask = torch.ones_like(input_ids) - logits_nt = model_nt.model(input_ids, input_mask).permute(1, 0, 2) - logits_hf = model_hf(input_ids).logits - assert logits_nt.size() == logits_hf.size() - assert torch.allclose(logits_nt, logits_hf, atol=ATOL), torch.mean(torch.abs(logits_nt - logits_hf)) - - -@rerun_if_address_is_in_use() -def test_tensor_parallel_conversion(): - init_distributed(tp=2, dp=1, pp=1)(_test_tensor_parallel_conversion)() ->>>>>>> main diff --git a/examples/mamba/trainer.py b/examples/mamba/trainer.py index 0ca66d12e..a0100df0e 100644 --- a/examples/mamba/trainer.py +++ b/examples/mamba/trainer.py @@ -50,7 +50,6 @@ def _mark_tied_parameters( target, ( parallel_context.world_rank_matrix[ - dist.get_rank(parallel_context.ep_pg), get_pp_rank_of(target, module=model), dist.get_rank(parallel_context.dp_pg), dist.get_rank(parallel_context.tp_pg), diff --git a/examples/moe/README.md b/examples/moe/README.md index 43a7e3551..8141d250c 100644 --- a/examples/moe/README.md +++ b/examples/moe/README.md @@ -2,34 +2,12 @@ library_name: nanotron --- -# LlaMoE - -Modeling code for LlaMoE to use with [Nanotron](https://github.com/huggingface/nanotron/) - -## 🚀 Quickstart +### Benchmark ```bash -# Generate a config file -python examples/moe/config_llamoe.py - -# Install megablocks -pip install megablocks - -# Run training -export CUDA_DEVICE_MAX_CONNECTIONS=1 # important for some distributed operations -torchrun --nproc_per_node=4 examples/moe/train_moe.py --config-file examples/moe/config_llamoe.yaml +./examples/moe/benchmark_moe.sh /fsx/phuc/new_workspace/experiments/qwen_moe/benchmark/exp0a0_benhmark_num_experts_topk_and_ep_in_a_node ``` -## 🚀 Use your custom model -- Update the `LlaMoEConfig` class in `config_llamoe.py` to match your model's configuration -- Update the `LlaMoEForTraining` class in `modeling_llamoe.py` to match your model's architecture -- Pass the previous to the `DistributedTrainer` class in `train_moe.py`: -```python -trainer = DistributedTrainer(config_file, model_class=LlaMoEForTraining, model_config_class=LlaMoEConfig) -``` -- Run training as usual - - ## Credits Credits to the following repositories from which the code was adapted: - https://github.com/huggingface/transformers/blob/main/src/transformers/models/mixtral/modeling_mixtral.py diff --git a/examples/moe/benchmark_moe.py b/examples/moe/benchmark_moe.py new file mode 100644 index 000000000..9ccff1bda --- /dev/null +++ b/examples/moe/benchmark_moe.py @@ -0,0 +1,652 @@ +# fmt: off +# Generates configs and SLURM scripts for scaling benchmarks used in https://huggingface.co/spaces/nanotron/ultrascale-playbook +import argparse +import math +import os +from datetime import datetime +from typing import Optional + +import pandas as pd +import yaml +from nanotron.logging import human_format +from tqdm import tqdm + +ACCUMULATE_GRAD_IN_FP32 = True +NUM_KEY_VALUE_HEADS = None # Although it's necessary to reduce 450B to 400B model, we can't bench tp>8 if we use 8 KV heads +# NUM_KEY_VALUE_HEADS = 8 # Although it's necessary to reduce 450B to 400B model, we can't bench tp>8 if we use 8 KV heads + + +from dataclasses import dataclass + + +@dataclass +class BenchmarkConfig: + dp: int + tp: int + pp: int + ep: int + etp: int + edp: int + + +def estimate_num_params( + layers, + hidden_size, + heads, + intermediate_size, + tie_word_embeddings, + vocab, + kv_heads=None, +): + # params = 2*V*h + l(3*h*H + (2 + 2*q/kv_ratio)*h*h) + # For GQA with 8 KV heads and 32 attention heads (4x ratio), it's: 2*V*h + l(3*h*H + (2 + 2/4)*h*h) + vocab = vocab * hidden_size if tie_word_embeddings else 2 * vocab * hidden_size + kv_ratio = kv_heads / heads if kv_heads is not None else 1 + qkv_params = (2 + 2 * kv_ratio) * hidden_size * hidden_size # Account for GQA + return vocab + layers * (3 * hidden_size * intermediate_size + qkv_params) + + +def create_config( + base_config_path: str, + dp: int, + tp: int, + pp: int, + expert_parallel_size: int, + expert_tensor_parallel_size: int, + expert_data_parallel_size: int, + num_experts: Optional[int] = None, + topk: Optional[int] = None, + batch_accum: Optional[int] = None, + seq_len: Optional[int] = None, + benchmark_csv_path: Optional[str] = None, + micro_batch_size: Optional[int] = None, + zero_stage: Optional[int] = None, + num_layers: Optional[int] = None, + hidden_size: Optional[int] = None, + num_attention_heads: Optional[int] = None, + intermediate_size: Optional[int] = None, + tp_mode: Optional[str] = None, + vocab_size: Optional[int] = None, + profile: bool = False, +) -> dict: + """Create a config with the specified parallelism settings.""" + with open(base_config_path) as f: + config = yaml.safe_load(f) + + # Parallelism settings (required args) + config["parallelism"].update({ + "dp": dp, + "tp": tp, + "pp": pp, + "expert_parallel_size": expert_parallel_size, + "expert_tensor_parallel_size": expert_tensor_parallel_size, + "expert_data_parallel_size": expert_data_parallel_size + }) + + # Optional parameter updates + optional_updates = { + "tokens": { + "batch_accumulation_per_replica": batch_accum, + "sequence_length": seq_len, + "micro_batch_size": micro_batch_size + }, + "model": { + "model_config": { + "num_hidden_layers": num_layers, + "hidden_size": hidden_size, + "num_attention_heads": num_attention_heads, + "intermediate_size": intermediate_size, + "vocab_size": vocab_size, + "moe_config": { + "num_experts": num_experts, + "topk": topk + } + } + }, + "optimizer": {"zero_stage": zero_stage}, + "parallelism": {"tp_mode": tp_mode} + } + + for section, values in optional_updates.items(): + for key, value in values.items(): + if value is not None: + if isinstance(value, dict): # Handle nested model_config + for sub_key, sub_value in value.items(): + if sub_key == "hidden_size" or sub_value == "hidden_size": + assert 1 == 1 + + if sub_value is not None: + config[section][key][sub_key] = sub_value + else: + config[section][key] = value + # except: + # assert 1 == 1 + + # Handle special cases + if seq_len is not None: + config["model"]["model_config"]["max_position_embeddings"] = seq_len + + if num_attention_heads is not None: + config["model"]["model_config"]["num_key_value_heads"] = ( + NUM_KEY_VALUE_HEADS if NUM_KEY_VALUE_HEADS is not None + else num_attention_heads + ) + + if intermediate_size is None and hidden_size is not None: + config["model"]["model_config"]["intermediate_size"] = 4 * hidden_size + + # config["model"]["model_config"]["tie_word_embeddings"] = ( + # config["model"]["model_config"]["intermediate_size"] < 10_000 + # ) + config["model"]["model_config"]["tie_word_embeddings"] = True + + # Build run name + run_name_parts = [ + human_format(estimate_num_params( + config["model"]["model_config"]["num_hidden_layers"], + config["model"]["model_config"]["hidden_size"], + config["model"]["model_config"]["num_attention_heads"], + config["model"]["model_config"]["intermediate_size"], + config["model"]["model_config"]["tie_word_embeddings"], + config["model"]["model_config"]["vocab_size"], + )), + f"dp{dp}", f"tp{tp}", f"pp{pp}", + f"ep{expert_parallel_size}", f"etp{expert_tensor_parallel_size}", f"edp{expert_data_parallel_size}", + f"num_experts{num_experts}", f"topk{topk}", + f"acc{batch_accum}" if batch_accum is not None else "", + f"mbs{micro_batch_size}" if micro_batch_size is not None else "", + f"seq{seq_len}" if seq_len is not None else "", + f"zero{zero_stage}" if zero_stage is not None else "", + f"tpmode{tp_mode[:3]}" if tp_mode else "", + f"vocab{config['model']['model_config']['vocab_size']//1000}k" + ] + + config["general"]["run"] = "_".join(filter(None, run_name_parts)) + if NUM_KEY_VALUE_HEADS is not None: + config["general"]["run"] += f"_gqa{NUM_KEY_VALUE_HEADS}" + if profile: + config["general"]["run"] += "_prof" + config["profiler"] = {"profiler_export_path": "./tb_logs"} + config["tokens"]["train_steps"] = 10 + + config["general"]["benchmark_csv_path"] = benchmark_csv_path + return config + + +def generate_slurm_script( + config: dict, + dp: int, + tp: int, + pp: int, + time: str = "00:02:00", + partition: str = "hopper-prod", + base_script_path: str = "run_multinode.sh", + use_bash: bool = False, +) -> str: + """Generate a SLURM script for the given configuration.""" + # Check if base script exists + if not os.path.exists(base_script_path): + raise FileNotFoundError(f"Base script file not found: {base_script_path}") + + # Load base script + with open(base_script_path) as f: + script = f.read() + + # Calculate required number of nodes + total_gpus_needed = dp * tp * pp + gpus_per_node = min(8, total_gpus_needed) + num_nodes = math.ceil(total_gpus_needed / gpus_per_node) + + # Replace SLURM parameters + replacements = { + "--nodes=2": f"--nodes={num_nodes}", + # "export SLURM_NNODES=2": f"export SLURM_NNODES={num_nodes}", + "--time=00:02:00": f"--time={time}", + "--partition=hopper-prod": f"--partition={partition}", + "--job-name=smolm2-bench": f"--job-name=bench_{config['general']['run']}", + 'JOBNAME="smolm2-bench"': f'JOBNAME="bench_{config["general"]["run"]}"', + "examples/config_tiny_llama.yaml": f"benchmark/configs/config_{config['general']['run']}.yaml", + "export GPUS_PER_NODE=8": f"export GPUS_PER_NODE={gpus_per_node}", + } + + for old, new in replacements.items(): + if old not in script: + raise ValueError(f"Could not find '{old}' in base script") + script = script.replace(old, new) + + return script + + +def check_params(model_configs): + for model_name, ( + num_layers, + hidden_size, + num_heads, + intermediate_size, + ) in model_configs.items(): + print(f"{model_name} model parameters:") + tie = True if intermediate_size < 10_000 else False + print(f" Embedding params: {human_format(estimate_num_params(num_layers, hidden_size, num_heads, intermediate_size, tie, 131072, 8))}") + print() + + exit() + + +def save_experiment_configs(configs, output_path, job_ids=None): + """Save core experiment configurations for tracking""" + records = [] + + for i, config in enumerate(configs): + # Calculate total params + tie_word_embeddings = True if config["model"]["model_config"]["intermediate_size"] < 10_000 else False + estimate_num_params( + config["model"]["model_config"]["num_hidden_layers"], + config["model"]["model_config"]["hidden_size"], + config["model"]["model_config"]["num_attention_heads"], + config["model"]["model_config"]["intermediate_size"], + tie_word_embeddings, + config["model"]["model_config"]["vocab_size"], + NUM_KEY_VALUE_HEADS, + ) + record = { + "date": datetime.now().strftime("%Y-%m-%d %H:%M:%S"), + "name": config["general"]["run"], + "nodes": config["parallelism"]["dp"] * config["parallelism"]["tp"] * config["parallelism"]["pp"] / 8, + "seq_len": config["tokens"]["sequence_length"], + "mbs": config["tokens"]["micro_batch_size"], + "batch_accum": config["tokens"]["batch_accumulation_per_replica"], + "gbs": config["tokens"]["sequence_length"] * config["tokens"]["micro_batch_size"] * config["tokens"]["batch_accumulation_per_replica"] * config["parallelism"]["dp"], + "dp": config["parallelism"]["dp"], + "pp": config["parallelism"]["pp"], + "tp": config["parallelism"]["tp"], + "tp_mode": f"{config['parallelism']['tp_mode']}", + "hidden_size": config["model"]["model_config"]["hidden_size"], + "num_layers": config["model"]["model_config"]["num_hidden_layers"], + "num_heads": config["model"]["model_config"]["num_attention_heads"], + "vocab_size": config["model"]["model_config"]["vocab_size"], + "zero_stage": config["optimizer"]["zero_stage"], + "job_id": job_ids[i] if job_ids else None, + } + records.append(record) + + # Save to CSV + if os.path.exists(output_path): + # Read existing data and append new records + existing_df = pd.read_csv(output_path) + df = pd.DataFrame(records) + df = pd.concat([existing_df, df], ignore_index=True) + else: + df = pd.DataFrame(records) + + df.to_csv(output_path, index=False) + print(f"Saved {len(records)} experiment configurations to {output_path}") + + +def main(): + parser = argparse.ArgumentParser(description="Run scaling benchmarks with different parallelism configurations") + parser.add_argument( + "--configs-dir", + type=str, + default="benchmark/configs", + help="Directory to store generated configs", + ) + parser.add_argument( + "--scripts-dir", + type=str, + default="benchmark/scripts", + help="Directory to store generated SLURM scripts", + ) + parser.add_argument("--partition", type=str, default="hopper-prod", help="SLURM partition to use") + parser.add_argument("--time", type=str, default="00:40:00", help="Time limit for each job") + parser.add_argument( + "--base-config", + type=str, + default="examples/config_tiny_llama_bench.yaml", + help="Base configuration file to use", + ) + parser.add_argument( + "--base-script", + type=str, + default="run_multinode.sh", + help="Base SLURM script to use", + ) + parser.add_argument( + "--pending-csv", + type=str, + default="benchmark/results/pending_experiments2.csv", + help="CSV file to store pending experiments", + ) + parser.add_argument( + "--benchmark-csv", + type=str, + default="benchmark/results/bench_final2.csv", + help="CSV file to store benchmark results", + ) + parser.add_argument( + "--run", + action="store_true", + help="Automatically submit all generated SLURM scripts", + ) + parser.add_argument("--debug", action="store_true", help="Debug mode") + parser.add_argument( + "--limit", + type=str, + default=None, + help="Limit the number of configurations to run (e.g. 100:200)", + ) + parser.add_argument("--profile", action="store_true", help="Enable profiling") + parser.add_argument("--use-bash", action="store_true", help="Use bash instead of sbatch") + args = parser.parse_args() + + # Parse limit argument if provided + if args.limit is not None: + if ":" in args.limit: + start, end = args.limit.split(":") + start = int(start) if start else None + end = int(end) if end else None + args.limit = slice(start, end) + else: + args.limit = slice(int(args.limit)) + + # Validate input files exist + if not os.path.exists(args.base_config): + raise FileNotFoundError(f"Base config file not found: {args.base_config}") + if not os.path.exists(args.base_script): + raise FileNotFoundError(f"Base script file not found: {args.base_script}") + + # Create directories if they don't exist + for directory in [args.configs_dir, args.scripts_dir]: + os.makedirs(directory, exist_ok=True) + + # # Define model configurations + # model_configs = { + # # (layers, hidden_size, heads, intermediate_size) + # # "1B": (16, 2048, 32, 8192), # 1.2G + # "3B": (28, 3072, 32, 8192), # 3.57G 24heads -> 32heads + # # "4B": (30, 3072, 32, 8192), # 30 layers distributed among PP-2 + # # "8B": (32, 4096, 32, 14336), # 8.0G + # # "70B": (80, 8192, 64, 28672), # 70G + # # "405B": (126, 16384, 128, 53248), # 406G + # } + + + # NOTE: dp, tp, pp, ep, etp, edp + parallel_configs = [ + (1, 1, 1, 1, 1, 1), + (4, 1, 1, 4, 1, 1), + (2, 1, 1, 2, 1, 1), + (8, 1, 1, 8, 1, 1), + ] + + num_experts_configs = [4, 16, 32, 64] + # NOTE: qwen use has a topk of 4, qwen3 use has a topk of 8 + topk_configs = [1, 2, 4, 8] + + # Define configurations to test + configurations = [] + + for dp, tp, pp, ep, etp, edp in parallel_configs: + for num_experts in num_experts_configs: + for topk in topk_configs: + configurations.append((dp, tp, pp, ep, etp, edp, num_experts, topk)) + + # For each model size, test different GPU configurations + # for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items(): + # vocab_size = 32768 + # zero_stage = 0 + # tp_mode = "REDUCE_SCATTER" + # configs = [ # 64 nodes max + # # 2k, 8k, 32k + # # GBS: 1M, 4M + # # Format: (dp, tp, pp, batch_accum, seq_len, mbs, ...) + # # Using SP what's the biggest seqlen we can fit? + # # (1, 8, 1, 1, 2048, 1, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 8, 1, 1, 2048, 2, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 8, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 8, 1, 1, 2048, 32, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # best run + # # (1, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode), + + # # test zero + # # (3, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode), + # # (3, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 1, tp_mode), + # # (24, 1, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 0, tp_mode), + # # (24, 1, 1, 1, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, 1, tp_mode), + # # test tp mode + # # (1, 8, 1, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, "ALL_REDUCE"), + # # test pp + # # (1, 1, 8, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 8, 2, 1, 2048, 64, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 1, 8, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 2, 8, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 2, 64, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # # (1, 2, 16, 8, 2048, 8, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode), + # ] + # configurations.extend(configs) + + + # TP scaling tests with corresponding max batch sizes + # tp_mbs_configs = [ + # # Format: (tp, mbs) + # # TP=1 OOMs + # (2, 3), # 363.66 TFLOPs, 14167.25 tok/s/gpu + # (4, 9), # 345.51 TFLOPs, 13460.16 tok/s/gpu (-5%) + # (8, 18), # 279.50 TFLOPs, 10888.53 tok/s/gpu (-19%) + # (16, 40), # 158.10 TFLOPs, 6159.30 tok/s/gpu (-43%) + # (32, 90), # 92.66 TFLOPs, 3609.73 tok/s/gpu (-41%) + # ] + # TP_, MBS_ = tp_mbs_configs[4] + + + # Method 2: Parameter combinations + # PARALLEL_CONFIGS = [(1, 2, 1) + # for dp in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] + # for tp in [1, 2, 4, 8, 16, 32] + # for pp in [2]] + # # Sort PARALLEL_CONFIGS by total GPU count (dp*tp*pp) ascending + # PARALLEL_CONFIGS = sorted(PARALLEL_CONFIGS, key=lambda x: x[0] * x[1] * x[2]) + # SEQUENCE_LENGTHS = [4096] + # MBS = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] # ~1M, 4M + # MBS = [2] # ~1M, 4M + # GRAD_ACCUM_STEPS = [1] # ~1M, 4M + # VOCAB_SIZES = [131072] # 49152 131072 + # ZERO_STAGES = [0] # 0 if dp>=32 and model<80 / if no need for memory + # TP_MODES = ["REDUCE_SCATTER"] + # # TP_MODES = ["ALL_REDUCE"] + # GBS = [512 * 2048] # 1M + # MIN_NODES = 0 + # MAX_NODES = 10 + + # time = 0 + # TIME_PER_CONFIG = 2 # 2 minutes per config + # counter = 0 + # configurations = [] + # for pp, tp, dp in PARALLEL_CONFIGS: + # for model_name, (num_layers, hidden_size, num_heads, intermediate_size) in model_configs.items(): + # for seq_len in SEQUENCE_LENGTHS: + # for mbs in MBS: + # for batch_accum in GRAD_ACCUM_STEPS: + # for vocab_size in VOCAB_SIZES: + # for zero_stage in ZERO_STAGES: + # for tp_mode in TP_MODES: + # # batch_accum = pp-1 + + # # Optional: Add conditions to filter out unwanted combinations + # total_gpus = dp * tp * pp + # if not MIN_NODES <= total_gpus / 8 <= MAX_NODES: + # print(f"Skipping config - nodes {total_gpus/8} not in range [{MIN_NODES}, {MAX_NODES}]") + # continue + + # tokens_per_step = dp * mbs * batch_accum * seq_len + # # if tokens_per_step not in GBS: + # # continue + # # if batch_accum > 1: + # # print(f"Skipping config - batch_accum {batch_accum} > 1") + # # continue + + # # if dp=1 skip zero stage 1 + # if dp == 1 and zero_stage == 1: + # print(f"Skipping config - dp=1 with zero stage 1") + # continue + + # # if tp=1 skip tp_mode=ALL_REDUCE + # # if tp == 1 and tp_mode == "ALL_REDUCE": + # # print(f"Skipping config - tp=1 with ALL_REDUCE") + # # continue + + # if batch_accum < pp - 1: + # print(f"Skipping config - batch_accum {batch_accum} < pp-1 ({pp-1})") + # continue + + # if model_name == "1B" and pp > 21: # too many pp for numlayers + # print(f"Skipping config - 1B model with pp {pp} > 21") + # continue + + # config = (dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode) + # if config not in configurations: + # counter += 1 + # time += total_gpus * TIME_PER_CONFIG / 8 / MAX_NODES # 2 minutes per config + # configurations.append(config) + + # # print(f"experiments: {counter}") + # # print(f"time (days): {time/60/24} | {time/60:.2f} hours") + # # print(len(configurations)) + + # # # Load configs from pickle file + # # import pickle + # # with open('configs.pkl', 'rb') as f: + # # configurations = pickle.load(f) + + # # validate configs + # new_configs = [] + # for config in configurations: + # # config = (dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode) + # dp, tp, pp, batch_accum, seq_len, mbs, num_layers, hidden_size, num_heads, intermediate_size, vocab_size, zero_stage, tp_mode = config + # tokens_per_step = dp * mbs * batch_accum * seq_len + # # if tokens_per_step not in GBS: + # # print(f"Invalid config: {config} | tokens_per_step: {tokens_per_step}") + # # continue + # if dp == 1 and zero_stage == 1: + # print(f"Invalid config: {config} | dp: {dp} | zero_stage: {zero_stage}") + # continue + # # if tp == 1 and tp_mode == "ALL_REDUCE": + # # print(f"Invalid config: {config} | tp: {tp} | tp_mode: {tp_mode}") + # # continue + # if batch_accum < pp - 1: + # print(f"Invalid config: {config} | batch_accum: {batch_accum} | pp: {pp}") + # continue + # new_configs.append(config) + # configurations = new_configs + + print(len(configurations)) + + if args.debug: + print("Debug mode: only running 1 configuration") + configurations = configurations[:1] + + if isinstance(args.limit, slice): + print(f"Limiting to {args.limit} configurations") + configurations = configurations[args.limit] + elif isinstance(args.limit, int): + print(f"Limiting to {args.limit} configurations") + configurations = configurations[: args.limit] + + # run first 100 configurations + # configurations = configurations[:120+5000] + + + # load data + # import pandas as pd + # old_results_df = pd.read_csv('benchmark/results/bench_final2_mfu2.csv') + # old_results_df = old_results_df[old_results_df['status'].isin(['Success', 'OOM'])] + + # Generate configs and scripts + generated_scripts = [] + configs = [] + for dp, tp, pp, ep, etp, edp, num_experts, topk in tqdm(configurations, desc="Generating configs and scripts"): + try: + # Create config + config = create_config( + dp=dp, + tp=tp, + pp=pp, + expert_parallel_size=ep, + expert_tensor_parallel_size=etp, + expert_data_parallel_size=edp, + num_experts=num_experts, + topk=topk, + base_config_path=args.base_config, + profile=args.profile, + benchmark_csv_path=args.benchmark_csv, + ) + + # if config['general']['run'] in old_results_df['name'].values: + # #job_id < 14097150 + # if pp==1: + # print(f"Skipping {config['general']['run']} because it already exists in old_results_df") + # continue + # elif int(old_results_df[old_results_df['name']==config['general']['run']]['job_id'].values[0]) >= 14097150: + # print(f"Skipping {config['general']['run']} because it already exists in old_results_df") + # continue + + # Save config + config_path = os.path.join(args.configs_dir, f"config_{config['general']['run']}.yaml") + with open(config_path, "w") as f: + yaml.dump(config, f, default_flow_style=False) + + # Generate and save SLURM script + script = generate_slurm_script(config, dp, tp, pp, time=args.time, partition=args.partition, base_script_path=args.base_script, use_bash=args.use_bash) + + script_path = os.path.join(args.scripts_dir, f"run_{config['general']['run']}.sh") + with open(script_path, "w") as f: + f.write(script) + + # Make script executable + os.chmod(script_path, 0o755) + + generated_scripts.append(script_path) + configs.append(config) + + except Exception as e: + print(f"Error processing configuration (dp={dp}, tp={tp}, pp={pp}): {str(e)}") + + # Submit jobs if requested + job_ids = [] + if args.run: + import subprocess + + print("\nSubmitting jobs...") + for script_path, config in tqdm(zip(generated_scripts, configs), desc="Submitting jobs"): + try: + if args.use_bash: + env = os.environ.copy() + salloc_jobid = os.environ.get("SALLOC_JOBID") + if not salloc_jobid: + raise ValueError("SALLOC_JOBID environment variable is required but not set. Please define it in your environment.") + env["SALLOC_JOBID"] = os.environ.get("SALLOC_JOBID") + env["NNODES"] = str(config["parallelism"]["dp"] * config["parallelism"]["tp"] * config["parallelism"]["pp"] // 8) + result = subprocess.run(["bash", script_path], check=True, env=env) + job_id = None # No job ID for bash execution + print(f"bash {script_path}") + else: + result = subprocess.run(["sbatch", script_path], check=True, capture_output=True, text=True) + # Extract job ID from sbatch output (format: "Submitted batch job 123456") + job_id = result.stdout.strip().split()[-1] + print(f"sbatch {script_path}: {result.stdout.strip()}") + job_ids.append(job_id) + except subprocess.CalledProcessError as e: + print(f"Error {'running' if args.use_bash else 'submitting'} {script_path}: {e.stderr}") + job_ids.append(None) + + # Save configs with job IDs + save_experiment_configs(configs, args.pending_csv, job_ids=job_ids) + + else: + print("\nTo run individual jobs:") + for script_path in generated_scripts: + print(f"sbatch {script_path}") + job_ids.append(None) + + +if __name__ == "__main__": + main() diff --git a/examples/moe/benchmark_moe.sh b/examples/moe/benchmark_moe.sh new file mode 100755 index 000000000..f3b850a92 --- /dev/null +++ b/examples/moe/benchmark_moe.sh @@ -0,0 +1,38 @@ +#!/bin/bash + +# Create benchmark runner script +cat > run_benchmark.sh << 'EOL' +#!/bin/bash + +# Set base directory (default to current directory if not specified) +BASE_DIR="${1:-.}" + +echo "1. Creating directory structure under: $BASE_DIR" +mkdir -p "${BASE_DIR}/benchmark/configs" || exit 1 +mkdir -p "${BASE_DIR}/benchmark/scripts" || exit 1 +mkdir -p "${BASE_DIR}/benchmark/results" || exit 1 + +echo "2. Directory structure created:" +tree "${BASE_DIR}/benchmark" + +echo "3. Running MoE benchmark script..." +python examples/moe/benchmark_moe.py \ + --configs-dir "${BASE_DIR}/benchmark/configs" \ + --scripts-dir "${BASE_DIR}/benchmark/scripts" \ + --pending-csv "${BASE_DIR}/benchmark/results/pending_experiments2.csv" \ + --benchmark-csv "${BASE_DIR}/benchmark/results/bench_final2.csv" \ + --base-config examples/config_tiny_llama_bench.yaml \ + --partition hopper-prod \ + --time 01:00:00 \ + --run + +echo "4. Benchmark jobs submitted! Results will be saved to:" +echo " - Configs: ${BASE_DIR}/benchmark/configs" +echo " - Results: ${BASE_DIR}/benchmark/results" +EOL + +# Make the script executable +chmod +x run_benchmark.sh + +echo "Setup complete! Run the benchmark with:" +echo "./run_benchmark.sh [optional-base-directory]" diff --git a/examples/moe/run_benchmark.sh b/examples/moe/run_benchmark.sh new file mode 100755 index 000000000..f739eec48 --- /dev/null +++ b/examples/moe/run_benchmark.sh @@ -0,0 +1,27 @@ +#!/bin/bash + +# Set base directory (default to current directory if not specified) +BASE_DIR="${1:-.}" + +echo "1. Creating directory structure under: $BASE_DIR" +mkdir -p "${BASE_DIR}/benchmark/configs" || exit 1 +mkdir -p "${BASE_DIR}/benchmark/scripts" || exit 1 +mkdir -p "${BASE_DIR}/benchmark/results" || exit 1 + +echo "2. Directory structure created:" +tree "${BASE_DIR}/benchmark" + +echo "3. Running MoE benchmark script..." +python examples/moe/benchmark_moe.py \ + --configs-dir "${BASE_DIR}/benchmark/configs" \ + --scripts-dir "${BASE_DIR}/benchmark/scripts" \ + --pending-csv "${BASE_DIR}/benchmark/results/pending_experiments2.csv" \ + --benchmark-csv "${BASE_DIR}/benchmark/results/bench_final2.csv" \ + --base-config examples/config_tiny_llama_bench.yaml \ + --partition hopper-prod \ + --time 01:00:00 \ + --run + +echo "4. Benchmark jobs submitted! Results will be saved to:" +echo " - Configs: ${BASE_DIR}/benchmark/configs" +echo " - Results: ${BASE_DIR}/benchmark/results" diff --git a/examples/moe/run_multinode.sh b/examples/moe/run_multinode.sh new file mode 100644 index 000000000..be55c9bb0 --- /dev/null +++ b/examples/moe/run_multinode.sh @@ -0,0 +1,163 @@ +#!/bin/bash +#SBATCH --job-name=smolm2-bench # Job name +#SBATCH --time=00:02:00 +#SBATCH --partition=hopper-prod +#SBATCH --qos=low + +#SBATCH -o /fsx/phuc/new_workspace/experiments/qwen_moe/benchmark/exp0a0_benhmark_num_experts_topk_and_ep_in_a_node/logs/%j-%x.out + +#SBATCH --nodes=2 # Number of nodes (modify as needed) +#SBATCH --ntasks-per-node=1 # Number of tasks per node +#SBATCH --cpus-per-task=60 # CPU cores per task +#SBATCH --gres=gpu:8 # Number of GPUs per node +#SBATCH --exclusive # Exclusive use of nodes +#SBATCH --wait-all-nodes=1 # fail if any node is not ready + +# run using +# sbatch --nodes=1 run_multinode.sh +# or +# SALLOC_JOBID=13482276 NNODES=1 bash run_multinode.sh + +set -x -e +echo "Running script: $0" + + +# If not running under SLURM, set default SLURM environment variables +if [ -z "${SLURM_JOB_ID}" ]; then + if [ -z "${SALLOC_JOBID}" ]; then + echo "Error: SALLOC_JOBID environment variable is required but not set. Please run this script within an salloc session." + exit 1 + fi + if [ -z "${NNODES}" ]; then + echo "Error: NNODES environment variable is required but not set. Please run this script within an salloc session." + exit 1 + fi + export SALLOC_MODE=1 + export SLURM_JOB_ID=$SALLOC_JOBID + export SLURM_NNODES=$NNODES + export SLURM_JOB_NODELIST=$(squeue -j $SALLOC_JOBID -h -o "%N") +fi + +# Load any necessary modules for your system +source /etc/profile.d/modules.sh # for some reason module isn't loaded +module load cuda/12.1 +# Unset FI_PROVIDER to avoid potential libfabric provider issues +# unset FI_PROVIDER + + +# Activate your conda environment if needed +source /admin/home/phuc_nguyen/.bashrc +source /admin/home/phuc_nguyen/miniconda3/etc/profile.d/conda.sh +conda activate /fsx/phuc/temp/env_for_qwen_moe/env/ +# conda activate 2-1-cu121 +# export PATH=/fsx/nouamane/miniconda/envs/2-1-cu121/bin:$PATH +# export PATH=/fsx/phuc/temp/env_for_qwen_moe/env//bin:$PATH + +# Get the node names from SLURM +if [ -z "${SALLOC_MODE}" ]; then # sbatch mode + export NODELIST=`scontrol show hostnames $SLURM_JOB_NODELIST` + +else # srun mode + export NODELIST=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n$SLURM_NNODES` +fi +export MASTER_NODE=`scontrol show hostnames $SLURM_JOB_NODELIST | head -n1` +export MASTER_PORT=12356 + +# Calculate total number of processes +export NNODES=$SLURM_NNODES +export GPUS_PER_NODE=8 +export WORLD_SIZE=$(($NNODES * $GPUS_PER_NODE)) + +# Set some environment variables for better distributed training +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_DEBUG=WARN # INFO, WARN +# export NCCL_DEBUG_SUBSYS=ALL +# export CUDA_LAUNCH_BLOCKING=1 + +# Nanotron specific +export NANOTRON_BENCHMARK=1 +export WANDB_MODE=disabled + +# export TORCH_NCCL_USE_COMM_NONBLOCKING=1 + +# Trying to avoid hangs +export TORCH_NCCL_ASYNC_ERROR_HANDLING=1 + +# debug +export TORCH_DISTRIBUTED_DEBUG=DETAIL + +# export NCCL_P2P_LEVEL=NVL +# export CUDA_LAUNCH_BLOCKING=1 +# export NCCL_IB_CUDA_SUPPORT=0 # Disable RDMA +# export NCCL_NET_GDR_LEVEL=LOC +# Test Script - save as test_comm.sh + +# Test 1 - Force TCP +# echo "Running with TCP only..." +# export NCCL_P2P_LEVEL=LOC + +# # Match bandwidth patterns +# export NCCL_MAX_NCHANNELS=2 +# export NCCL_MIN_NCHANNELS=2 + + +# export NCCL_NET_GDR_LEVEL=LOC # Disable RDMA +# export NCCL_SHM_DISABLE=0 # disables the Shared Memory (SHM) transport +# export NCCL_IB_DISABLE=0 # disables the InfiniBand (IB) transport +# export NCCL_IB_TIMEOUT=60 # 20 = ~4 seconds , 21 = ~8 seconds , 22 = ~16 seconds +# export NCCL_IB_RETRY_CNT=7 # Increase retry count as well + +# Force SHM +# export NCCL_NET_PLUGIN=none # fixes hang but doesnt work multinode +# export NCCL_SOCKET_NTHREADS=1 +# export FI_PROVIDER="tcp" + +# Print GPU topology information +if [ -z "${SALLOC_MODE}" ]; then + echo "=== GPU Topology ===" + nvidia-smi topo -m + echo "==================" + export SRUN_ALLOC_ARGS="" +else + export JOBNAME="smolm2-bench" + export OUTPUT_FILE="/fsx/phuc/new_workspace/experiments/qwen_moe/benchmark/exp0a0_benhmark_num_experts_topk_and_ep_in_a_node/logs/$SLURM_JOB_ID-$(date +%Y-%m-%d-%H-%M-%S)-$JOBNAME.out" + export SRUN_ALLOC_ARGS="--jobid=$SLURM_JOB_ID --nodes=$NNODES --gres=gpu:$GPUS_PER_NODE --time=01:02:00 --job-name=$JOBNAME" +fi + + +# Print some debugging information +echo "Master node: $MASTER_NODE" +echo "All nodes: $NODELIST" +echo "World size: $WORLD_SIZE" + +# Launch the training script using srun in background +if [ -n "${SALLOC_MODE}" ]; then # srun mode + srun $SRUN_ALLOC_ARGS --wait=0 --kill-on-bad-exit=1 torchrun \ + --nnodes=$NNODES \ + --nproc_per_node=$GPUS_PER_NODE \ + --rdzv_id=$SLURM_JOB_ID \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$MASTER_NODE:$MASTER_PORT \ + --max_restarts 0 \ + --rdzv_conf timeout=60 \ + /fsx/phuc/temp/env_for_qwen_moe/nanotron/run_train.py \ + --config-file examples/config_tiny_llama.yaml > $OUTPUT_FILE 2>&1 & + # Store the process ID + SRUN_PID=$! + echo "Job started in background with PID: $SRUN_PID" | tee -a $OUTPUT_FILE + + # Optionally, you can add: + echo "To check job status: ps -p $SRUN_PID" | tee -a $OUTPUT_FILE + echo "To kill the job: kill $SRUN_PID" | tee -a $OUTPUT_FILE + +else # sbatch mode + srun $SRUN_ALLOC_ARGS --wait=0 --kill-on-bad-exit=1 torchrun \ + --nnodes=$NNODES \ + --nproc_per_node=$GPUS_PER_NODE \ + --rdzv_id=$SLURM_JOB_ID \ + --rdzv_backend=c10d \ + --rdzv_endpoint=$MASTER_NODE:$MASTER_PORT \ + --max_restarts 0 \ + --rdzv_conf timeout=60 \ + /fsx/phuc/temp/env_for_qwen_moe/nanotron/run_train.py \ + --config-file examples/config_tiny_llama.yaml diff --git a/pyproject.toml b/pyproject.toml index 390c32c40..26171e9f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ "dacite", "tqdm", "datasets", + "torchtyping" ] [tool.setuptools.packages.find] diff --git a/run_generate.py b/run_generate.py index e21fe7e22..9525cd109 100644 --- a/run_generate.py +++ b/run_generate.py @@ -63,6 +63,10 @@ def get_args(): parser.add_argument("--dp", type=int, default=1) parser.add_argument("--pp", type=int, default=0) parser.add_argument("--tp", type=int, default=0) + parser.add_argument("--expert_parallel_size", type=int, default=None) + parser.add_argument("--expert_tensor_parallel_size", type=int, default=None) + parser.add_argument("--expert_data_parallel_size", type=int, default=None) + parser.add_argument("--enabled_moe", action="store_true", help="Enable MoE model") parser.add_argument("--max-new-tokens", type=int, default=128, help="Maximum number of new tokens to generate") parser.add_argument("--use-cache", action="store_true", help="Use KV cache to speed up generation") return parser.parse_args() @@ -78,9 +82,19 @@ def main(): tokenizer_path = config.tokenizer.tokenizer_name_or_path parallel_config = ParallelismArgs( - dp=args.dp or config.parallelism.dp, + dp=config.parallelism.dp, pp=args.pp or config.parallelism.pp, tp=args.tp or config.parallelism.tp, + expert_parallel_size=args.expert_parallel_size + if args.expert_parallel_size is not None + else config.parallelism.expert_parallel_size, + expert_tensor_parallel_size=args.expert_tensor_parallel_size + if args.expert_tensor_parallel_size is not None + else config.parallelism.expert_tensor_parallel_size, + expert_data_parallel_size=args.expert_data_parallel_size + if args.expert_data_parallel_size is not None + else config.parallelism.expert_data_parallel_size, + enabled_moe=args.enabled_moe if args.enabled_moe else config.parallelism.enabled_moe, pp_engine=OneForwardOneBackwardPipelineEngine(), tp_mode=TensorParallelLinearMode.ALL_REDUCE, tp_linear_async_communication=False, @@ -91,7 +105,12 @@ def main(): data_parallel_size=parallel_config.dp, pipeline_parallel_size=parallel_config.pp, tensor_parallel_size=parallel_config.tp, + expert_parallel_size=parallel_config.expert_parallel_size, + expert_tensor_parallel_size=parallel_config.expert_tensor_parallel_size, + expert_data_parallel_size=parallel_config.expert_data_parallel_size, + enabled_moe=parallel_config.enabled_moe, ) + # log_rank(f"[run_generate.parallel_context]", logger=logger, level=logging.INFO) # Set log levels logging_config = LoggingArgs( @@ -171,12 +190,20 @@ def main(): dummy_inputs = [ # "The future of AI is", "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", - "def fib(n)", + "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + # "def fib(n)", + # "Passage: Daniel went back to the garden. Mary travelled to the kitchen. Sandra journeyed to the kitchen. Sandra went to the hallway. John went to the bedroom. Mary went back to the garden. Where is Mary?\nAnswer:", + # "def fib(n)", # 'Here is an extract from a webpage: "Have you ever experienced heel pain after a heavy physical activity, or even right after a long period of standing? If you regard this as something usual and normal, then think again. Miscalled as heel pain, plantar fasciitis causes these frequent mild pains experienced in the soles of the feet. It is the inflammation and enlargement the plantar fascia tissue that is located in the heels of the feet, stretching to the base of the toes. This tissue is responsible for absorbing shock in the feet and for supporting the arches. It also plays a vital role in foot movements during walking and standing. Many factors such as excessive walking, standing, and running trigger heel pain and plantar fasciitis. A sudden increase in intensity of activities, increase in weight, and abrupt change of footwear also cause the swelling of the ligament. Non-supportive footwear lacking arch cushions and improper and worn out running or training can also lead to the problem. It is also most evident among those". Write an extensive and detailed course unit suitable for a textbook targeted at college students, related to the given extract, within the context of "Medicine". Do not just list concepts, but develop each one in detail before moving to the next, as we prioritize depth of understanding and comprehensive exploration of the subject matter over breadth. Focus on: - Rigor: Ensure in-depth coverage of the concepts/sections. - Engagement: Write with an academic, professional and engaging tone that captivates interest. - Application: Incorporate specific, practical examples, such as proofs in calculus or critical dates and figures in history. Do not include a title or an introduction, simply write the content without headlines and introductory phrases. Do not use images.', # "Advancements in technology will lead to", # "Tomorrow's world is shaped by", ] + # log_rank(f"[run_generate.main.before_decode_text]", logger=logger, level=logging.INFO) + outputs = decode_text( input_iter=(GenerationInput(text=text) for text in dummy_inputs), tokenizer=tokenizer, diff --git a/run_train.py b/run_train.py index d00ef2118..1f262abd7 100644 --- a/run_train.py +++ b/run_train.py @@ -10,18 +10,18 @@ import argparse import time from pprint import pformat -from typing import Dict, Optional, cast +from typing import Dict, Optional import nanotron.distributed as dist +import torch.multiprocessing as mp from nanotron import logging from nanotron.config import ( DataArgs, - DatasetStageArgs, NanosetDatasetsArgs, PretrainDatasetsArgs, - Qwen2Config, SFTDatasetsArgs, ) +from nanotron.config.models_config import Qwen2Config from nanotron.data.dataloader import ( dummy_infinite_data_generator, get_train_dataloader, @@ -33,12 +33,11 @@ from nanotron.data.sft_processing import prepare_sft_dataset from nanotron.helpers import ( compute_remain_train_steps_of_a_data_stage_from_ckp, - get_consumed_train_samples_of_a_data_stage_from_ckp, ) from nanotron.logging import log_rank from nanotron.parallel.pipeline_parallel.utils import get_input_output_pp_ranks from nanotron.sanity_checks import sanity_check_dataloader -from nanotron.trainer import DistributedTrainer +from nanotron.trainer import DataStageMetadata, DistributedTrainer from nanotron.utils import main_rank_first from torch.utils.data import DataLoader @@ -56,12 +55,83 @@ # lt.monkey_patch() +import numpy as np +import torch +from torch.utils.data import Dataset + + +class SimpleTokenDataset(Dataset): + """A simple dataset that reads tokens from a file and returns sequences of a specified length. + Example usage: + dataset = SimpleTokenDataset("path/to/tokens.bin", seq_len=512) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True) + + for batch in dataloader: + input_ids = batch["input_ids"] # Shape: [batch_size, seq_len] + position_ids = batch["position_ids"] # Shape: [seq_len] + # ... use both tensors + Args: + file_path (str): Path to the file containing tokens + seq_len (int): Length of sequences to return + token_size (int): Size of each token in bytes (2 for uint16, 4 for uint32) + """ + + def __init__(self, file_path: str, seq_len: int, token_size: int = 2): + self.file_path = file_path + self.seq_len = seq_len + self.token_size = token_size + + # Open file and get total size + with open(file_path, "rb") as f: + f.seek(0, 2) # Seek to end + file_size = f.tell() + + # Calculate number of tokens and sequences + self.num_tokens = file_size // token_size + self.num_sequences = self.num_tokens // seq_len + + self._f = None + + def _get_input_ids(self, item): + if self._f is None: + self._f = open(self.file_path, "rb") + + chunk_size = self.token_size * self.seq_len + self._f.seek(item * chunk_size) + + # Read and convert to tensor + tokens = np.frombuffer(self._f.read(chunk_size), np.uint16 if self.token_size == 2 else np.uint32).astype( + np.int64 + ) + + return torch.as_tensor(tokens, dtype=torch.long) + + def __getitem__(self, item): + input_ids = self._get_input_ids(item) + position_ids = torch.arange(self.seq_len, dtype=torch.long) + + # Create label_ids by shifting input_ids right by 1 + label_ids = torch.roll(input_ids, shifts=-1, dims=0) + + # Create label_mask (all ones) + label_mask = torch.ones(self.seq_len, dtype=torch.long) + + return {"input_ids": input_ids, "position_ids": position_ids, "label_ids": label_ids, "label_mask": label_mask} + + def __len__(self): + return self.num_sequences + + def __del__(self): + if self._f: + self._f.close() + def get_dataloader_from_data_stage( trainer: DistributedTrainer, data: DataArgs, - consumed_train_samples: int, + consumed_train_samples_stage: int, consumed_tokens_per_dataset_folder: Dict[str, int], + last_stages_consumed_tokens_per_dataset_folder: Dict[str, int], num_remaining_train_steps: int, sanity_check_dataloader_interval: Optional[int] = None, ): @@ -69,10 +139,11 @@ def get_dataloader_from_data_stage( Returns a dataloader for a given data stage. data: The data configuration for the current stage. - consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). + consumed_train_samples_stage: The number of samples consumed by the model in the this stage (each stage starts from zero). + consumed_tokens_per_dataset_folder: The number of tokens consumed by the model in previous stages to avoid reseeing them, because the sampler has restarted for this stage. num_remaining_train_steps: The number of remaining training steps for this stage. """ - assert consumed_train_samples >= 0, "consumed_train_samples should be greater than 0" + assert consumed_train_samples_stage >= 0, "consumed_train_samples_stage should be greater than 0" assert num_remaining_train_steps >= 0, "num_remaining_train_steps should be greater than 0" # First, we need to know which ranks to feed the dataloader to @@ -164,7 +235,7 @@ def get_dataloader_from_data_stage( input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, micro_batch_size=trainer.micro_batch_size, - consumed_train_samples=consumed_train_samples, + consumed_train_samples_stage=consumed_train_samples_stage, dataloader_num_workers=data.num_loading_workers, seed_worker=data.seed, dataloader_drop_last=True, @@ -185,7 +256,6 @@ def get_dataloader_from_data_stage( # Case 3: Nanosets elif isinstance(data.dataset, NanosetDatasetsArgs): log_rank("Using TokenizedBytes Dataloader", logger=logger, level=logging.INFO, rank=0) - from nanotron.data.tokenized_bytes import get_tb_dataloader, get_tb_datasets tokenizer_path = trainer.config.tokenizer.tokenizer_name_or_path tokenizer = AutoTokenizer.from_pretrained(tokenizer_path) @@ -198,20 +268,26 @@ def get_dataloader_from_data_stage( level=logging.INFO, rank=0, ) - tokenizer.pad_token = tokenizer.eos_token + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token tokenizer.padding_side = "left" + from nanotron.data.tokenized_bytes import get_tb_dataloader, get_tb_datasets + start_time = time.time() train_dataset, data_log = get_tb_datasets( config=data.dataset, global_batch_size=trainer.global_batch_size, sequence_length=trainer.sequence_length, train_steps=trainer.config.tokens.train_steps, + current_iteration=trainer.iteration_step, parallel_context=trainer.parallel_context, shuffle=data.dataset.shuffle_files, eos_token_id=tokenizer.eos_token_id, seed=data.seed, + consumed_samples=consumed_train_samples_stage, consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder, ) dataloader = get_tb_dataloader( dataset=train_dataset, @@ -220,8 +296,9 @@ def get_dataloader_from_data_stage( global_batch_size=trainer.global_batch_size, num_workers=data.num_loading_workers, cfg=data.dataset, - consumed_samples=consumed_train_samples, - num_samples=trainer.config.tokens.train_steps * trainer.global_batch_size, + consumed_samples=consumed_train_samples_stage, + num_samples=trainer.config.tokens.train_steps + * trainer.global_batch_size, # TODO: this overshoots what's needed by the current stage, but it doesnt matter? parallel_context=trainer.parallel_context, input_pp_rank=input_pp_rank, output_pp_rank=output_pp_rank, @@ -236,6 +313,13 @@ def get_dataloader_from_data_stage( level=logging.INFO, rank=0, ) + + # log_rank( + # f"[TokenizedBytes] Time taken to create TokenizedBytes: {time.strftime('%M:%S', time.gmtime(time.time() - start_time))} (MM:SS)", + # logger=logger, + # level=logging.INFO, + # rank=0, + # ) dist.barrier() # Create Nanoset @@ -315,46 +399,76 @@ def get_dataloader( full_log_message = f"There are {len(trainer.config.data_stages)} training stages \n{stages_info}" log_rank(full_log_message, logger=logger, level=logging.INFO, rank=0) - for stage_idx, stage in enumerate(trainer.config.data_stages): - # NOTE: we only create the dataloader for the first stage, - # then we lazy initialize the dataloader for the other stages - stage = cast(DatasetStageArgs, stage) - ( - consumed_train_samples, - consumed_tokens_per_dataset_folder, - ) = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, trainer.metadata) - - num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( - stage, trainer.config, trainer.metadata - ) - log_rank( - f"Stage {stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {consumed_train_samples} samples" - f"Consumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}", - logger=logger, - level=logging.INFO, - rank=0, + current_stage = None + # WARNING: we assume we train on last stage + stage_idx = len(trainer.config.data_stages) - 1 + stage_args = trainer.config.data_stages[stage_idx] + if trainer.iteration_step + 1 == stage_args.start_training_step: + log_rank(f"Starting new stage {stage_args.name}", logger=logger, level=logging.INFO, rank=0) + # we start a new stage + if stage_idx >= len(trainer.metadata.data_stages): + trainer.metadata.data_stages.append( + DataStageMetadata( + name=stage_args.name, + start_training_step=stage_args.start_training_step, + consumed_train_samples=0, + consumed_tokens_per_dataset_folder={}, + sequence_length=trainer.sequence_length, + ) + ) + elif len(trainer.metadata.data_stages) < len(trainer.config.data_stages): + raise ValueError( + f"If you're trying to start a new stage, you need to set `start_training_step` to the step after the last stage's: {trainer.iteration_step+1}" ) + current_stage = trainer.metadata.data_stages[stage_idx] + cur_stage_consumed_train_samples = current_stage.consumed_train_samples + consumed_tokens_per_dataset_folder = current_stage.consumed_tokens_per_dataset_folder + stage_args_data = trainer.config.data_stages[stage_idx].data + + num_remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp( + current_stage, trainer.config, trainer.metadata + ) # TODO: check this + log_rank( + f"Current stage: {current_stage.name} has {num_remaining_train_steps} remaining training steps and has consumed {cur_stage_consumed_train_samples} samples" + f"Consumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}", + logger=logger, + level=logging.INFO, + rank=0, + ) - dataloader = ( - get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, - num_remaining_train_steps=num_remaining_train_steps, - sanity_check_dataloader_interval=sanity_check_dataloader_interval, - ) - if stage_idx == 0 - else lambda stage=stage: get_dataloader_from_data_stage( - trainer, - stage.data, - consumed_train_samples=consumed_train_samples, - consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, - num_remaining_train_steps=num_remaining_train_steps, - sanity_check_dataloader_interval=sanity_check_dataloader_interval, + # warn that if seqlen of stage - 1 has changed, consumed_train_samples=0 so we'll assume we're reading from new folder (so that we can resume training) + if current_stage.sequence_length != trainer.metadata.data_stages[-1].sequence_length: + raise NotImplementedError("We don't support changing sequence length between stages yet") + if current_stage.consumed_train_samples == 0: + log_rank( + f"Warning: The sequence length of the last stage has changed from {trainer.metadata.data_stages[-1].sequence_length} to {current_stage.sequence_length}. We'll assume we're reading from the beginning of the dataset folders.", + logger=logger, + level=logging.WARNING, + rank=0, ) - ) - dataloaders[stage.name] = dataloader + else: + # we're resuming training, so that's fine + pass + cur_stage_consumed_train_samples = current_stage.consumed_train_samples + + else: + # Prepare last_stages_consumed_tokens_per_dataset_folder which will be used to offset BlendableDataset to avoid reseeing consumed tokens even when sampler has restarted for this stage + last_stages_consumed_tokens_per_dataset_folder = {} + for stage in trainer.metadata.data_stages[:-1]: + for folder_path, consumed_tokens in stage.consumed_tokens_per_dataset_folder.items(): + last_stages_consumed_tokens_per_dataset_folder[folder_path] = ( + last_stages_consumed_tokens_per_dataset_folder.get(folder_path, 0) + consumed_tokens + ) + + dataloaders[current_stage.name] = get_dataloader_from_data_stage( + trainer, + stage_args_data, + consumed_train_samples_stage=cur_stage_consumed_train_samples, + consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + last_stages_consumed_tokens_per_dataset_folder=last_stages_consumed_tokens_per_dataset_folder, + num_remaining_train_steps=num_remaining_train_steps, + sanity_check_dataloader_interval=sanity_check_dataloader_interval, + ) return dataloaders @@ -371,6 +485,7 @@ def get_args(): if __name__ == "__main__": + mp.set_start_method("spawn") # debuggy fails args = get_args() config_file = args.config_file diff --git a/scripts/scaling_moe_benchmark.py b/scripts/scaling_moe_benchmark.py new file mode 100644 index 000000000..035a90b4b --- /dev/null +++ b/scripts/scaling_moe_benchmark.py @@ -0,0 +1,217 @@ +import argparse +import os +import subprocess +from copy import deepcopy + +import numpy as np +import pandas as pd +from nanotron.config import Config + +SUBMIT_JOB_PATH = "/fsx/phuc/new_workspace/snippets/runner/submit_job.py" + + +def generate_scaling_configs( + seq_len=4096, + mbs=2, + gpus_per_node=8, + target_gbs_min=4_000_000, + target_gbs_max=8_000_000, + num_layers=[9, 10, 11], + learning_rates=[0.00002, 0.00006, 0.0001, 0.0002], +): + # Initialize lists to store configurations + configs = [] + + # Define node counts to explore + node_counts = [1, 2, 4, 8, 16, 32, 64, 128] + + # Target the lower end of the range + target_gbs = target_gbs_min + + for nodes in node_counts: + for nl in num_layers: + for lr in learning_rates: + # Calculate data parallel size based on number of nodes + world_size = gpus_per_node * nodes + dp_replicas = world_size + + # Calculate accumulation steps needed to approach target GBS + # GBS = seq_len * mbs * dp_size * accum + ideal_accum = target_gbs / (seq_len * mbs * dp_replicas) + + # Round to nearest integer, minimum of 1 + # For smaller batch sizes, round up to ensure we meet minimum target + accum = max(1, int(np.ceil(ideal_accum))) + + # Calculate actual GBS with this configuration + actual_gbs = seq_len * mbs * dp_replicas * accum + + # Calculate batch size per replica + bs_per_replica = seq_len * mbs + + # Calculate ep and edp values + ep = 8 + edp = world_size // ep + + # Store configuration + configs.append( + { + "Nodes": nodes, + "GPUs": world_size, + "MicroBatchSize": mbs, + "SequenceLength": seq_len, + "AccumSteps": accum, + "BatchSizePerReplica": bs_per_replica, + "GlobalBatchSize": actual_gbs, + "GlobalBatchSizeMillions": actual_gbs / 1_000_000, + "ep": ep, + "edp": edp, + "NumLayers": nl, + "LearningRate": lr, + } + ) + + # Create DataFrame + df = pd.DataFrame(configs) + + # Format numbers for better readability + df["GlobalBatchSizeMillions"] = df["GlobalBatchSizeMillions"].round(2) + + return df + + +def create_experiment_names(df): + # Define mapping for node counts to alphabets + node_to_alphabet = {1: "a", 2: "b", 4: "c", 8: "d", 16: "e", 32: "f", 64: "g", 128: "h"} + + # Create experiment names with alphabet code based on node count + df["ExperimentName"] = df.apply( + lambda row: f"exp19{node_to_alphabet[row['Nodes']]}a1_like_exp18aa1_and_{int(row['Nodes'])}_node_OLMoE-1B-7B_te_and_seq_len_{int(row['SequenceLength'])}_and_batch_accum{int(row['AccumSteps'])}_and_mbs{int(row['MicroBatchSize'])}_and_gbs{int(row['AccumSteps']*row['MicroBatchSize'])}_with_{int(row['GlobalBatchSizeMillions'])}m_and_elie_training_config_and_fineweb_numlayer{int(row['NumLayers'])}_and_seed_312_but_dp{int(row['GPUs'])}_tp1_ep{int(row['ep'])}_edp{int(row['edp'])}_and_lr{row['LearningRate']:.6f}_and_groupedgemm_and_allgather", + axis=1, + ) + + return df + + +def create_scaled_configs( + base_config: Config, + scaling_df: pd.DataFrame, + output_base_dir: str, + benchmark_csv_path: str, + brrr_repo_path: str, + uv_env_path: str, + script_path: str, + reservation_name: str, + launch_config: bool = False, +): + """Create scaled config files and optionally launch jobs""" + os.makedirs(output_base_dir, exist_ok=True) + + for _, row in scaling_df.iterrows(): + print(f"Generating config for {row['ExperimentName']}") + new_config = deepcopy(base_config) + + # Config generation remains the same + new_config.general.benchmark_csv_path = benchmark_csv_path + new_config.general.run = row["ExperimentName"] + new_config.model.model_config.num_hidden_layers = row["NumLayers"] + new_config.parallelism.dp = row["GPUs"] + new_config.parallelism.expert_parallel_size = row["ep"] + new_config.parallelism.expert_data_parallel_size = row["edp"] + new_config.tokens.sequence_length = row["SequenceLength"] + new_config.tokens.micro_batch_size = row["MicroBatchSize"] + new_config.tokens.batch_accumulation_per_replica = row["AccumSteps"] + new_config.optimizer.learning_rate_scheduler.learning_rate = row["LearningRate"] + + config_path = os.path.join(output_base_dir, f"{row['ExperimentName']}.yaml") + new_config.save_as_yaml(config_path) + + # Build launch command + launch_command = [ + "python3", + SUBMIT_JOB_PATH, + "--config", + config_path, + "--nproc_per_node", + "8", + "--brrr_repo_path", + brrr_repo_path, + "--uv_env_path", + uv_env_path, + "--nodes", + str(row["Nodes"]), + "--script_path", + script_path, + "--is_brrr_config", + "false", + "--reservation_name", + reservation_name, + ] + + if launch_config: + print("Launching:", " ".join(launch_command)) + subprocess.run(launch_command, check=True) + else: + print("Would launch:", " ".join(launch_command)) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Generate scaled config files and launch jobs") + parser.add_argument("--base-config", type=str, required=True) + parser.add_argument( + "--output-base-dir", type=str, required=True, help="Shared directory for all configs and benchmark CSV" + ) + parser.add_argument( + "--benchmark-csv-name", + type=str, + default="benchmark_results.csv", + help="Filename for EMPTY benchmark CSV in output-base-dir", + ) + + # Add new argument to parser + parser.add_argument("--launch-config", action="store_true", help="Actually launch the jobs (dry run by default)") + + # Add new arguments for job submission + parser.add_argument( + "--brrr_repo_path", type=str, required=False, default="", help="[Optional] Path to BRRR repository" + ) + parser.add_argument( + "--uv_env_path", type=str, required=False, default="", help="[Optional] Path to UV environment" + ) + parser.add_argument( + "--script_path", type=str, default="run_train.py", help="Training script path (default: run_train.py)" + ) + parser.add_argument( + "--reservation_name", type=str, required=False, default="", help="[Optional] Cluster reservation name" + ) + + args = parser.parse_args() + + base_config = Config.load_from_yaml(args.base_config) + os.makedirs(args.output_base_dir, exist_ok=True) + + # Path definitions + benchmark_csv_path = os.path.join(args.output_base_dir, args.benchmark_csv_name) + scaling_configs_path = os.path.join(args.output_base_dir, "scaling_configs.csv") + + # Generate configurations and names + scaling_configs = generate_scaling_configs(seq_len=4096, mbs=2, num_layers=[10], learning_rates=[0.00002]) + scaling_config_names = create_experiment_names(scaling_configs.copy()) + + # Save reference configurations with names + scaling_config_names.to_csv(scaling_configs_path, index=False) + + # Create empty benchmark file + open(benchmark_csv_path, "a").close() # Creates empty file if it doesn't exist + + create_scaled_configs( + base_config=base_config, + scaling_df=scaling_config_names, + output_base_dir=args.output_base_dir, + benchmark_csv_path=benchmark_csv_path, + brrr_repo_path=args.brrr_repo_path, + uv_env_path=args.uv_env_path, + script_path=args.script_path, + reservation_name=args.reservation_name, + launch_config=args.launch_config, + ) diff --git a/src/nanotron/config/config.py b/src/nanotron/config/config.py index c16f076c1..67b56a5f1 100644 --- a/src/nanotron/config/config.py +++ b/src/nanotron/config/config.py @@ -14,7 +14,12 @@ from yaml.loader import SafeLoader from nanotron.config.lighteval_config import LightEvalConfig -from nanotron.config.models_config import ExistingCheckpointInit, NanotronConfigs, RandomInit, SpectralMupInit +from nanotron.config.models_config import ( + ExistingCheckpointInit, + NanotronConfigs, + RandomInit, + SpectralMupInit, +) from nanotron.config.parallelism_config import ParallelismArgs from nanotron.config.utils_config import ( InitScalingMethod, @@ -226,6 +231,7 @@ class DatasetStageArgs: name: str start_training_step: int data: DataArgs + sequence_length: Optional[int] = None # if None, we use the sequence length from the config def __post_init__(self): if self.start_training_step < 0: @@ -317,10 +323,30 @@ def __post_init__(self): if isinstance(self.dtype, str): self.dtype = cast_str_to_torch_dtype(self.dtype) - self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit) + # Convert model_config to proper type if it's a dict + if isinstance(self.model_config, dict): + # First convert moe_config if it exists + if "moe_config" in self.model_config and isinstance(self.model_config["moe_config"], dict): + from nanotron.config.models_config import MoEConfig + + self.model_config["moe_config"] = MoEConfig(**self.model_config["moe_config"]) + + # Then convert the main config + if self.model_config.get("is_qwen2_config", False): + from nanotron.config.models_config import Qwen2Config - # if self.model_config.max_position_embeddings is None: - # self.model_config.max_position_embeddings = 0 + self.model_config = Qwen2Config(**self.model_config) + elif self.model_config.get("is_llama_config", False): + from nanotron.config.models_config import LlamaConfig + + self.model_config = LlamaConfig(**self.model_config) + elif self.model_config.get("is_starcoder2_config", False): + from nanotron.config.models_config import Starcoder2Config + + self.model_config = Starcoder2Config(**self.model_config) + + # Now we can safely set _is_using_mup + self.model_config._is_using_mup = isinstance(self.init_method, SpectralMupInit) @dataclass @@ -542,6 +568,17 @@ def __post_init__(self): self.model.model_config.num_attention_heads % self.model.model_config.num_key_value_heads == 0 ), f"num_attention_heads ({self.model.model_config.num_attention_heads}) must be divisible by num_key_value_heads ({self.model.model_config.num_key_value_heads})" + if self.model.model_config.moe_config is not None: + assert ( + self.model.model_config.moe_config.num_experts % self.parallelism.expert_parallel_size == 0 + ), f"num_experts ({self.model.model_config.moe_config.num_experts}) must be divisible by expert_parallel_size ({self.parallelism.expert_parallel_size})" + + # data_stages + if self.data_stages is not None: + for stage in self.data_stages: + if stage.sequence_length is None: + stage.sequence_length = self.tokens.sequence_length + @property def global_batch_size(self): return self.tokens.micro_batch_size * self.tokens.batch_accumulation_per_replica * self.parallelism.dp diff --git a/src/nanotron/config/lighteval_config.py b/src/nanotron/config/lighteval_config.py index 363ee9887..0806acffe 100644 --- a/src/nanotron/config/lighteval_config.py +++ b/src/nanotron/config/lighteval_config.py @@ -109,8 +109,13 @@ class LightEvalConfig: logging: Optional[LightEvalLoggingArgs] = None wandb: Optional[LightEvalWandbLoggerConfig] = None slurm: Optional[LightEvalSlurm] = None - s3_save_path: Optional[str] = None # should not be dependent of the run_name - output_dir: Optional[str] = None # we should sanity check that it's the same as the one in the eval_config_override + s3_save_path: Optional[str] = None # should not be dependent of the run_name + upload_to_wandb: Optional[bool] = False + wandb_project: Optional[str] = None + wandb_entity: Optional[str] = None + output_dir: Optional[ + str + ] = None # we should sanity check that it's the same as the one in the eval_config_override nanotron_path: Optional[str] = "./" eval_config_override: str = None eval_config_override: Path = None # Previously hardcoded in run_slurm_one_job @@ -127,6 +132,12 @@ def __post_init__(self): if self.slurm is None: self.slurm = LightEvalSlurm() self.local_checkpoint_dir = str(Path(self.local_checkpoint_dir).expanduser()) + if self.upload_to_wandb: + assert ( + self.s3_save_path is not None + ), " We should have a s3_save_path if we want to upload to wandb" # todo: add the option to read from local folder i guess + assert self.wandb_project is not None, "wandb_project must be specified if upload_to_wandb is True" + assert self.wandb_entity is not None, "wandb_entity must be specified if upload_to_wandb is True" if self.eval_interval_file is not None and Path(self.eval_interval_file).exists(): logger.warning( f"Eval interval file {self.eval_interval_file} exists. `eval_interval` will be replaced by the value in the file upon the next evaluation. You should probably delete this file if that's not what you want." diff --git a/src/nanotron/config/models_config.py b/src/nanotron/config/models_config.py index dd575e399..ad0480ded 100644 --- a/src/nanotron/config/models_config.py +++ b/src/nanotron/config/models_config.py @@ -1,6 +1,6 @@ from dataclasses import dataclass, field from pathlib import Path -from typing import Any, List, Optional, Union +from typing import Any, List, Literal, Optional, Union from nanotron.config.utils_config import InitScalingMethod from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, AttentionImplementation @@ -36,13 +36,58 @@ class ExistingCheckpointInit: class MoEConfig: """Configuration for Mixture of Experts layers""" - num_experts: int = 8 # Total number of experts - top_k: int = 2 # Number of experts to route each token to + num_experts: int # Total number of experts + top_k: int # Number of experts to route each token to + moe_hidden_size: int # Hidden size of the MoE layer + moe_intermediate_size: int # Intermediate size of the MoE layer + enable_shared_expert: bool = False # Whether to use a shared expert alongside specialized experts + shared_expert_hidden_size: int = 4096 # Hidden size of the shared expert + shared_expert_intermediate_size: int = 11008 # Intermediate size of the shared expert + router_aux_loss_coef: float = ( + 0.01 # Scaling coefficient for the aux loss. A starting value of 1e-2 is recommended. + ) layers: List[int] = field( default_factory=lambda: [-1] ) # Indices of layers that use MoE. -1 means all layers. Default is all layers - enable_shared_expert: bool = False # Whether to use a shared expert alongside specialized experts token_dispatcher_type: str = "alltoall" # Communication pattern for MoE ("alltoall" or "allgather") + use_torch_permute: bool = True # Whether to use Haojun's permute + + moe_impl: str = "transformer_engine" + grouped_gemm_imple: Literal["transformer_engine", "megablock_grouped_gemm"] = "transformer_engine" + + # Transformer-Engine specific config + num_shared_experts: int = None + rotary_base: int = None + rotary_scaling_factor: int = None + max_position_embeddings: int = None + + moe_z_loss_coeff: Optional[ + float + ] = None # Scaling coefficient for the z-loss. A starting value of 1e-3 is recommended. + gradient_accumulation_fusion: bool = False # + disable_parameter_transpose_cache: bool = ( + False # When set to true, the parameter transposes are not cached for subsequent iterations. + ) + bias_activation_fusion: bool = True + permute_fusion: bool = False + input_jitter_eps: float = None # Add noise to the input tensor. https://arxiv.org/abs/2101.03961 + # The load balancing strategy for the router. "aux_loss" corresponds to the load balancing loss + # used in GShard and SwitchTransformer; "seq_aux_loss" corresponds to the loss used in DeepSeekV2, + # which computes the loss for each individual sample; "sinkhorn" corresponds to the balancing + # algorithm used in S-BASE, and "none" implies no load balancing. The default is "aux_loss". + router_load_balancing_type: str = "aux_loss" + moe_expert_capacity_factor: Optional[float] = None + moe_pad_expert_input_to_capacity: bool = False + moe_token_drop_policy: str = "probs" + moe_router_pre_softmax: bool = False + moe_router_num_groups: Optional[int] = None + moe_router_group_topk: Optional[int] = None + # Scaling factor for routing score in top-k selection, only works when moe_router_pre_softmax enabled. Defaults to None, which means no scaling + moe_router_topk_scaling_factor: Optional[float] = None + moe_router_score_function: str = "softmax" # Score function for MoE routing. Can be "softmax" or "sigmoid" + moe_router_expert_bias: Optional[bool] = None + moe_router_dtype: Optional[str] = None + # TODO: add docs https://github.com/NVIDIA/Megatron-LM/blob/dab7723821fc326564634b398a809d43740a6c8d/megatron/core/transformer/transformer_config.py def __post_init__(self): # Validate the configuration @@ -54,6 +99,20 @@ def __post_init__(self): f"token_dispatcher_type must be one of ['alltoall', 'allgather'], got {self.token_dispatcher_type}" ) + assert self.grouped_gemm_imple in [ + "transformer_engine", + "megablock_grouped_gemm", + ], f"Invalid grouped gemm implementation: {self.grouped_gemm_imple}. Available options are: ['transformer_engine', 'megablock_grouped_gemm']" + + if ( + self.top_k == 1 + and self.moe_router_score_function == "softmax" + and not self.moe_router_pre_softmax + and self.router_load_balancing_type != "sinkhorn" + ): + # https://github.com/NVIDIA/Megatron-LM/blob/28118fcdc22e42621776a021af568ae39c198418/megatron/core/transformer/transformer_config.py#L805-L813 + raise ValueError("Please use --moe-router-pre-softmax when topk is 1.") + @dataclass class LlamaConfig: @@ -134,6 +193,9 @@ class Qwen2Config: rope_scaling: Optional[dict] = None rope_theta: float = 10000.0 rope_interleaved: bool = False + rope_seq_len_interpolation_factor: Optional[ + float + ] = None # if not None, discrete positions will be interpolated by this factor via the trick in https://arxiv.org/abs/2306.15595 tie_word_embeddings: bool = False use_cache: bool = True vocab_size: int = 32000 @@ -193,6 +255,12 @@ def __post_init__(self): self.num_hidden_layers % self.no_rope_layer == 0 ), "no_rope_layer must be a multiple of num_hidden_layers" + # rope_seq_len_interpolation_factor = seqlen / 4096 + if self.max_position_embeddings > 4096: + assert ( + self.rope_seq_len_interpolation_factor == self.max_position_embeddings / 4096 + ), f"rope_seq_len_interpolation_factor must be equal to max_position_embeddings / 4096 = {self.max_position_embeddings / 4096}" + @property def is_using_mup(self) -> bool: return self._is_using_mup diff --git a/src/nanotron/config/parallelism_config.py b/src/nanotron/config/parallelism_config.py index 48aa941e8..0d50bf7a5 100644 --- a/src/nanotron/config/parallelism_config.py +++ b/src/nanotron/config/parallelism_config.py @@ -19,7 +19,11 @@ class ParallelismArgs: dp: Number of DP replicas pp: Number of PP stages tp: Number of TP replicas + expert_parallel_size: Number of expert parallel replicas (used only for MoEs) + expert_tensor_parallel_size: The degree that we shard the experts across GPUs + expert_data_parallel_size: The number of expert replicas in data parallel + pp_engine: Pipeline engine to use between "1f1b" and "afab" tp_mode: TP mode to use between "all_reduce" and "reduce_scatter": all_reduce is normal, reduce_scatter activate sequence parallelism tp_linear_async_communication: Whether to use async communication in TP linear layers @@ -33,11 +37,14 @@ class ParallelismArgs: tp_mode: Optional[TensorParallelLinearMode] = None tp_linear_async_communication: Optional[bool] = None recompute_layer: bool = False - moe_layer_recompute: bool = False - tp_recompute_allgather: bool = True + # NOTE: moe-specific parallelism expert_parallel_size: int = 1 + expert_tensor_parallel_size: int = 1 + expert_data_parallel_size: int = 1 + enabled_moe: bool = False + context_parallel_size: int = 1 def __post_init__(self): @@ -53,3 +60,6 @@ def __post_init__(self): self.pp_engine = cast_str_to_pipeline_engine(self.pp_engine) if isinstance(self.tp_mode, str): self.tp_mode = TensorParallelLinearMode[self.tp_mode.upper()] + + if self.expert_parallel_size > 1: + assert self.enabled_moe, "expert_parallel_size > 1 requires enabled_moe to be True" diff --git a/src/nanotron/constants.py b/src/nanotron/constants.py index 580bd99df..4eae3c807 100644 --- a/src/nanotron/constants.py +++ b/src/nanotron/constants.py @@ -2,7 +2,7 @@ from packaging.version import Version, parse -CHECKPOINT_VERSION = Version("1.4") +CHECKPOINT_VERSION = Version("1.5") PY_VERSION = parse(platform.python_version()) @@ -10,3 +10,14 @@ CHECKPOINT_FILE_NAME = "checkpoint_metadata.json" MODEL_CONFIG_FILE_NAME = "model_config.json" + +# MoE specific +EXPERT_PARAM_NAMES = [ + # NOTE: nanotron's moe modeling + "mlp.experts.merged_down_proj", + "mlp.experts.merged_gate_up_proj", + # NOTE: TE's moe modeling + "experts.linear_fc1", + "experts.linear_fc2", + "mlp.router.weight", +] diff --git a/src/nanotron/data/dataloader.py b/src/nanotron/data/dataloader.py index 0a6185163..930a6142b 100644 --- a/src/nanotron/data/dataloader.py +++ b/src/nanotron/data/dataloader.py @@ -137,6 +137,30 @@ def data_generator() -> Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]: seed * (1 + dist.get_rank(parallel_context.dp_pg)) * (1 + dist.get_rank(parallel_context.pp_pg)) ) + # debugging + import joblib + + sample_batch = joblib.load("/fsx/nouamane/projects/OLMoE/sample_batch.pkl") + if sample_batch["input_ids"].shape[0] != micro_batch_size: + sample_batch["input_ids"] = sample_batch["input_ids"][:micro_batch_size] + + assert sample_batch["input_ids"].shape == ( + micro_batch_size, + sequence_length, + ), f"input_ids.shape: {sample_batch['input_ids'].shape}, micro_batch_size: {micro_batch_size}, sequence_length: {sequence_length}" + input_ids = sample_batch["input_ids"].to(device="cuda") + position_ids = torch.arange(sequence_length, device="cuda").repeat(micro_batch_size, 1) + # shift label_ids + label_ids = torch.cat([input_ids[:, 1:], torch.zeros_like(input_ids[:, :1])], dim=1) + label_mask = torch.ones_like(label_ids, dtype=torch.bool) + while True: + yield { + "input_ids": input_ids, + "position_ids": position_ids, + "label_ids": label_ids, + "label_mask": label_mask, + } + if use_position_ids: document_lengths = [[4, 6, sequence_length - 10]] + [[sequence_length]] * (micro_batch_size - 1) position_ids = torch.full( diff --git a/src/nanotron/data/nemo_dataset/blendable_dataset.py b/src/nanotron/data/nemo_dataset/blendable_dataset.py index e4e999169..96c49fe0a 100644 --- a/src/nanotron/data/nemo_dataset/blendable_dataset.py +++ b/src/nanotron/data/nemo_dataset/blendable_dataset.py @@ -25,6 +25,7 @@ from nanotron.logging import log_rank from nanotron.parallel import ParallelContext from nanotron.utils import main_rank_first +from pprint import pformat if TYPE_CHECKING: from . import GPTDataset, SubsetSplitLog @@ -48,6 +49,7 @@ def __init__( parallel_context: ParallelContext, seed: int, consumed_tokens_per_dataset_folder: Optional[Dict[str, int]] = None, + offsets_in_samples: Optional[Dict[str, int]] = None, ): self.datasets = datasets num_datasets = len(datasets) @@ -114,16 +116,26 @@ def __init__( # self.last_item_idx = np.full(self.history_size, -1, dtype=np.int64) # Initialize consumption tracking - self.consumed_tokens = {idx: 0 for idx in range(len(datasets))} + self.consumed_tokens = {idx: 0 for idx in range(len(datasets))} # current stage's consumed_tokens_per_dataset_folder if consumed_tokens_per_dataset_folder is not None: # find idx of dataset that matches the folder path for idx, dataset in enumerate(datasets): for folder_path, consumed_tokens in consumed_tokens_per_dataset_folder.items(): if dataset.folder_path == folder_path: self.consumed_tokens[idx] = consumed_tokens - break + log_rank(f"[BlendableDataset] Setting consumed_tokens for dataset {idx} ({dataset.folder_path}) to {consumed_tokens}", logger=logger, level=logging.INFO, rank=0) + self.sequence_length = None # Will be set when first batch is processed + # Setup offsets for already consumed tokens from previous stages + self.offsets_in_samples = {idx: 0 for idx in range(len(datasets))} # last stage's consumed_tokens_per_dataset_folder + if offsets_in_samples is not None: + for idx, dataset in enumerate(datasets): + for folder_path, offset in offsets_in_samples.items(): + if dataset.folder_path == folder_path: + self.offsets_in_samples[idx] = offset + log_rank(f"[BlendableDataset] Applying offset {offset} samples to dataset {idx} ({dataset.folder_path})", logger=logger, level=logging.INFO, rank=0) + def __len__(self): return self.size @@ -131,16 +143,7 @@ def __getitem__(self, idx): dataset_idx = self.dataset_index[idx] sample_idx = self.dataset_sample_index[idx] - # Shift history arrays and add new values at the end - # self.last_item_idx = np.roll(self.last_item_idx, -1) - # self.last_dataset_idx = np.roll(self.last_dataset_idx, -1) - # self.last_dataset_sample_idx = np.roll(self.last_dataset_sample_idx, -1) - - # self.last_item_idx[-1] = idx - # self.last_dataset_idx[-1] = dataset_idx - # self.last_dataset_sample_idx[-1] = sample_idx - - return self.datasets[dataset_idx][sample_idx] + return self.datasets[dataset_idx][sample_idx + self.offsets_in_samples[dataset_idx]] # TODO: is it okay to not respect dataset_sample_index? Since it's sequential it's okay for now # @property # def last_file_idx(self): @@ -181,6 +184,9 @@ def get_consumption_stats(self): """ stats = {} for dataset_idx, dataset in enumerate(self.datasets): + assert ( + "s3" in dataset.folder_path + ), "Only S3 paths are supported for consumption stats" # TODO: remove this stats[dataset.folder_path] = {"tokens": self.consumed_tokens[dataset_idx]} return stats diff --git a/src/nanotron/data/tokenized_bytes.py b/src/nanotron/data/tokenized_bytes.py index 4f9063eb6..3b8a5437f 100644 --- a/src/nanotron/data/tokenized_bytes.py +++ b/src/nanotron/data/tokenized_bytes.py @@ -369,12 +369,13 @@ def __init__( ) from datatrove.utils.dataset import url_to_fs - fs_folder, folder_path = url_to_fs(folder_path) + fs_folder, stripped_folder_path = url_to_fs(folder_path) matched_files = ( - fs_folder.find(folder_path, detail=False, maxdepth=1 if not recursive else None) + fs_folder.find(stripped_folder_path, detail=False, maxdepth=1 if not recursive else None) if not filename_pattern else fs_folder.glob( - os.path.join(folder_path, filename_pattern), maxdepth=1 if not recursive else None + os.path.join(stripped_folder_path, filename_pattern), + maxdepth=1 if not recursive else None, ) ) matched_files = sorted(matched_files) @@ -505,7 +506,7 @@ def build_dataset( seq_len=seq_length, recursive=False, token_size=token_size, - max_tokens=max_tokens, + max_tokens=max_tokens, # TODO: remove shuffle=shuffle, return_positions=return_positions, # if set to True, the position ids are directly read from datatrove eos_token_id=eos_token_id, @@ -521,11 +522,14 @@ def get_tb_datasets( sequence_length: int, global_batch_size: int, train_steps: int, + current_iteration: int, parallel_context: ParallelContext, eos_token_id: Optional[int] = None, shuffle: bool = False, seed: int = 6, + consumed_samples: int = 0, consumed_tokens_per_dataset_folder: Optional[Dict[str, int]] = None, + last_stages_consumed_tokens_per_dataset_folder: Optional[Dict[str, int]] = None, ) -> Tuple[DataLoader, TrainDataLog]: """Build TokenizedBytes datasets @@ -541,6 +545,7 @@ def get_tb_datasets( if dataset_max_tokens is None: dataset_max_tokens = [None] * len(config.dataset_folder) train_num_samples = train_steps * global_batch_size + last_stages_consumed_samples_per_dataset_folder = {k: v // sequence_length for k, v in last_stages_consumed_tokens_per_dataset_folder.items()} datasets = [ build_dataset( @@ -560,6 +565,49 @@ def get_tb_datasets( for i, (dataset_folder, max_tokens) in enumerate(zip(config.dataset_folder, dataset_max_tokens)) ] + # in case of dataset_read_path check we have enough files locally for the training + if config.dataset_read_path: + + weights = config.dataset_weights + if not weights: + weights = [1] * len(datasets) + + # Normalize weights + weights = np.array(weights, dtype=np.float64) + sum_weights = np.sum(weights) + assert sum_weights > 0.0 + weights /= sum_weights + + # check we have enough files locally for the training + for i, dataset in enumerate(datasets): + # warmup datasets + estimate_current_sample = int(consumed_samples * weights[i]) + last_stages_consumed_samples_per_dataset_folder.get(dataset.folder_path, 0) + _ = dataset[estimate_current_sample] + # print which file we're currently reading from + log_rank(f"Dataset {i} ({dataset.folder_path}) is reading from file {dataset.current_file_path}", logger=logger, level=logging.INFO, rank=0) + # estimate number of tokens needed for this dataset + needed_num_samples_dataset = int((train_steps - current_iteration) * global_batch_size * weights[i]) + needed_num_tokens_dataset = needed_num_samples_dataset * sequence_length + needed_size_tokens_dataset = human_format(needed_num_tokens_dataset * config.token_size_in_bytes) + log_rank(f"Dataset {i} ({dataset.folder_path}) needs {needed_num_tokens_dataset} tokens (size: {needed_size_tokens_dataset}) for current stage", logger=logger, level=logging.INFO, rank=0) + + # NOTE: let's assume that s3 folder keep the same old files when resuming + # check that sum of lens of files in dataset is greater than needed_num_samples_dataset (use dataset.lens) + total_num_samples_dataset = int(train_steps * global_batch_size * weights[i]) + log_rank(f"Dataset {i} ({dataset.folder_path}) on s3 has {len(dataset) * sequence_length} tokens (size: {human_format(len(dataset) * sequence_length * config.token_size_in_bytes)}) and needs {total_num_samples_dataset * sequence_length} tokens (size: {human_format(total_num_samples_dataset * sequence_length * config.token_size_in_bytes)}) for all stages", logger=logger, level=logging.INFO, rank=0) + assert total_num_samples_dataset <= len(dataset), f"Not enough files on s3 for dataset {i} ({dataset.folder_path})" + # check that local files exist for the needed_num_samples_dataset + estimate_end_sample = estimate_current_sample + needed_num_samples_dataset + for file_idx, file in enumerate(dataset.files): + # intersection [start_sample, end_sample] with [dataset.lens[file_idx], dataset.lens[file_idx+1]] + a, b, c, d = estimate_current_sample, estimate_end_sample, dataset.lens[file_idx], dataset.lens[file_idx+1] + if max(a, c) < min(b, d): # ranges overlap + assert os.path.exists(file.file_path), f"Dataset {i} ({dataset.folder_path}) will need file {file.file_path} but it does not exist" + log_rank(f"Dataset {i} ({dataset.folder_path}) will need file {file.file_path} from sample {max(a, c)} to {min(b, d)} (offset: {last_stages_consumed_samples_per_dataset_folder.get(dataset.folder_path, 0)})", logger=logger, level=logging.INFO, rank=0) + else: + log_rank(f"Dataset {i} ({dataset.folder_path}) will not need file {file.file_path} to train from sample {estimate_current_sample} to {estimate_end_sample} (offset: {last_stages_consumed_samples_per_dataset_folder.get(dataset.folder_path, 0)})", logger=logger, level=logging.INFO, rank=0) + + if len(datasets) == 1 and False: outputs_dataset = datasets[0] else: @@ -582,6 +630,7 @@ def get_tb_datasets( parallel_context=parallel_context, seed=seed, consumed_tokens_per_dataset_folder=consumed_tokens_per_dataset_folder, + offsets_in_samples=last_stages_consumed_samples_per_dataset_folder, ) log_rank("Streamable datasets ready.", logger=logger, level=logging.INFO, rank=0) @@ -625,7 +674,7 @@ def get_tb_dataloader( dataset = EmptyInfiniteDataset(length=len(dataset)) log_rank( - f"Building dataloader with consumed samples: {consumed_samples}", logger=logger, level=logging.INFO, rank=0 + f"Building dataloader with consumed samples for current datastage: {consumed_samples}", logger=logger, level=logging.INFO, rank=0 ) # Megatron sampler # batch_sampler = MegatronPretrainingRandomSampler( diff --git a/src/nanotron/eval/one_job_runner.py b/src/nanotron/eval/one_job_runner.py index 43d1a7653..6567ec94a 100644 --- a/src/nanotron/eval/one_job_runner.py +++ b/src/nanotron/eval/one_job_runner.py @@ -60,13 +60,18 @@ def eval_single_checkpoint(self, uploaded_files: List[dict]) -> Tuple[str, str]: logger.warning( f"Lighteval Runner got {len(uploaded_files)} files. Using {checkpoint_path} as checkpoint path." ) - - slurm_job_id, slurm_log = run_slurm_one_job( - config=self.config, - lighteval_config=self.lighteval_config, - model_checkpoint_path=checkpoint_path, - current_step=self.config.general.step, - ) + if self.config.general.step % self.lighteval_config.eval_interval == 0: + slurm_job_id, slurm_log = run_slurm_one_job( + config=self.config, + lighteval_config=self.lighteval_config, + model_checkpoint_path=checkpoint_path, + current_step=self.config.general.step, + ) + else: + logger.warning( + f"Skipping evaluation at step {self.config.general.step} because it's not a multiple of {self.lighteval_config.eval_interval}" + ) + return None, None return slurm_job_id, slurm_log @@ -130,7 +135,8 @@ def run_slurm_one_job( #SBATCH --exclusive #SBATCH --qos={slurm_config.qos} #SBATCH --time={slurm_config.time} -#SBATCH --output={eval_logs_path}/%j-{timestamp}.out""" +#SBATCH --output={eval_logs_path}/%j-{timestamp}.out +#SBATCH --requeue""" if slurm_config.reservation: slurm_script += f"\n#SBATCH --reservation={slurm_config.reservation}" @@ -250,7 +256,23 @@ def run_slurm_one_job( --cache-dir {slurm_config.hf_cache}""" if lighteval_config.output_dir is not None and lighteval_config.s3_save_path is not None: slurm_script += f""" -s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path} +s5cmd cp --if-size-differ "{lighteval_config.output_dir}*" {lighteval_config.s3_save_path}/ +""" + if lighteval_config.upload_to_wandb: + gbs_tok = ( + config.parallelism.dp + * config.tokens.micro_batch_size + * config.tokens.sequence_length + * config.tokens.batch_accumulation_per_replica + ) + slurm_script += f""" +python {nanotron_path}/src/nanotron/eval/upload_to_wandb.py \\ + --wandb_project {lighteval_config.wandb_project} \\ + --wandb_entity {lighteval_config.wandb_entity} \\ + --model_name {general_run_name} \\ + --results_path {lighteval_config.s3_save_path}/results/results/{general_run_name}/{current_step}/ \\ + --train_step {current_step} \\ + --consumed_tokens {current_step*gbs_tok} """ slurm_script += """ echo "Cleaning up downloaded checkpoints..." diff --git a/src/nanotron/eval/upload_to_wandb.py b/src/nanotron/eval/upload_to_wandb.py new file mode 100644 index 000000000..aa8c12d41 --- /dev/null +++ b/src/nanotron/eval/upload_to_wandb.py @@ -0,0 +1,87 @@ +import json +import s3fs +import wandb +import re +import argparse +from wandb.sdk.lib.runid import generate_id + + +def push_to_wandb(wandb_project, wandb_entity, model_name, results_path, train_step, consumed_tokens): + s3 = s3fs.S3FileSystem(anon=False) + all_metrics = { + # basic X axis replacements for all metrics + "consumed_tokens": consumed_tokens, + "train_step": train_step, + } + + for result_file in sorted(s3.ls(results_path)): + if not result_file.endswith(".json"): + continue + + with s3.open(result_file, "r") as f: + results = json.loads(f.read())["results"] + + for benchmark, metrics in results.items(): + if benchmark == "all": + continue + + # extract dataset and config name + match = re.search(r"\|(.*?)(?::(.*?))?\|", benchmark) + if match: + dataset, subtask = match.groups() + + for metric_name, metric_value in metrics.items(): + if "_stderr" in metric_name: + continue + # wandb-friendly metric name + wandb_metric = f"{dataset}/{subtask}/{metric_name}" if subtask else f"{dataset}/{metric_name}" + all_metrics[wandb_metric] = metric_value + + run_id = f"{model_name}-{generate_id()}" + + # try to find the run in wandb and resume it + api = wandb.Api() + runs = api.runs(f"{wandb_entity}/{wandb_project}") + for run in runs: + if run.name == model_name: + run_id = run.id + break + + wandb.init( + project=wandb_project, + entity=wandb_entity, + name=model_name, + id=run_id, + config={ + "model_name": model_name, + }, + resume="allow", + ) + + # log all metrics for this checkpoint + wandb.log(all_metrics) + + wandb.finish() + +if __name__ == "__main__": + # Setup argument parser + parser = argparse.ArgumentParser(description="Upload evaluation results to Weights & Biases.") + parser.add_argument("--wandb_project", type=str, required=True, help="WandB project name.") + parser.add_argument("--wandb_entity", type=str, required=True, help="WandB entity name.") + parser.add_argument("--model_name", type=str, required=True, help="Name of the model.") + parser.add_argument("--results_path", type=str, required=True, help="S3 path to the results directory.") + parser.add_argument("--train_step", type=int, required=True, help="Training step corresponding to the checkpoint.") + parser.add_argument("--consumed_tokens", type=int, required=True, help="Total consumed tokens up to this checkpoint.") + + # Parse arguments + args = parser.parse_args() + + # Call the main function with parsed arguments + push_to_wandb( + wandb_project=args.wandb_project, + wandb_entity=args.wandb_entity, + model_name=args.model_name, + results_path=args.results_path, + train_step=args.train_step, + consumed_tokens=args.consumed_tokens + ) diff --git a/src/nanotron/generation/decode.py b/src/nanotron/generation/decode.py index 338801100..af89163eb 100644 --- a/src/nanotron/generation/decode.py +++ b/src/nanotron/generation/decode.py @@ -241,6 +241,7 @@ def decode_text( pipeline_state = PipelineEvalBatchState() with attach_pipeline_state_to_model(model=model, pipeline_state=pipeline_state): # We query the first `pipeline_size` batches + for batches in chunks( iterable=micro_batcher( input_iter=input_iter, @@ -306,6 +307,7 @@ def decode_text( batch_generated_ids = state.new_input_ids batch_generated_mask = state.new_input_mask position_ids = get_position_ids(batch_generated_ids, tokenizer) + sharded_logits = model( input_ids=batch_generated_ids, position_ids=position_ids, # [batch_size, seq_len] diff --git a/src/nanotron/helpers.py b/src/nanotron/helpers.py index e83355e15..4ec0ade4b 100644 --- a/src/nanotron/helpers.py +++ b/src/nanotron/helpers.py @@ -447,6 +447,7 @@ def grad_optimizer_builder(named_param_groups): model.register_comm_hook( state=FP32GradBucketManager( dp_pg=parallel_context.dp_pg, + ep_dp_pg=parallel_context.ep_dp_pg if parallel_context.enabled_moe else None, accumulator=grad_accumulator, param_id_to_name={ id(param): param.get_tied_info().get_full_name_from_module_id_to_prefix( @@ -818,7 +819,7 @@ def is_resume_from_training(): return 0 else: last_train_steps = metadata.last_train_step if is_resume_from_training() else stage.start_training_step - return total_train_steps - last_train_steps + return total_train_steps - last_train_steps + 1 def get_consumed_train_samples_of_a_data_stage_from_ckp( @@ -826,6 +827,7 @@ def get_consumed_train_samples_of_a_data_stage_from_ckp( ) -> Optional[int]: start_training_step = stage.start_training_step + # find the stage in the metadata using the start_training_step actual_stage = next( (s for s in metadata.data_stages if s.start_training_step == start_training_step), None, diff --git a/src/nanotron/logging/base.py b/src/nanotron/logging/base.py index b14b94aab..8e2f94a74 100644 --- a/src/nanotron/logging/base.py +++ b/src/nanotron/logging/base.py @@ -229,6 +229,7 @@ def log_rank( rank: Optional[int] = None, category: Optional[str] = None, is_separator: bool = False, + main_rank_only: bool = False, **kwargs, ): """Log only if the current process is the rank specified.""" @@ -246,6 +247,8 @@ def log_rank( kwargs["extra"] = kwargs.get("extra", {}) kwargs["extra"]["separator"] = True + if main_rank_only: + rank = 0 # rank is None means everyone logs if rank is None or dist.get_rank(group) == rank: if is_separator: @@ -429,9 +432,9 @@ def log_libraries_versions(logger: logging.Logger): log_rank(f"datasets version: {datasets.__version__}", logger=logger, level=logging.INFO, rank=0) log_rank(f"flash-attn version: {flash_attn.__version__}", logger=logger, level=logging.INFO, rank=0) log_rank(f"numpy version: {numpy.__version__}", logger=logger, level=logging.INFO, rank=0) - log_rank( - f"\ntorch.utils.collect_env: {torch.utils.collect_env.main()}", logger=logger, level=logging.INFO, rank=0 - ) + # log_rank( + # f"\ntorch.utils.collect_env: {torch.utils.collect_env.main()}", logger=logger, level=logging.INFO, rank=0 + # ) _configure_library_root_logger() diff --git a/src/nanotron/logging/timers.py b/src/nanotron/logging/timers.py index 1129b9c6c..e3603f118 100644 --- a/src/nanotron/logging/timers.py +++ b/src/nanotron/logging/timers.py @@ -19,15 +19,23 @@ class TimerType(Enum): @dataclass class TimerRecord: - """Records timing information for a single timer.""" + """ + Records timing information for a single timer. + + By default, uses CUDA events for timing GPU operations, which provides more accurate + measurements of GPU execution time without forcing CPU-GPU synchronization. + + For CPU-only operations, you can use CPU-based timing by specifying timer_type=TimerType.CPU. + """ name: str - timer_type: TimerType = TimerType.CPU + timer_type: TimerType = TimerType.CUDA start_time: float = 0.0 end_time: float = 0.0 running: bool = False call_count: int = 0 cuda_sync: bool = False # Option to add CUDA synchronization for more accurate timings + enabled: bool = True # Allow individual timer to be enabled/disabled # For CPU timers we still track total_time _cpu_total_time: float = 0.0 @@ -48,7 +56,7 @@ def __exit__(self, exc_type, exc_val, exc_tb): def start(self) -> "TimerRecord": """Start the timer.""" - if self.name == "dummy": # disabled + if self.name == "dummy" or not self.enabled: # disabled return self if self.running: @@ -75,7 +83,7 @@ def start(self) -> "TimerRecord": def end(self) -> None: """End the timer, but don't compute elapsed time yet.""" - if self.name == "dummy": # disabled + if self.name == "dummy" or not self.enabled: # disabled return if not self.running: @@ -175,7 +183,17 @@ def average_time(self) -> float: class Timers: - """A collection of timers for tracking execution time in Nanotron.""" + """ + A collection of timers for tracking execution time in Nanotron. + + By default, timers use CUDA events for timing GPU operations, which provides several benefits: + 1. More accurate measurement of GPU execution time + 2. Reduced need for explicit CUDA synchronization + 3. Lower overhead compared to CPU-based timing with synchronization + 4. Better performance monitoring for distributed training + + For CPU-only operations, you can still use CPU-based timing by specifying timer_type=TimerType.CPU. + """ _instance = None _enabled = os.environ.get("ENABLE_TIMERS", "0") == "1" # Add global enable/disable flag @@ -202,39 +220,56 @@ def is_enabled(cls) -> bool: return cls._enabled def __call__( - self, name: str, timer_type: Union[TimerType, str] = TimerType.CPU, cuda_sync: bool = True + self, + name: str, + timer_type: Union[TimerType, str] = TimerType.CUDA, + cuda_sync: bool = False, + enabled: bool = bool(int(os.environ.get("ENABLE_TIMERS", "0"))), ) -> TimerRecord: """Get or create a timer with the given name. Can be used as a decorator, context manager, or directly: - - @nanotron_timer("name") # As decorator + - @nanotron_timer # As decorator with default CUDA timing + - @nanotron_timer("my_function") # As decorator with custom name + - @nanotron_timer(timer_type=TimerType.CPU) # As decorator with CPU timing - with nanotron_timer("name"): ... # As context manager - nanotron_timer("name").start(); ...; nanotron_timer("name").end() # Direct use Args: name: Name of the timer - timer_type: Type of timer, either TimerType.CPU or TimerType.CUDA - (or 'cpu'/'cuda' strings) - cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing + timer_type: Type of timer, either TimerType.CUDA (default) or TimerType.CPU + (or 'cuda'/'cpu' strings) + cuda_sync: Whether to perform torch.cuda.synchronize() for more accurate CUDA timing. + Default is False to avoid unnecessary synchronization overhead. + enabled: Override default enabled setting from environment variable + + Raises: + ValueError: If a timer with the same name already exists with different settings """ - if not self._enabled: - # Return a dummy timer that does nothing when timing is disabled - return TimerRecord(name="dummy", timer_type=TimerType.CPU) - if isinstance(timer_type, str): timer_type = TimerType(timer_type) - if callable(name) and timer_type == TimerType.CPU: - # Being used as a decorator with default settings + if callable(name): + # Being used as a decorator with specified or default settings func = name timer_name = func.__name__ - return self._create_timer_decorator(timer_name, TimerType.CPU, cuda_sync)(func) + return self._create_timer_decorator(timer_name, timer_type, cuda_sync, enabled)(func) - if name not in self._timers: - self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync) - else: - # Update the cuda_sync option if the timer already exists - self._timers[name].cuda_sync = cuda_sync + if name in self._timers: + existing_timer = self._timers[name] + if ( + existing_timer.timer_type != timer_type + or existing_timer.cuda_sync != cuda_sync + or existing_timer.enabled != enabled + ): + raise ValueError( + f"Timer '{name}' already exists with different settings.\n" + f"Existing: type={existing_timer.timer_type}, cuda_sync={existing_timer.cuda_sync}, enabled={existing_timer.enabled}\n" + f"New: type={timer_type}, cuda_sync={cuda_sync}, enabled={enabled}" + ) + return existing_timer + + self._timers[name] = TimerRecord(name=name, timer_type=timer_type, cuda_sync=cuda_sync, enabled=enabled) # Check if we're being called as a decorator if not callable(name): @@ -243,9 +278,9 @@ def __call__( return timer_record # If we get here, we're being called as @nanotron_timer("name", timer_type) - return self._create_timer_decorator(name, timer_type, cuda_sync) + return self._create_timer_decorator(name, timer_type, cuda_sync, enabled) - def _create_timer_decorator(self, name, timer_type, cuda_sync=False): + def _create_timer_decorator(self, name, timer_type=TimerType.CUDA, cuda_sync=False, enabled=None): """Create a decorator that times the execution of a function.""" def decorator(func): @@ -253,7 +288,7 @@ def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): - with self(name, timer_type, cuda_sync): + with self(name, timer_type, cuda_sync, enabled): return func(*args, **kwargs) return wrapper diff --git a/src/nanotron/models/base.py b/src/nanotron/models/base.py index af26c6da0..6bb4d6472 100644 --- a/src/nanotron/models/base.py +++ b/src/nanotron/models/base.py @@ -1,3 +1,4 @@ +import threading from abc import ABCMeta, abstractmethod from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Dict, Iterator, List, Optional, Tuple @@ -238,7 +239,34 @@ def build_model( return model -# TODO @thomasw21: Should this option override user defined options? Maybe not ... right now it does. +@contextmanager +def ignore_init_on_device_and_dtype(): + """ + A context manager that temporarily disables dtype enforcement from init_on_device_and_dtype. + + Example: + ```python + with init_on_device_and_dtype(device=torch.device("cuda"), dtype=torch.float32): + with ignore_init_on_device_and_dtype(): + # This parameter will keep its specified dtype (float32) + self.weight = nn.Parameter(torch.randn(..., dtype=torch.float32)) + ``` + """ + # Create a thread-local storage for the ignore flag + if not hasattr(ignore_init_on_device_and_dtype, "_ignore_flag"): + ignore_init_on_device_and_dtype._ignore_flag = threading.local() + + # Set the ignore flag + old_value = getattr(ignore_init_on_device_and_dtype._ignore_flag, "value", False) + ignore_init_on_device_and_dtype._ignore_flag.value = True + + try: + yield + finally: + # Restore the previous value + ignore_init_on_device_and_dtype._ignore_flag.value = old_value + + @contextmanager def init_on_device_and_dtype( device: torch.device = torch.device("cpu"), @@ -250,35 +278,30 @@ def init_on_device_and_dtype( device (`torch.device` defaults to `cpu`): Device to initialize all parameters on. dtype (`torch.dtype` defaults to `torch.float`): - Dtype to initialize all parameters on. - include_buffers (`bool`, defaults to `False`): - Whether or not to also default all buffers constructors given previous arguments. - Example: - ```python - import torch.nn as nn - from accelerate import init_on_device - with init_on_device_and_dtype(device=torch.device("cuda")): - tst = nn.Liner(100, 100) # on `cuda` device - ``` + Dtype to initialize all parameters on. If specified, will override any dtype + set in parameter initialization with a warning, unless within an ignore_init_on_device_and_dtype context. """ old_register_parameter = nn.Module.register_parameter old_register_buffer = nn.Module.register_buffer + def should_ignore_init_on_device_and_dtype(): + if not hasattr(ignore_init_on_device_and_dtype, "_ignore_flag"): + return False + return getattr(ignore_init_on_device_and_dtype._ignore_flag, "value", False) + def register_empty_parameter(module, name, param): old_register_parameter(module, name, param) if param is not None: - if isinstance(param, DTypeInvariantTensor): - # if param is DTypeInvariantTensor we should avoid updating it - param.data = param.data.to(device) + if should_ignore_init_on_device_and_dtype(): + pass else: param.data = param.data.to(device, dtype) def register_empty_buffer(module, name, buffer, persistent=True): old_register_buffer(module, name, buffer, persistent=persistent) if buffer is not None: - if isinstance(buffer, DTypeInvariantTensor): - # if buffer is DTypeInvariantTensor we should avoid updating it - buffer.data = buffer.data.to(device) + if should_ignore_init_on_device_and_dtype(): + pass else: module._buffers[name] = module._buffers[name].to(device, dtype) diff --git a/src/nanotron/models/inference_qwen.py b/src/nanotron/models/inference_qwen.py new file mode 100644 index 000000000..a9509332e --- /dev/null +++ b/src/nanotron/models/inference_qwen.py @@ -0,0 +1,942 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from flash_attn.modules.mha import flash_attn_varlen_kvpacked_func +from torch import nn +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import Config, ParallelismArgs +from nanotron.config.models_config import Qwen2Config, RandomInit, SpectralMupInit +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.nn.activations import ACT2FN +from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, get_attention_mask +from nanotron.nn.layer_norm import LlamaRMSNorm as RMSNorm +from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.rotary import RotaryEmbedding +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator + +logger = logging.get_logger(__name__) + + +class CoreAttention(nn.Module): + """Core attention module that can use different attention implementations""" + + def __init__( + self, + config: Qwen2Config, + tp_pg: dist.ProcessGroup, + cp_pg: dist.ProcessGroup, + layer_idx: int = 0, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.local_num_heads = self.num_heads // tp_pg.size() + self.local_num_kv_heads = self.num_kv_heads // tp_pg.size() + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) # Important for transformers's `sdpa_attention_forward` + self._attn_implementation = config._attn_implementation + self.cp_pg = cp_pg + self.sliding_window_size = config.sliding_window_size + self.simple_causal_mask = True # Use simple causal mask instead of computing custom attention mask if not document masking / sliding window + self.flex_attention_mask = config.flex_attention_mask if hasattr(config, "flex_attention_mask") else None + + def forward( + self, + query_states: torch.Tensor, # [b*s, num_heads, head_dim] + key_states: torch.Tensor, # [b*s, num_kv_heads, head_dim] + value_states: torch.Tensor, # [b*s, num_kv_heads, head_dim] + position_ids: torch.Tensor, # [b*s] + seq_length: Optional[int], + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + **kwargs, + ): + """Forward pass applying the chosen attention implementation""" + # Get the appropriate attention function + attention_func = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + + # Initialize variables for attention parameters + cu_seqlens = kwargs.get("cu_seqlens", None) + + # Shape tensors according to attention implementation + if self._attn_implementation == "ring_flash_triton": + query_states = query_states.view(-1, seq_length, self.local_num_heads, self.head_dim) + key_states = key_states.view(-1, seq_length, self.local_num_kv_heads, self.head_dim) + value_states = value_states.view(-1, seq_length, self.local_num_kv_heads, self.head_dim) + elif self._attn_implementation == "ring": + # Warning: Since this uses _flash_attn_varlen_forward make sure we count padding tokens in cu_seqlens + query_states = query_states.view(-1, self.local_num_heads, self.head_dim) + key_states = key_states.view(-1, self.local_num_kv_heads, self.head_dim) + value_states = value_states.view(-1, self.local_num_kv_heads, self.head_dim) + else: + # Process attention mask based on implementation + if self.simple_causal_mask: + assert attention_mask is None, "Simple causal mask is not supported with custom attention mask" + assert self.sliding_window_size is None, "Simple causal mask is not supported with sliding window" + elif attention_mask is None and position_ids is not None: + # Determine if we need to create an attention mask from position_ids + if self._attn_implementation == "flex_attention" and self.sliding_window_size is not None: + # For FlexAttention with sliding window, we don't need an explicit mask + # The mask_mod function will handle it + pass + else: + # For other implementations, generate the attention mask if needed + # Only calculate if cu_seqlens wasn't passed + if cu_seqlens is None: + attention_mask, cu_seqlens = get_attention_mask(position_ids, seq_length=seq_length) + + if attention_mask is not None: + # Add batch and head dimensions for proper broadcasting + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + + attn_output = attention_func( + self, + query_states, # [b, num_heads, seq_len, head_dim] + key_states, # [b, num_kv_heads, seq_len, head_dim] + value_states, # [b, num_kv_heads, seq_len, head_dim] + attention_mask, # [b, num_heads, seq_len, seq_len] + max_seqlen=seq_length, + dropout=dropout, + scaling=None, # by default, scaling is head_dim**-0.5 + sliding_window=self.sliding_window_size, + ring_pg=self.cp_pg, + position_ids=position_ids if self._attn_implementation == "flex_attention" else None, + document_ids=kwargs.get("document_ids", None) if self._attn_implementation == "flex_attention" else None, + flex_attention_mask=self.flex_attention_mask if self._attn_implementation == "flex_attention" else None, + **kwargs, # Pass remaining kwargs + )[0] + + return attn_output.view( + -1, self.local_num_heads * self.head_dim + ) # [b*s, num_heads, head_dim] -> [b*s, num_heads*head_dim] + + +class Qwen2Attention(nn.Module): + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + cp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.tp_pg_size = tp_pg.size() + + # Head configuration + self.num_heads = config.num_attention_heads + self.local_num_heads = self.num_heads // self.tp_pg_size + + # KV head configuration + self.num_kv_heads = config.num_key_value_heads + self.local_num_kv_heads = self.num_kv_heads // self.tp_pg_size + + # Dimensions + self.head_dim = config.hidden_size // self.num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.local_q_size = self.local_num_heads * self.head_dim + self.local_kv_size = self.local_num_kv_heads * self.head_dim + + # TP mode configuration + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + qkv_contiguous_chunks = ( + self.q_size, # Q chunk size + self.kv_size, # K chunk size + self.kv_size, # V chunk size + ) + self.qkv_proj = TensorParallelColumnLinear( + self.hidden_size, + self.q_size + 2 * self.kv_size, + pg=tp_pg, + mode=tp_mode, + bias=config.attention_bias, # Qwen2 uses bias for QKV, Llama doesn't + async_communication=tp_linear_async_communication, + contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, + ) + self.o_proj = TensorParallelRowLinear( + self.num_heads * self.head_dim, + self.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + if config._use_qkv_packed: + from nanotron.nn.rotary import FlashRotaryEmbedding + + self.rotary_emb = FlashRotaryEmbedding( + dim=self.head_dim, + base=config.rope_theta, + interleaved=config.rope_interleaved, + ) + else: + self.rotary_emb = RotaryEmbedding( + dim=self.head_dim, + max_seq_len=config.max_position_embeddings, + base=config.rope_theta, + interleaved=config.rope_interleaved, + seq_len_scaling_factor=None, + fused=config._fused_rotary_emb, + ) + self.attention = CoreAttention(config, tp_pg, cp_pg, layer_idx) + self.simple_causal_mask = True + self._use_qkv_packed = config._use_qkv_packed + + # TODO: support doc masking / SWA / SFT / inference + + def forward( + self, + hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] + position_ids: torch.Tensor, # [batch_size, seq_length] where -1 is padding + cu_seqlens: Optional[torch.Tensor] = None, # Added cu_seqlens argument + inference_max_seqlen: Optional[int] = None, + ): + # [0, 1, 2, 3, 4, 0, 1, 2, -1, -1, -1] # 2 documents with 5 and 3 tokens then padding + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 1 document with 11 tokens + # [0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1] # 1 document with 10 tokens then padding + # Replace -1 with 0 in position_ids to mark every padding token as a separate sequence. Ideally we want to get rid of padding tokens from qkv + # position_ids = position_ids.masked_fill(position_ids == -1, 0) + if position_ids.ndim == 2: + seq_length = position_ids.shape[1] + flat_position_ids = position_ids.view(-1) # [batch_size*seq_length] + else: + assert ( + inference_max_seqlen is not None + ), "inference_max_seqlen must be provided if position_ids is a 1D tensor" + seq_length = inference_max_seqlen + flat_position_ids = position_ids + qkv = self.qkv_proj(hidden_states) + + if self._use_qkv_packed: + attn_output = self._forward_packed(qkv, seq_length, flat_position_ids, cu_seqlens) + else: + q, k, v = qkv.split( + [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 + ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size] + q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim] + k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + rotary_pos_emb = self.rotary_emb( + position_ids=flat_position_ids if not self.simple_causal_mask else None, seq_length=seq_length + ) # [b*s, dim] or [seq_length, dim] + q = self.rotary_emb.apply_rotary_pos_emb( + q, rotary_pos_emb, seq_length=seq_length + ) # [b*s, num_heads, head_dim] + k = self.rotary_emb.apply_rotary_pos_emb( + k, rotary_pos_emb, seq_length=seq_length + ) # [b*s, num_kv_heads, head_dim] + else: + log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + attn_output = self.attention( + q, k, v, position_ids=flat_position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + ) + output = self.o_proj(attn_output) + return {"hidden_states": output, "position_ids": position_ids} + + def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): + assert cu_seqlens is not None, "cu_seqlens must be provided for packed attention" + q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn + kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn + + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + if self.training: + q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) + kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + q, kv = self.rotary_emb( + q, kv, seqlen_offset=0, max_seqlen=None + ) # TODO: should we use position_ids here? flash_attn doesn't + else: + # TODO: support seqlen_offsets in case of use_cache + # qkv = qkv.view(-1, self.local_num_heads + 2 * self.local_num_kv_heads, self.head_dim) + # self.rotary_emb.varlen_forward(qkv, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + # qkv = qkv.view(-1, (self.local_num_heads + 2 * self.local_num_kv_heads) * self.head_dim) + # q = qkv[..., : self.local_num_heads * self.head_dim] + # kv = qkv[..., self.local_num_heads * self.head_dim :] + q = q.view(-1, self.local_num_heads, self.head_dim) + kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) + k = kv[:, 0] + self.rotary_emb.varlen_forward(q, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + self.rotary_emb.varlen_forward(k, seqlen_offsets=0, cu_seqlens=cu_seqlens, max_seqlen=seq_length) + else: + log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + q = q.view(-1, self.local_num_heads, self.head_dim) + kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) + max_seqlen = seq_length # TODO: should this be max position_ids? As long as it doesn't change often it and not too big should be fine + + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + softmax_scale=None, + causal=True, # TODO: double check + alibi_slopes=None, + window_size=(-1, -1), # TODO: fix + deterministic=False, + ) # Not contiguous, similar to flash_attn + # flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730 + return attn_output.reshape(-1, self.local_num_heads * self.head_dim) # [b*s, num_heads*head_dim] + + +class Qwen2MLP(nn.Module): + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + intermediate_size: int, + ) -> None: + super().__init__() + + # Get TP mode and communication settings + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + gate_up_contiguous_chunks = ( + intermediate_size, # shape of gate_linear + intermediate_size, # shape of up_linear + ) + + self.gate_up_proj = TensorParallelColumnLinear( + config.hidden_size, + 2 * intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, # Qwen2 doesn't use bias for gate_up_proj + async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, + ) + + # Define down projection + self.down_proj = TensorParallelRowLinear( + intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, # Qwen2 doesn't use bias for down_proj + async_communication=tp_linear_async_communication, + ) + + # Define activation function (silu followed by multiplication) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + # Apply gate_up_proj to get gate and up projections + merged_states = self.gate_up_proj(hidden_states) + + # Apply activation function (SiLU and Mul) + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + hidden_states = self.act(gate_states) * up_states + + # Apply down projection + hidden_states = self.down_proj(hidden_states) + + return {"hidden_states": hidden_states} + + +class Qwen2DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + cp_pg: dist.ProcessGroup, + layer_idx: int, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Use fused RMSNorm if configured + norm_class = TritonRMSNorm if config._fused_rms_norm else RMSNorm + self.input_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = Qwen2Attention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + cp_pg=cp_pg, + layer_idx=layer_idx, + ) + self.post_attention_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) + + # Use MoE layer if this layer is in the MoE layers list + if config.moe_config and layer_idx in config.moe_config.layers: + from nanotron.nn.moe import Qwen2MoELayer + + self.mlp = Qwen2MoELayer( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + ) + else: + self.mlp = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + intermediate_size=config.intermediate_size, + ) + + self.recompute_layer = parallel_config.recompute_layer + + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], # [batch_size*seq_length, hidden_size] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + cu_seqlens: Union[torch.Tensor, TensorPointer], + inference_max_seqlen: Optional[int] = None, + ) -> List[Union[torch.Tensor, TensorPointer]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + output = self.attn( + hidden_states=hidden_states, + position_ids=position_ids, + cu_seqlens=cu_seqlens, + inference_max_seqlen=inference_max_seqlen, + ) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, position_ids, cu_seqlens + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, position_ids, cu_seqlens) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], + cu_seqlens: Union[torch.Tensor, TensorPointer], + inference_max_seqlen: Optional[int] = None, + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, position_ids, cu_seqlens = self._checkpointed_forward( + hidden_states, position_ids, cu_seqlens, inference_max_seqlen + ) + else: + hidden_states, position_ids, cu_seqlens = self._core_forward( + hidden_states, position_ids, cu_seqlens, inference_max_seqlen + ) + + return { + "hidden_states": hidden_states, + "position_ids": position_ids, + "cu_seqlens": cu_seqlens, + "inference_max_seqlen": inference_max_seqlen, + } + + +class Embedding(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup, config: Qwen2Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [...] + input_embeds = self.token_embedding(input_ids) # [..., hidden_size] + return {"input_embeds": input_embeds, "position_ids": position_ids} + + +class Qwen2Model(nn.Module): + """Build pipeline graph for Qwen2 model""" + + def __init__( + self, + config: Qwen2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + }, + module_input_keys={"input_ids", "position_ids"}, + module_output_keys={"input_embeds", "position_ids"}, + ) + + # Create decoder layers + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=Qwen2DecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "cp_pg": parallel_context.cp_pg, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "position_ids", "cu_seqlens", "inference_max_seqlen"}, + module_output_keys={"hidden_states", "position_ids", "cu_seqlens", "inference_max_seqlen"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonRMSNorm if config._fused_rms_norm else RMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Return sharded logits that will need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + ): + if not self.training: # inference case: + assert ( + position_ids.ndim == 2 + ), "position_ids must be 2D for inference, otherwise how do we know when to separate samples?" + inference_max_seqlen = position_ids.shape[1] + inference_batch_size = position_ids.shape[0] + # This gives the number of non-padding tokens per sequence in the batch + seqlens_in_batch = (position_ids != -1).sum(dim=-1, dtype=torch.int32) + input_ids = input_ids.view(-1) + position_ids = position_ids.view(-1) + # Find indices of non-padding tokens using the flattened position_ids + unpad_indices = torch.nonzero(position_ids != -1, as_tuple=False).flatten() + cu_seqlens = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=seqlens_in_batch.device), + torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), + ] + ) + unpadded_input_ids = input_ids[unpad_indices] # (total_tokens) = (total_unpadded_tokens) + unpadded_position_ids = position_ids[unpad_indices] # (total_tokens) + + # TODO: compute this in dataloader to avoid cpu-gpu sync + cu_seqlens = cu_seqlens.to(input_ids.device) + + output = self.token_position_embeddings( + input_ids=unpadded_input_ids, position_ids=unpadded_position_ids + ) # (total_tokens, hidden_size) + decoder_states = { + "hidden_states": output["input_embeds"], # Unpadded embeds [total_tokens, hidden_size] + "position_ids": output["position_ids"], # Unpadded pos_ids [total_tokens] + "cu_seqlens": cu_seqlens, # cu_seqlens for unpadded sequence [batch_size + 1] + "inference_max_seqlen": inference_max_seqlen, # original seq_length using for inference + } + + else: # Training case (handles potential packing) + # Get embeddings for the original (potentially padded/packed) sequence + output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) + # output["position_ids"] is the original position_ids [batch_size, seq_length] + + # Compute cu_seqlens based on document starts (position_id == 0) if data is packed + if position_ids.numel() > 0: + start_indices = torch.where(position_ids.view(-1) == 0)[0] + cu_seqlens = torch.cat( + [ + start_indices, + torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device), + ] + ).to(torch.int32) + else: + cu_seqlens = None # Or handle empty tensor case appropriately + + # Prepare state for decoder layers using original/padded/packed data + decoder_states = { + "hidden_states": output["input_embeds"], # Padded embeds [batch*seq_len, hidden_size] + "position_ids": output["position_ids"], # Original pos_ids [batch_size, seq_len] + "cu_seqlens": cu_seqlens, # Based on packing, might be None + } + + # Pass the prepared decoder_states dictionary to the decoder layers + for decoder_layer in self.decoder: + # Decoder layers need to handle both inference (unpadded) and training (padded/packed) states + decoder_states = decoder_layer(**decoder_states) + + # Final layer norm and LM head operate on the output hidden_states from the last decoder layer + hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + # Pad logits back to original shape if in inference mode + if not self.training: + assert inference_batch_size is not None and inference_max_seqlen is not None and unpad_indices is not None + # Create zero tensor with the full padded shape (flattened batch/seq) + padded_sharded_logits = torch.zeros( + inference_batch_size * inference_max_seqlen, + sharded_logits.shape[-1], # vocab_shard_size + dtype=sharded_logits.dtype, + device=sharded_logits.device, + ) + # Scatter the unpadded logits back into the zero tensor + padded_sharded_logits[unpad_indices] = sharded_logits + # Reshape to (batch_size, sequence_length, vocab_shard_size) + sharded_logits = padded_sharded_logits.view(inference_batch_size, inference_max_seqlen, -1) + + return sharded_logits + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model for load balancing.""" + model_config = self.config + d_ff = model_config.intermediate_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + block_compute_costs = { + # Self-attention (qkv proj + attn out) + MLP + Qwen2DecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 3 * d_ff * model_config.hidden_size, + # Final LM head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for the model""" + world_size = self.parallel_context.world_pg.size() + + # Get number of KV heads, accounting for potential absence in config + try: + num_key_value_heads = self.config.num_key_value_heads + except AttributeError: + num_key_value_heads = self.config.num_attention_heads + + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_key_value_heads=num_key_value_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.intermediate_size, + seq_len=sequence_length, + batch_size=global_batch_size, + ) + + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +@torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [batch_size*seq_length, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + sharded_logits = sharded_logits.view(label_ids.shape[0], label_ids.shape[1], -1) + loss = sharded_cross_entropy(sharded_logits, label_ids.contiguous(), group=self.tp_pg, dtype=torch.float) + loss = masked_mean(loss, label_mask, dtype=torch.float) + return {"loss": loss} + + +class LossWithZLoss(Loss): + def __init__(self, tp_pg: dist.ProcessGroup, z_loss_coefficient: float): + super().__init__(tp_pg) + self.z_loss_coef = z_loss_coefficient + + def forward( + self, + sharded_logits: torch.Tensor, # [batch_size*seq_length, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + sharded_logits = sharded_logits.view(label_ids.shape[0], label_ids.shape[1], -1) + loss, z_loss = sharded_cross_entropy( + sharded_logits, label_ids.contiguous(), group=self.tp_pg, dtype=torch.float, z_loss_coef=self.z_loss_coef + ) + loss = masked_mean(loss, label_mask, dtype=torch.float) + z_loss = masked_mean(z_loss.detach(), label_mask, dtype=torch.float) + return {"loss": loss, "z_loss": z_loss} + + +class Qwen2ForTraining(NanotronModel): + def __init__( + self, + config: Qwen2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + self.model = Qwen2Model(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + + # Choose the appropriate loss class based on config + loss_kwargs = { + "tp_pg": parallel_context.tp_pg, + } + if config.z_loss_enabled: + loss_kwargs["z_loss_coefficient"] = config.z_loss_coefficient + + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=LossWithZLoss if config.z_loss_enabled else Loss, + module_kwargs=loss_kwargs, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss", "z_loss"} if config.z_loss_enabled else {"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + sharded_logits = self.model( + input_ids=input_ids, + position_ids=position_ids, + ) + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + ) + if self.config.z_loss_enabled: + return {"loss": loss["loss"], "z_loss": loss["z_loss"]} + else: + return {"loss": loss["loss"]} + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly.""" + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(full_param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + if self.config.tie_word_embeddings is True: + # Should be similar to ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + else: + return [] + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) + + +def get_flops( + num_layers, + hidden_size, + num_heads, + num_key_value_heads, + vocab_size, + seq_len, + ffn_hidden_size, + batch_size=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + num_key_value_heads: number of key/value heads in the model + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + if num_key_value_heads is None: + num_key_value_heads = num_heads + hidden_size_per_head = hidden_size // num_heads + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention + ## qkv projection + decoder_qkv_proj_flops_fwd = ( + 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head + ) + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head + ## attn out + decoder_attn_out_flops_fwd = ( + 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size + ) + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + + decoder_flops_fwd = ( + decoder_qkv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd + + decoder_ffn_2_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO: This is a placeholder for now + + return model_flops, hardware_flops diff --git a/src/nanotron/models/llama.py b/src/nanotron/models/llama.py index db8206448..6224042fa 100644 --- a/src/nanotron/models/llama.py +++ b/src/nanotron/models/llama.py @@ -1126,7 +1126,7 @@ def init_model_randomly(self, config: Config): continue module = model.get_submodule(module_name) - parametrizator.parametrize(param_name, module) + parametrizator.parametrize(full_param_name, module) assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) diff --git a/src/nanotron/models/qwen.py b/src/nanotron/models/qwen.py index eee5cba38..a5574d044 100644 --- a/src/nanotron/models/qwen.py +++ b/src/nanotron/models/qwen.py @@ -3,7 +3,6 @@ import torch from flash_attn.modules.mha import flash_attn_varlen_kvpacked_func from torch import nn -from torch.nn import functional as F from torch.utils.checkpoint import CheckpointFunction from nanotron import distributed as dist @@ -11,6 +10,7 @@ from nanotron.config import Config, ParallelismArgs from nanotron.config.models_config import Qwen2Config, RandomInit, SpectralMupInit from nanotron.logging import log_rank +from nanotron.logging.timers import nanotron_timer from nanotron.models import NanotronModel from nanotron.nn.activations import ACT2FN from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, get_attention_mask @@ -193,12 +193,13 @@ def __init__( async_communication=tp_linear_async_communication, ) if config._use_qkv_packed: - from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + from nanotron.nn.rotary import FlashRotaryEmbedding self.rotary_emb = FlashRotaryEmbedding( dim=self.head_dim, base=config.rope_theta, interleaved=config.rope_interleaved, + seq_len_interpolation_factor=config.rope_seq_len_interpolation_factor, ) else: self.rotary_emb = RotaryEmbedding( @@ -206,14 +207,14 @@ def __init__( max_seq_len=config.max_position_embeddings, base=config.rope_theta, interleaved=config.rope_interleaved, - seq_len_scaling_factor=None, + # seq_len_scaling_factor=config.rope_seq_len_scaling_factor, fused=config._fused_rotary_emb, ) self.attention = CoreAttention(config, tp_pg, cp_pg, layer_idx) self.simple_causal_mask = True self._use_qkv_packed = config._use_qkv_packed - - # TODO: support doc masking / SWA / SFT / inference + self.sliding_window_size = config.sliding_window_size + # TODO: support SFT def forward( self, @@ -279,6 +280,7 @@ def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): assert cu_seqlens.dtype == torch.int32 assert max_seqlen is not None assert isinstance(max_seqlen, int) + attn_output = flash_attn_varlen_kvpacked_func( q, kv, @@ -288,9 +290,9 @@ def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): max_seqlen, 0.0, softmax_scale=None, - causal=True, # TODO: double check + causal=True, alibi_slopes=None, - window_size=(-1, -1), # TODO: fix + window_size=(self.sliding_window_size - 1, 0) if self.sliding_window_size is not None else (-1, -1), deterministic=False, ) # Not contiguous, similar to flash_attn # flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730 @@ -303,6 +305,8 @@ def __init__( config: Qwen2Config, parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, + hidden_size: int, + intermediate_size: int, ) -> None: super().__init__() @@ -312,14 +316,14 @@ def __init__( parallel_config.tp_linear_async_communication if parallel_config is not None else False ) - # Define gate_up_proj as a merged layer for gate and up projections gate_up_contiguous_chunks = ( - config.intermediate_size, # shape of gate_linear - config.intermediate_size, # shape of up_linear + intermediate_size, # shape of gate_linear + intermediate_size, # shape of up_linear ) + self.gate_up_proj = TensorParallelColumnLinear( - config.hidden_size, - 2 * config.intermediate_size, + hidden_size, + 2 * intermediate_size, pg=tp_pg, mode=tp_mode, bias=False, # Qwen2 doesn't use bias for gate_up_proj @@ -330,8 +334,8 @@ def __init__( # Define down projection self.down_proj = TensorParallelRowLinear( - config.intermediate_size, - config.hidden_size, + intermediate_size, + hidden_size, pg=tp_pg, mode=tp_mode, bias=False, # Qwen2 doesn't use bias for down_proj @@ -355,183 +359,6 @@ def forward(self, hidden_states): return {"hidden_states": hidden_states} -class Qwen2MoELayer(nn.Module): - """Mixture of experts Layer for Qwen2 models.""" - - def __init__( - self, - config: Qwen2Config, - parallel_config: Optional[ParallelismArgs], - tp_pg: dist.ProcessGroup, - layer_idx: int = 0, - ) -> None: - super().__init__() - self.hidden_size = config.hidden_size - self.intermediate_size = config.intermediate_size - - # MoE specific configurations - self.num_experts = config.moe_config.num_experts # Total number of experts - self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) - self.expert_parallel_size = getattr(parallel_config, "expert_parallel_size", 1) - self.num_local_experts = self.num_experts // self.expert_parallel_size # Experts per device - - # Get TP mode configuration - tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE - tp_linear_async_communication = ( - parallel_config.tp_linear_async_communication if parallel_config is not None else False - ) - - # Router for selecting experts - self.router = TensorParallelColumnLinear( - self.hidden_size, - self.num_experts, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, - ) - - # Enable shared experts if configured - self.enable_shared_expert = getattr(config.moe_config, "enable_shared_expert", False) - if self.enable_shared_expert: - self.shared_expert = Qwen2MLP( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - ) - self.shared_expert_gate = TensorParallelColumnLinear( - self.hidden_size, - 1, - pg=tp_pg, - mode=tp_mode, - bias=False, - async_communication=tp_linear_async_communication, - ) - - # Create the expert MLPs - self.experts = nn.ModuleList( - [ - Qwen2MLP( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - ) - for _ in range(self.num_local_experts) - ] - ) - - # Whether to recompute MoE layer during backward pass for memory efficiency - self.recompute_layer = getattr(parallel_config, "recompute_layer", False) - - # Token dispatcher type - determines communication pattern - self.token_dispatcher_type = getattr(config.moe_config, "token_dispatcher_type", "alltoall") - # For more sophisticated implementations, we would add token dispatcher logic here - - def _compute_router_probabilities(self, hidden_states): - """Compute routing probabilities for each token to each expert.""" - router_logits = self.router(hidden_states) # [batch_size*seq_length, num_experts] - - # Get the top-k experts per token - routing_weights, routing_indices = torch.topk(router_logits, k=self.num_experts_per_token, dim=-1) - - # Apply softmax on the top-k values - routing_weights = F.softmax(routing_weights, dim=-1) - - return routing_weights, routing_indices - - def _dispatch_tokens(self, hidden_states, routing_weights, routing_indices): - """ - Dispatches tokens to their selected experts. - In a full implementation, this would handle the actual token routing logic - including communication between devices. - """ - # Simplified implementation - in a complete version this would handle - # all-to-all or all-gather communications for distributed experts - - hidden_states.shape[0] - dispatched_inputs = [] - expert_counts = [] - - # For each expert, gather the tokens assigned to it - for expert_idx in range(self.num_local_experts): - # Find tokens that have this expert in their top-k - expert_mask = (routing_indices == expert_idx).any(dim=-1) - tokens_for_expert = hidden_states[expert_mask] - - # Get the routing weights for this expert - expert_positions = (routing_indices == expert_idx).nonzero(as_tuple=True) - token_positions, k_positions = expert_positions - expert_weights = routing_weights[token_positions, k_positions].unsqueeze(-1) - - # Scale inputs by routing weights - scaled_inputs = tokens_for_expert * expert_weights - - dispatched_inputs.append(scaled_inputs) - expert_counts.append(len(tokens_for_expert)) - - return dispatched_inputs, expert_counts - - def _combine_expert_outputs(self, expert_outputs, routing_indices, original_shape): - """ - Combines outputs from different experts back to the original tensor layout. - """ - # Initialize output tensor with zeros - combined_output = torch.zeros(original_shape, device=expert_outputs[0].device) - - for expert_idx, expert_output in enumerate(expert_outputs): - if expert_output.shape[0] == 0: # Skip if no tokens were routed to this expert - continue - - # Find positions where this expert was in the top-k - expert_mask = (routing_indices == expert_idx).any(dim=-1) - combined_output[expert_mask] += expert_output - - return combined_output - - def _core_forward(self, hidden_states): - """Core forward logic for MoE layer.""" - # Get router probabilities - routing_weights, routing_indices = self._compute_router_probabilities(hidden_states) - - # Dispatch tokens to experts - dispatched_inputs, expert_counts = self._dispatch_tokens(hidden_states, routing_weights, routing_indices) - - # Process tokens with their assigned experts - expert_outputs = [] - for expert_idx, (inputs, count) in enumerate(zip(dispatched_inputs, expert_counts)): - if count == 0: # Skip computation if no tokens assigned - expert_outputs.append(torch.tensor([], device=hidden_states.device)) - continue - - # Forward through the expert - output = self.experts[expert_idx](hidden_states=inputs)["hidden_states"] - expert_outputs.append(output) - - # Combine expert outputs - output = self._combine_expert_outputs(expert_outputs, routing_indices, hidden_states.shape) - - # Add shared expert contribution if enabled - if self.enable_shared_expert: - shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] - shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) - output = output + shared_gate * shared_expert_output - - return output - - def _checkpointed_forward(self, hidden_states): - """Apply gradient checkpointing to save memory during training.""" - return CheckpointFunction.apply(self._core_forward, True, hidden_states) - - def forward(self, hidden_states): - """Forward pass for the MoE layer.""" - if self.recompute_layer and self.training: - hidden_states = self._checkpointed_forward(hidden_states) - else: - hidden_states = self._core_forward(hidden_states) - - return {"hidden_states": hidden_states} - - class Qwen2DecoderLayer(nn.Module): def __init__( self, @@ -539,6 +366,7 @@ def __init__( parallel_config: Optional[ParallelismArgs], tp_pg: dist.ProcessGroup, cp_pg: dist.ProcessGroup, + parallel_context: ParallelContext, layer_idx: int, ) -> None: super().__init__() @@ -547,7 +375,6 @@ def __init__( norm_class = TritonRMSNorm if config._fused_rms_norm else RMSNorm self.input_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) - self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.attn = Qwen2Attention( config=config, parallel_config=parallel_config, @@ -555,21 +382,35 @@ def __init__( cp_pg=cp_pg, layer_idx=layer_idx, ) - self.post_attention_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) # Use MoE layer if this layer is in the MoE layers list if config.moe_config and layer_idx in config.moe_config.layers: - self.mlp = Qwen2MoELayer( - config=config, - parallel_config=parallel_config, - tp_pg=tp_pg, - layer_idx=layer_idx, - ) + + if config.moe_config.moe_impl == "nanotron": + from nanotron.nn.moe import Qwen2MoEMLPLayer + + self.mlp = Qwen2MoEMLPLayer( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + layer_idx=layer_idx, + ) + elif config.moe_config.moe_impl == "transformer_engine": + from nanotron.nn.te_moe import Qwen2MoEMLPLayer + + self.mlp = Qwen2MoEMLPLayer( + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + layer_idx=layer_idx, + ) else: self.mlp = Qwen2MLP( config=config, parallel_config=parallel_config, tp_pg=tp_pg, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, ) self.recompute_layer = parallel_config.recompute_layer @@ -684,6 +525,8 @@ def __init__( "parallel_config": parallel_config, "tp_pg": parallel_context.tp_pg, "cp_pg": parallel_context.cp_pg, + # TODO: directly pass the ep_pg process group instead of the parallel_context + "parallel_context": parallel_context, "layer_idx": layer_idx, }, module_input_keys={"hidden_states", "position_ids", "cu_seqlens"}, @@ -723,7 +566,10 @@ def forward( input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding ): + nanotron_timer("Token position embeddings", timer_type="cuda", cuda_sync=True).start() output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) + nanotron_timer("Token position embeddings", timer_type="cuda", cuda_sync=True).end() + # Compute cu_seqlens if position_ids.numel() > 0: start_indices = torch.where(position_ids.view(-1) == 0)[0] @@ -739,12 +585,18 @@ def forward( "cu_seqlens": cu_seqlens, } + nanotron_timer("Decoder layer", timer_type="cuda", cuda_sync=True).start() for decoder_layer in self.decoder: decoder_states = decoder_layer(**decoder_states) + nanotron_timer("Decoder layer", timer_type="cuda", cuda_sync=True).end() + nanotron_timer("Final layer norm", timer_type="cuda", cuda_sync=True).start() hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] + nanotron_timer("Final layer norm", timer_type="cuda", cuda_sync=True).end() + nanotron_timer("LM head", timer_type="cuda", cuda_sync=True).start() sharded_logits = self.lm_head(x=hidden_states)["logits"] + nanotron_timer("LM head", timer_type="cuda", cuda_sync=True).end() return sharded_logits @@ -871,6 +723,13 @@ def forward( label_ids: Union[torch.Tensor, TensorPointer], label_mask: Union[torch.Tensor, TensorPointer], ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + # import torch.distributed as dist + # if dist.get_rank() == 0: + # from pprint import pformat + # print(pformat([(n,p.shape) for n,p in self.named_parameters()])) + # print(pformat([(n, round(p.mean().item(), 4), round(p.std().item(), 4)) for n,p in self.named_parameters()])) + # assert False + sharded_logits = self.model( input_ids=input_ids, position_ids=position_ids, @@ -880,10 +739,12 @@ def forward( label_ids=label_ids, label_mask=label_mask, ) + outputs = {"loss": loss["loss"]} + if self.config.z_loss_enabled: - return {"loss": loss["loss"], "z_loss": loss["z_loss"]} - else: - return {"loss": loss["loss"]} + outputs["z_loss"] = loss["z_loss"] + + return outputs @torch.no_grad() def init_model_randomly(self, config: Config): @@ -930,7 +791,7 @@ def init_model_randomly(self, config: Config): continue module = model.get_submodule(module_name) - parametrizator.parametrize(param_name, module) + parametrizator.parametrize(full_param_name, module) assert full_param_name not in initialized_parameters initialized_parameters.add(full_param_name) diff --git a/src/nanotron/models/train_qwen.py b/src/nanotron/models/train_qwen.py new file mode 100644 index 000000000..3ae2b55d8 --- /dev/null +++ b/src/nanotron/models/train_qwen.py @@ -0,0 +1,855 @@ +from typing import Dict, List, Optional, Tuple, Union + +import torch +from flash_attn.modules.mha import flash_attn_varlen_kvpacked_func +from torch import nn +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import Config, ParallelismArgs +from nanotron.config.models_config import Qwen2Config, RandomInit, SpectralMupInit +from nanotron.logging import log_rank +from nanotron.models import NanotronModel +from nanotron.nn.activations import ACT2FN +from nanotron.nn.attention import ALL_ATTENTION_FUNCTIONS, get_attention_mask +from nanotron.nn.layer_norm import LlamaRMSNorm as RMSNorm +from nanotron.nn.layer_norm import TritonRMSNorm +from nanotron.nn.rotary import RotaryEmbedding +from nanotron.parallel import ParallelContext +from nanotron.parallel.parameters import NanotronParameter +from nanotron.parallel.pipeline_parallel.block import PipelineBlock, TensorPointer +from nanotron.parallel.pipeline_parallel.p2p import P2P +from nanotron.parallel.tensor_parallel.functional import sharded_cross_entropy +from nanotron.parallel.tensor_parallel.nn import ( + TensorParallelColumnLinear, + TensorParallelEmbedding, + TensorParallelLinearMode, + TensorParallelRowLinear, +) +from nanotron.random import RandomStates +from nanotron.scaling.parametrization import SpectralMupParametrizator, StandardParametrizator + +logger = logging.get_logger(__name__) + + +class CoreAttention(nn.Module): + """Core attention module that can use different attention implementations""" + + def __init__( + self, + config: Qwen2Config, + tp_pg: dist.ProcessGroup, + cp_pg: dist.ProcessGroup, + layer_idx: int = 0, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.head_dim = config.hidden_size // config.num_attention_heads + self.num_heads = config.num_attention_heads + self.num_kv_heads = config.num_key_value_heads + self.local_num_heads = self.num_heads // tp_pg.size() + self.local_num_kv_heads = self.num_kv_heads // tp_pg.size() + self.num_key_value_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) # Important for transformers's `sdpa_attention_forward` + self._attn_implementation = config._attn_implementation + self.cp_pg = cp_pg + self.sliding_window_size = config.sliding_window_size + self.simple_causal_mask = True # Use simple causal mask instead of computing custom attention mask if not document masking / sliding window + self.flex_attention_mask = config.flex_attention_mask if hasattr(config, "flex_attention_mask") else None + + def forward( + self, + query_states: torch.Tensor, # [b*s, num_heads, head_dim] + key_states: torch.Tensor, # [b*s, num_kv_heads, head_dim] + value_states: torch.Tensor, # [b*s, num_kv_heads, head_dim] + position_ids: torch.Tensor, # [b*s] + seq_length: Optional[int], + attention_mask: Optional[torch.Tensor] = None, + dropout: float = 0.0, + **kwargs, + ): + """Forward pass applying the chosen attention implementation""" + # Get the appropriate attention function + attention_func = ALL_ATTENTION_FUNCTIONS[self._attn_implementation] + + # Initialize variables for attention parameters + cu_seqlens = kwargs.get("cu_seqlens", None) + + # Shape tensors according to attention implementation + if self._attn_implementation == "ring_flash_triton": + query_states = query_states.view(-1, seq_length, self.local_num_heads, self.head_dim) + key_states = key_states.view(-1, seq_length, self.local_num_kv_heads, self.head_dim) + value_states = value_states.view(-1, seq_length, self.local_num_kv_heads, self.head_dim) + elif self._attn_implementation == "ring": + # Warning: Since this uses _flash_attn_varlen_forward make sure we count padding tokens in cu_seqlens + query_states = query_states.view(-1, self.local_num_heads, self.head_dim) + key_states = key_states.view(-1, self.local_num_kv_heads, self.head_dim) + value_states = value_states.view(-1, self.local_num_kv_heads, self.head_dim) + else: + # Process attention mask based on implementation + if self.simple_causal_mask: + assert attention_mask is None, "Simple causal mask is not supported with custom attention mask" + assert self.sliding_window_size is None, "Simple causal mask is not supported with sliding window" + elif attention_mask is None and position_ids is not None: + # Determine if we need to create an attention mask from position_ids + if self._attn_implementation == "flex_attention" and self.sliding_window_size is not None: + # For FlexAttention with sliding window, we don't need an explicit mask + # The mask_mod function will handle it + pass + else: + # For other implementations, generate the attention mask if needed + # Only calculate if cu_seqlens wasn't passed + if cu_seqlens is None: + attention_mask, cu_seqlens = get_attention_mask(position_ids, seq_length=seq_length) + + if attention_mask is not None: + # Add batch and head dimensions for proper broadcasting + attention_mask = attention_mask.unsqueeze(0).unsqueeze(0) # [1, 1, seq_length, seq_length] + + attn_output = attention_func( + self, + query_states, # [b, num_heads, seq_len, head_dim] + key_states, # [b, num_kv_heads, seq_len, head_dim] + value_states, # [b, num_kv_heads, seq_len, head_dim] + attention_mask, # [b, num_heads, seq_len, seq_len] + max_seqlen=seq_length, + dropout=dropout, + scaling=None, # by default, scaling is head_dim**-0.5 + sliding_window=self.sliding_window_size, + ring_pg=self.cp_pg, + position_ids=position_ids if self._attn_implementation == "flex_attention" else None, + document_ids=kwargs.get("document_ids", None) if self._attn_implementation == "flex_attention" else None, + flex_attention_mask=self.flex_attention_mask if self._attn_implementation == "flex_attention" else None, + **kwargs, # Pass remaining kwargs + )[0] + + return attn_output.view( + -1, self.local_num_heads * self.head_dim + ) # [b*s, num_heads, head_dim] -> [b*s, num_heads*head_dim] + + +class Qwen2Attention(nn.Module): + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + cp_pg: dist.ProcessGroup, + layer_idx: int, + ): + super().__init__() + self.config = config + self.layer_idx = layer_idx + self.hidden_size = config.hidden_size + self.tp_pg_size = tp_pg.size() + + # Head configuration + self.num_heads = config.num_attention_heads + self.local_num_heads = self.num_heads // self.tp_pg_size + + # KV head configuration + self.num_kv_heads = config.num_key_value_heads + self.local_num_kv_heads = self.num_kv_heads // self.tp_pg_size + + # Dimensions + self.head_dim = config.hidden_size // self.num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.local_q_size = self.local_num_heads * self.head_dim + self.local_kv_size = self.local_num_kv_heads * self.head_dim + + # TP mode configuration + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + qkv_contiguous_chunks = ( + self.q_size, # Q chunk size + self.kv_size, # K chunk size + self.kv_size, # V chunk size + ) + self.qkv_proj = TensorParallelColumnLinear( + self.hidden_size, + self.q_size + 2 * self.kv_size, + pg=tp_pg, + mode=tp_mode, + bias=config.attention_bias, # Qwen2 uses bias for QKV, Llama doesn't + async_communication=tp_linear_async_communication, + contiguous_chunks=qkv_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, + ) + self.o_proj = TensorParallelRowLinear( + self.num_heads * self.head_dim, + self.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, + async_communication=tp_linear_async_communication, + ) + if config._use_qkv_packed: + from flash_attn.layers.rotary import RotaryEmbedding as FlashRotaryEmbedding + + self.rotary_emb = FlashRotaryEmbedding( + dim=self.head_dim, + base=config.rope_theta, + interleaved=config.rope_interleaved, + ) + else: + self.rotary_emb = RotaryEmbedding( + dim=self.head_dim, + max_seq_len=config.max_position_embeddings, + base=config.rope_theta, + interleaved=config.rope_interleaved, + seq_len_scaling_factor=None, + fused=config._fused_rotary_emb, + ) + self.attention = CoreAttention(config, tp_pg, cp_pg, layer_idx) + self.simple_causal_mask = True + self._use_qkv_packed = config._use_qkv_packed + + # TODO: support doc masking / SWA / SFT / inference + + def forward( + self, + hidden_states: torch.Tensor, # [batch_size*seq_length, hidden_size] + position_ids: torch.Tensor, # [batch_size, seq_length] where -1 is padding + cu_seqlens: Optional[torch.Tensor] = None, # Added cu_seqlens argument + ): + # [0, 1, 2, 3, 4, 0, 1, 2, -1, -1, -1] # 2 documents with 5 and 3 tokens then padding + # [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10] # 1 document with 11 tokens + # [0, 1, 2, 3, 4, 5, 6, 7, 8, -1, -1] # 1 document with 10 tokens then padding + # Replace -1 with 0 in position_ids to mark every padding token as a separate sequence. Ideally we want to get rid of padding tokens from qkv + # position_ids = position_ids.masked_fill(position_ids == -1, 0) + seq_length = position_ids.shape[1] + # Keep original position_ids shape for return, flatten for internal use + position_ids = position_ids.view(-1) # [batch_size*seq_length] + + qkv = self.qkv_proj(hidden_states) + + if self._use_qkv_packed: + attn_output = self._forward_packed(qkv, seq_length, position_ids, cu_seqlens) + else: + q, k, v = qkv.split( + [self.local_q_size, self.local_kv_size, self.local_kv_size], dim=-1 + ) # [batch_size*seq_length, q_size], [batch_size*seq_length, kv_size] + q = q.view(-1, self.local_num_heads, self.head_dim) # [b*s, num_heads, head_dim] + k = k.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + v = v.view(-1, self.local_num_kv_heads, self.head_dim) # [b*s, num_kv_heads, head_dim] + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + rotary_pos_emb = self.rotary_emb( + position_ids=position_ids if not self.simple_causal_mask else None, seq_length=seq_length + ) # [b*s, dim] or [seq_length, dim] + q = self.rotary_emb.apply_rotary_pos_emb( + q, rotary_pos_emb, seq_length=seq_length + ) # [b*s, num_heads, head_dim] + k = self.rotary_emb.apply_rotary_pos_emb( + k, rotary_pos_emb, seq_length=seq_length + ) # [b*s, num_kv_heads, head_dim] + else: + log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + attn_output = self.attention( + q, k, v, position_ids=position_ids, seq_length=seq_length, cu_seqlens=cu_seqlens + ) + output = self.o_proj(attn_output) + # Return original position_ids shape + return {"hidden_states": output, "position_ids": position_ids.view(-1, seq_length)} + + def _forward_packed(self, qkv, seq_length, position_ids, cu_seqlens): + assert cu_seqlens is not None, "cu_seqlens must be provided for packed attention" + q = qkv[..., : self.local_num_heads * self.head_dim] # Not contiguous, similar to flash_attn + kv = qkv[..., self.local_num_heads * self.head_dim :] # Not contiguous, similar to flash_attn + q = q.view(-1, seq_length, self.local_num_heads, self.head_dim) + kv = kv.view(-1, seq_length, 2, self.local_num_kv_heads, self.head_dim) + if self.config.no_rope_layer is None or (self.layer_idx + 1) % self.config.no_rope_layer != 0: + q, kv = self.rotary_emb( + q, kv, seqlen_offset=0, max_seqlen=None + ) # TODO: should we use position_ids here? flash_attn doesn't + else: + log_rank(f"skipping rotary for layer {self.layer_idx + 1}", logger=logger, level=logging.DEBUG, rank=0) + q = q.view(-1, self.local_num_heads, self.head_dim) + kv = kv.view(-1, 2, self.local_num_kv_heads, self.head_dim) + max_seqlen = seq_length # TODO: should this be max position_ids? + + assert cu_seqlens.dtype == torch.int32 + assert max_seqlen is not None + assert isinstance(max_seqlen, int) + attn_output = flash_attn_varlen_kvpacked_func( + q, + kv, + cu_seqlens, + cu_seqlens, + max_seqlen, + max_seqlen, + 0.0, + softmax_scale=None, + causal=True, # TODO: double check + alibi_slopes=None, + window_size=(-1, -1), # TODO: fix + deterministic=False, + ) # Not contiguous, similar to flash_attn + # flash_attn use rearrange instead of reshape https://github.com/Dao-AILab/flash-attention/blob/1a58058a6da83bd7baaf4c512e8a1abe0240bb77/flash_attn/modules/mha.py#L730 + return attn_output.reshape(-1, self.local_num_heads * self.head_dim) # [b*s, num_heads*head_dim] + + +class Qwen2MLP(nn.Module): + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + intermediate_size: int, + ) -> None: + super().__init__() + + # Get TP mode and communication settings + tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + gate_up_contiguous_chunks = ( + intermediate_size, # shape of gate_linear + intermediate_size, # shape of up_linear + ) + + self.gate_up_proj = TensorParallelColumnLinear( + config.hidden_size, + 2 * intermediate_size, + pg=tp_pg, + mode=tp_mode, + bias=False, # Qwen2 doesn't use bias for gate_up_proj + async_communication=tp_linear_async_communication, + contiguous_chunks=gate_up_contiguous_chunks, + tp_recompute_allgather=parallel_config.tp_recompute_allgather, + ) + + # Define down projection + self.down_proj = TensorParallelRowLinear( + intermediate_size, + config.hidden_size, + pg=tp_pg, + mode=tp_mode, + bias=False, # Qwen2 doesn't use bias for down_proj + async_communication=tp_linear_async_communication, + ) + + # Define activation function (silu followed by multiplication) + self.act = ACT2FN[config.hidden_act] + + def forward(self, hidden_states): + # Apply gate_up_proj to get gate and up projections + merged_states = self.gate_up_proj(hidden_states) + + # Apply activation function (SiLU and Mul) + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + hidden_states = self.act(gate_states) * up_states + + # Apply down projection + hidden_states = self.down_proj(hidden_states) + + return {"hidden_states": hidden_states} + + +class Qwen2DecoderLayer(nn.Module): + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + tp_pg: dist.ProcessGroup, + cp_pg: dist.ProcessGroup, + layer_idx: int, + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Use fused RMSNorm if configured + norm_class = TritonRMSNorm if config._fused_rms_norm else RMSNorm + self.input_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.attn = Qwen2Attention( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + cp_pg=cp_pg, + layer_idx=layer_idx, + ) + self.post_attention_layernorm = norm_class(config.hidden_size, eps=config.rms_norm_eps) + + # Use MoE layer if this layer is in the MoE layers list + if config.moe_config and layer_idx in config.moe_config.layers: + from nanotron.nn.moe import Qwen2MoELayer + + self.mlp = Qwen2MoELayer( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + layer_idx=layer_idx, + ) + else: + self.mlp = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=tp_pg, + intermediate_size=config.intermediate_size, + ) + + self.recompute_layer = parallel_config.recompute_layer + + def _core_forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], # [batch_size*seq_length, hidden_size] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + cu_seqlens: Union[torch.Tensor, TensorPointer], + ) -> List[Union[torch.Tensor, TensorPointer]]: + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + + output = self.attn(hidden_states=hidden_states, position_ids=position_ids, cu_seqlens=cu_seqlens) + hidden_states = output["hidden_states"] + hidden_states = hidden_states + residual + + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states=hidden_states)["hidden_states"] + hidden_states = hidden_states + residual + + return hidden_states, position_ids, cu_seqlens + + def _checkpointed_forward( + self, + hidden_states: torch.Tensor, + position_ids: torch.Tensor, + cu_seqlens: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return CheckpointFunction.apply(self._core_forward, True, hidden_states, position_ids, cu_seqlens) + + def forward( + self, + hidden_states: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], + cu_seqlens: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + if self.recompute_layer and not isinstance(hidden_states, TensorPointer): + hidden_states, position_ids, cu_seqlens = self._checkpointed_forward( + hidden_states, position_ids, cu_seqlens + ) + else: + hidden_states, position_ids, cu_seqlens = self._core_forward(hidden_states, position_ids, cu_seqlens) + + return { + "hidden_states": hidden_states, + "position_ids": position_ids, + "cu_seqlens": cu_seqlens, + } + + +class Embedding(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup, config: Qwen2Config, parallel_config: Optional[ParallelismArgs]): + super().__init__() + self.token_embedding = TensorParallelEmbedding( + num_embeddings=config.vocab_size, + embedding_dim=config.hidden_size, + padding_idx=config.pad_token_id, + pg=tp_pg, + mode=parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE, + ) + self.pg = tp_pg + + def forward(self, input_ids: torch.Tensor, position_ids: torch.Tensor): # [batch_size, seq_length] + input_ids = input_ids.view(-1) # [batch_size*seq_length] + input_embeds = self.token_embedding(input_ids) # [batch_size*seq_length, hidden_size] + return {"input_embeds": input_embeds, "position_ids": position_ids} + + +class Qwen2Model(nn.Module): + """Build pipeline graph for Qwen2 model""" + + def __init__( + self, + config: Qwen2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + ): + super().__init__() + + # Declare all the nodes + self.p2p = P2P(parallel_context.pp_pg, device=torch.device("cuda")) + self.config = config + self.parallel_config = parallel_config + self.parallel_context = parallel_context + self.tp_mode = parallel_config.tp_mode if parallel_config is not None else TensorParallelLinearMode.ALL_REDUCE + tp_linear_async_communication = ( + parallel_config.tp_linear_async_communication if parallel_config is not None else False + ) + + self.token_position_embeddings = PipelineBlock( + p2p=self.p2p, + module_builder=Embedding, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + }, + module_input_keys={"input_ids", "position_ids"}, + module_output_keys={"input_embeds", "position_ids"}, + ) + + # Create decoder layers + self.decoder = nn.ModuleList( + [ + PipelineBlock( + p2p=self.p2p, + module_builder=Qwen2DecoderLayer, + module_kwargs={ + "config": config, + "parallel_config": parallel_config, + "tp_pg": parallel_context.tp_pg, + "cp_pg": parallel_context.cp_pg, + "layer_idx": layer_idx, + }, + module_input_keys={"hidden_states", "position_ids", "cu_seqlens"}, + module_output_keys={"hidden_states", "position_ids", "cu_seqlens"}, + ) + for layer_idx in range(config.num_hidden_layers) + ] + ) + + self.final_layer_norm = PipelineBlock( + p2p=self.p2p, + module_builder=TritonRMSNorm if config._fused_rms_norm else RMSNorm, + module_kwargs={"hidden_size": config.hidden_size, "eps": config.rms_norm_eps}, + module_input_keys={"input"}, + module_output_keys={"hidden_states"}, + ) + + self.lm_head = PipelineBlock( + p2p=self.p2p, + # Return sharded logits that will need to be gathered + module_builder=TensorParallelColumnLinear, + module_kwargs={ + "in_features": config.hidden_size, + "out_features": config.vocab_size, + "pg": parallel_context.tp_pg, + "bias": False, + "mode": self.tp_mode, + "async_communication": tp_linear_async_communication, + "tp_recompute_allgather": parallel_config.tp_recompute_allgather, + }, + module_input_keys={"x"}, + module_output_keys={"logits"}, + ) + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] + position_ids: Union[torch.Tensor, TensorPointer], # [batch_size, seq_length] where -1 is padding + ): + output = self.token_position_embeddings(input_ids=input_ids, position_ids=position_ids) + # Compute cu_seqlens + if position_ids.numel() > 0: + start_indices = torch.where(position_ids.view(-1) == 0)[0] + cu_seqlens = torch.cat( + [start_indices, torch.tensor([position_ids.numel()], dtype=torch.int32, device=start_indices.device)] + ).to(torch.int32) + else: + cu_seqlens = None + + decoder_states = { + "hidden_states": output["input_embeds"], + "position_ids": output["position_ids"], + "cu_seqlens": cu_seqlens, + } + + for decoder_layer in self.decoder: + decoder_states = decoder_layer(**decoder_states) + + hidden_states = self.final_layer_norm(input=decoder_states["hidden_states"])["hidden_states"] + + sharded_logits = self.lm_head(x=hidden_states)["logits"] + + return sharded_logits + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model for load balancing.""" + model_config = self.config + d_ff = model_config.intermediate_size + d_qkv = model_config.hidden_size // model_config.num_attention_heads + block_compute_costs = { + # Self-attention (qkv proj + attn out) + MLP + Qwen2DecoderLayer: 4 * model_config.num_attention_heads * d_qkv * model_config.hidden_size + + 3 * d_ff * model_config.hidden_size, + # Final LM head + TensorParallelColumnLinear: model_config.vocab_size * model_config.hidden_size, + } + return block_compute_costs + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for the model""" + world_size = self.parallel_context.world_pg.size() + + # Get number of KV heads, accounting for potential absence in config + try: + num_key_value_heads = self.config.num_key_value_heads + except AttributeError: + num_key_value_heads = self.config.num_attention_heads + + model_flops, hardware_flops = get_flops( + num_layers=self.config.num_hidden_layers, + hidden_size=self.config.hidden_size, + num_heads=self.config.num_attention_heads, + num_key_value_heads=num_key_value_heads, + vocab_size=self.config.vocab_size, + ffn_hidden_size=self.config.intermediate_size, + seq_len=sequence_length, + batch_size=global_batch_size, + ) + + model_flops_per_s = model_flops / (iteration_time_in_sec * world_size * 1e12) + hardware_flops_per_s = hardware_flops / (iteration_time_in_sec * world_size * 1e12) + return model_flops_per_s, hardware_flops_per_s + + +@torch.jit.script +def masked_mean(loss, label_mask, dtype): + # type: (Tensor, Tensor, torch.dtype) -> Tensor + return (loss * label_mask).sum(dtype=dtype) / label_mask.sum() + + +class Loss(nn.Module): + def __init__(self, tp_pg: dist.ProcessGroup): + super().__init__() + self.tp_pg = tp_pg + + def forward( + self, + sharded_logits: torch.Tensor, # [batch_size*seq_length, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + sharded_logits = sharded_logits.view(label_ids.shape[0], label_ids.shape[1], -1) + loss = sharded_cross_entropy(sharded_logits, label_ids.contiguous(), group=self.tp_pg, dtype=torch.float) + loss = masked_mean(loss, label_mask, dtype=torch.float) + return {"loss": loss} + + +class LossWithZLoss(Loss): + def __init__(self, tp_pg: dist.ProcessGroup, z_loss_coefficient: float): + super().__init__(tp_pg) + self.z_loss_coef = z_loss_coefficient + + def forward( + self, + sharded_logits: torch.Tensor, # [batch_size*seq_length, logits] + label_ids: torch.Tensor, # [batch_size, seq_length] + label_mask: torch.Tensor, # [batch_size, seq_length] + ) -> Dict[str, torch.Tensor]: + sharded_logits = sharded_logits.view(label_ids.shape[0], label_ids.shape[1], -1) + loss, z_loss = sharded_cross_entropy( + sharded_logits, label_ids.contiguous(), group=self.tp_pg, dtype=torch.float, z_loss_coef=self.z_loss_coef + ) + loss = masked_mean(loss, label_mask, dtype=torch.float) + z_loss = masked_mean(z_loss.detach(), label_mask, dtype=torch.float) + return {"loss": loss, "z_loss": z_loss} + + +class Qwen2ForTraining(NanotronModel): + def __init__( + self, + config: Qwen2Config, + parallel_context: ParallelContext, + parallel_config: Optional[ParallelismArgs], + random_states: Optional[RandomStates] = None, + ): + super().__init__() + self.model = Qwen2Model(config=config, parallel_context=parallel_context, parallel_config=parallel_config) + + # Choose the appropriate loss class based on config + loss_kwargs = { + "tp_pg": parallel_context.tp_pg, + } + if config.z_loss_enabled: + loss_kwargs["z_loss_coefficient"] = config.z_loss_coefficient + + self.loss = PipelineBlock( + p2p=self.model.p2p, + module_builder=LossWithZLoss if config.z_loss_enabled else Loss, + module_kwargs=loss_kwargs, + module_input_keys={ + "sharded_logits", + "label_ids", + "label_mask", + }, + module_output_keys={"loss", "z_loss"} if config.z_loss_enabled else {"loss"}, + ) + self.parallel_context = parallel_context + self.config = config + self.parallel_config = parallel_config + + def forward( + self, + input_ids: Union[torch.Tensor, TensorPointer], + position_ids: Union[torch.Tensor, TensorPointer], + label_ids: Union[torch.Tensor, TensorPointer], + label_mask: Union[torch.Tensor, TensorPointer], + ) -> Dict[str, Union[torch.Tensor, TensorPointer]]: + sharded_logits = self.model( + input_ids=input_ids, + position_ids=position_ids, + ) + loss = self.loss( + sharded_logits=sharded_logits, + label_ids=label_ids, + label_mask=label_mask, + ) + if self.config.z_loss_enabled: + return {"loss": loss["loss"], "z_loss": loss["z_loss"]} + else: + return {"loss": loss["loss"]} + + @torch.no_grad() + def init_model_randomly(self, config: Config): + """Initialize model parameters randomly.""" + init_method = config.model.init_method + if isinstance(init_method, RandomInit): + parametrizator_cls = StandardParametrizator + elif isinstance(init_method, SpectralMupInit): + parametrizator_cls = SpectralMupParametrizator + else: + raise ValueError(f"Unknown init method {init_method}") + + parametrizator = parametrizator_cls(config=config) + + log_rank( + f"Parametrizing model parameters using {parametrizator.__class__.__name__}", + logger=logger, + level=logging.INFO, + rank=0, + ) + + model = self + initialized_parameters = set() + # Handle tensor parallelism + module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} + # Fix the root_model + module_id_to_prefix[id(model)] = "" + + for param_name, param in model.named_parameters(): + assert isinstance(param, NanotronParameter) + + module_name, param_name = param_name.rsplit(".", 1) + + if param.is_tied: + tied_info = param.get_tied_info() + full_param_name = tied_info.get_full_name_from_module_id_to_prefix( + module_id_to_prefix=module_id_to_prefix + ) + else: + full_param_name = f"{module_name}.{param_name}" + + if full_param_name in initialized_parameters: + # Already initialized + continue + + module = model.get_submodule(module_name) + parametrizator.parametrize(full_param_name, module) + + assert full_param_name not in initialized_parameters + initialized_parameters.add(full_param_name) + + assert initialized_parameters == { + param.get_tied_info().get_full_name_from_module_id_to_prefix(module_id_to_prefix=module_id_to_prefix) + if param.is_tied + else name + for name, param in model.named_parameters() + }, f"Somehow the initialized set of parameters don't match:\n - Expected: { {name for name, _ in model.named_parameters()} }\n - Got: {initialized_parameters}" + + def get_embeddings_lm_head_tied_names(self): + """Get the names of the tied embeddings and lm_head weights""" + if self.config.tie_word_embeddings is True: + # Should be similar to ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + return ["model.token_position_embeddings.pp_block.token_embedding.weight", "model.lm_head.pp_block.weight"] + else: + return [] + + def get_block_compute_costs(self): + """Computes the compute cost of each block in the model so that we can do a better job of load balancing.""" + return self.model.get_block_compute_costs() + + def get_flops_per_sec(self, iteration_time_in_sec, sequence_length, global_batch_size): + """Get flops per second for a given model""" + return self.model.get_flops_per_sec(iteration_time_in_sec, sequence_length, global_batch_size) + + +def get_flops( + num_layers, + hidden_size, + num_heads, + num_key_value_heads, + vocab_size, + seq_len, + ffn_hidden_size, + batch_size=1, +): + """Counts flops in an decoder-only model + Args: + num_layers: number of decoder layers + hidden_size: hidden size of the model + num_heads: number of heads in the model + num_key_value_heads: number of key/value heads in the model + ffn_hidden_size: hidden size of the FFN + vocab_size: size of the vocabulary + seq_len: sequence length of the decoder + batch_size: batch size + Returns: + model_flops: flops in the model (should be independent of the hardware and model implementation) + hardware_flops: flops in the hardware (actual flops performed on the hardware). Check 6.3 in https://arxiv.org/pdf/2205.05198.pdf + """ + if num_key_value_heads is None: + num_key_value_heads = num_heads + hidden_size_per_head = hidden_size // num_heads + # In the following we mark the reduced dimension with parentheses + # decoder + # self attention + ## qkv projection + decoder_qkv_proj_flops_fwd = ( + 2 * num_layers * batch_size * seq_len * (hidden_size) * num_heads * hidden_size_per_head + + 2 * num_layers * batch_size * seq_len * (hidden_size) * 2 * num_key_value_heads * hidden_size_per_head + ) + ## qk logits + decoder_qk_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * seq_len + ## v logits + decoder_v_logits_flops_fwd = 2 * num_layers * batch_size * num_heads * seq_len * (seq_len) * hidden_size_per_head + ## attn out + decoder_attn_out_flops_fwd = ( + 2 * num_layers * batch_size * num_heads * seq_len * (hidden_size_per_head) * hidden_size + ) + # FF + ## 1st layer + decoder_ffn_1_flops_fwd = 4 * num_layers * batch_size * seq_len * (hidden_size) * ffn_hidden_size + ## 2nd layer + decoder_ffn_2_flops_fwd = 2 * num_layers * batch_size * seq_len * (ffn_hidden_size) * hidden_size + + decoder_flops_fwd = ( + decoder_qkv_proj_flops_fwd + + decoder_qk_logits_flops_fwd + + decoder_v_logits_flops_fwd + + decoder_attn_out_flops_fwd + + decoder_ffn_1_flops_fwd + + decoder_ffn_2_flops_fwd + ) + + # lm head + lm_head_flops_fwd = 2 * batch_size * seq_len * (hidden_size) * vocab_size + + # the bwd pass requires double the flops in case of matmuls to calculate the gradients with respect to + # both input and weight tensors + model_flops = 3 * (decoder_flops_fwd + lm_head_flops_fwd) # 1 for fwd + 2 for bwd + + hardware_flops = model_flops # TODO: This is a placeholder for now + + return model_flops, hardware_flops diff --git a/src/nanotron/nn/moe.py b/src/nanotron/nn/moe.py new file mode 100644 index 000000000..99ea9c428 --- /dev/null +++ b/src/nanotron/nn/moe.py @@ -0,0 +1,491 @@ +from dataclasses import dataclass +from typing import List, Optional + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional as F +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ParallelismArgs +from nanotron.config.models_config import Qwen2Config +from nanotron.models.base import ignore_init_on_device_and_dtype +from nanotron.nn.activations import ACT2FN +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( + all_to_all, + differentiable_all_gather, +) + +logger = logging.get_logger(__name__) + + +try: + import grouped_gemm.ops as ops +except ImportError: + raise RuntimeError( + "Grouped GEMM is not available. Please run `pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@main` (takes less than 5 minutes)" + ) + + +def is_expert_param(name: str) -> bool: + from nanotron.constants import EXPERT_PARAM_NAMES + + # return any(param in name for param in EXPERT_PARAM_NAMES) + return any(x in name for x in EXPERT_PARAM_NAMES) + + +def permute(x: torch.Tensor, routing_indices: torch.Tensor): + permuted_x, inverse_permute_mapping = ops.permute(x.to(torch.float32), routing_indices) + permuted_x = permuted_x.to(x.dtype) + return permuted_x, inverse_permute_mapping + + +def unpermute(x: torch.Tensor, inverse_mapping: torch.Tensor, routing_weights: torch.Tensor): + comebined_x = ops.unpermute(x.to(torch.float32), inverse_mapping, routing_weights) + return comebined_x.to(x.dtype) + + +def _get_dispatched_routing_indices(global_routing_indices, expert_parallel_size, num_experts): + num_local_experts = num_experts // expert_parallel_size + global_routing_indices_per_device = rearrange( + global_routing_indices, + "(expert_parallel_size num_tokens) num_experts -> expert_parallel_size (num_tokens num_experts)", + expert_parallel_size=expert_parallel_size, + ) + sorted_global_routing_indices_per_device, _ = torch.sort(global_routing_indices_per_device) + + dispatched_indices = [] + for ep_rank in range(expert_parallel_size): + start_idx = ep_rank * num_local_experts + end_idx = start_idx + num_local_experts + + mask = (sorted_global_routing_indices_per_device >= start_idx) & ( + sorted_global_routing_indices_per_device < end_idx + ) + dispatched_indices.append(sorted_global_routing_indices_per_device[mask]) + + return dispatched_indices + + +@dataclass +class MoELogging: + """ + num_local_tokens: List[torch.Tensor]: The number of tokens per local expert per layer + """ + + num_local_tokens: List[torch.Tensor] + + +class AllToAllDispatcher(nn.Module): + def __init__(self, num_local_experts: int, num_experts: int, ep_pg: dist.ProcessGroup): + super().__init__() + self.num_local_experts = num_local_experts + self.num_experts = num_experts + self.expert_parallel_size = dist.get_world_size(ep_pg) + self.ep_pg = ep_pg + + self.input_split_sizes = None + self.output_split_sizes = None + + self._use_torch_permute = True + + def _haojun_permute_topk(self, hidden_states, routing_indices): + """ + hidden_states: [num_tokens, hidden_dim] + routing_indices: [num_tokens, topk] + num_experts: total number of experts + Returns: + permuted: [num_tokens * topk, hidden_dim] + expert_counts: [num_experts], number of tokens assigned to each expert + permute_metadata: metadata for unpermute + """ + num_tokens, hidden_dim = hidden_states.shape + topk = routing_indices.shape[-1] + + # Expand hidden_states to match topk, shape: [num_tokens, topk, hidden_dim] + expanded_states = hidden_states.unsqueeze(1).expand(-1, topk, -1) + + # Flatten the batch: [num_tokens * topk, hidden_dim] + flat_states = expanded_states.reshape(-1, hidden_dim) + flat_indices = routing_indices.reshape(-1) # [num_tokens * topk] + + # Sort by expert (so tokens are grouped by expert index) + sorted_expert_indices, sort_order = flat_indices.sort() + permuted = flat_states[sort_order] # [num_tokens * topk, hidden_dim] + + # Count tokens per expert + num_tokens_per_expert = torch.bincount(sorted_expert_indices, minlength=self.num_experts) + + return permuted, (sort_order, flat_indices.shape[0]), num_tokens_per_expert + + def _haojun_unpermute_topk(self, permuted, sort_order, total_elements, routing_weights): + """ + permuted: [num_tokens * topk, hidden_dim], output from experts + sort_order: indices used to sort the tokens in permute + total_elements: num_tokens * topk + routing_weights: [num_tokens, topk], used to scale expert outputs before aggregation + Returns: + output: [num_tokens, hidden_dim], weighted sum over topk expert outputs + """ + device = permuted.device + hidden_dim = permuted.size(-1) + num_tokens, topk = routing_weights.shape + + # Restore original order + unsort_order = torch.empty_like(sort_order) + unsort_order[sort_order] = torch.arange(total_elements, device=device) + + # Restore the original [num_tokens * topk, hidden_dim] order + unpermuted = permuted[unsort_order] + + # Reshape to [num_tokens, topk, hidden_dim] + unpermuted = unpermuted.view(num_tokens, topk, hidden_dim) + + # Apply routing weights + routing_weights = routing_weights.to(permuted.dtype).unsqueeze(-1) # [num_tokens, topk, 1] + weighted_output = unpermuted * routing_weights + + # Aggregate over topk experts: sum over topk axis + output = weighted_output.sum(dim=1) # [num_tokens, hidden_dim] + + return output + + def permute( + self, + hidden_states: torch.Tensor, + routing_indices: torch.Tensor, + ): + """ + Dispatches tokens to their selected experts. + In a full implementation, this would handle the actual token routing logic + including communication between devices. + + + local_routing_indices: is the initial routing indices for the local experts's tokens + + dispatched_routing_indices: is the routing indices for the dispatched tokens corresponding to the local experts + + + inverse_permute_mapping: is the inverse of the permute mapping + + inverse_expert_sorting_index: is the inverse of the expert sorting index + + outputs: + + num_local_dispatched_tokens_per_expert: we return it in cpu so don't have to move to cpu for grouped_gemm.gmm + """ + + def calculate_output_split_sizes_for_rank(all_input_split_sizes, rank): + """ + Calculate output_split_sizes for a specific rank based on input_split_sizes from all ranks. + + Args: + all_input_split_sizes: List of lists where all_input_split_sizes[i] is the input_split_sizes for rank i + rank: The rank to calculate output_split_sizes for + + Returns: + List containing the output_split_sizes for the specified rank + """ + world_size = len(all_input_split_sizes) + output_split_sizes = [] + + # For each possible sender rank + for sender_rank in range(world_size): + # Get how much data sender_rank is sending to our rank + size = all_input_split_sizes[sender_rank][rank] + output_split_sizes.append(size.item()) + + return output_split_sizes + + with torch.autograd.profiler.record_function("AllToAllDispatcher.permute.all_to_all.pre"): + # NOTE: start from expert 0 to expert n + # NOTE: because the routing indices is global, + # but each expert device has a set of local experts + # so we need to align the routing indices to the local experts index + ep_rank = dist.get_rank(self.ep_pg) + num_tokens_per_expert = torch.bincount( + routing_indices.flatten(), minlength=self.num_experts + ) # [num_local_experts] + global_routing_indices = differentiable_all_gather(routing_indices, group=self.ep_pg) + + if self._use_torch_permute: + hidden_states, inverse_permute_mapping, _ = self._haojun_permute_topk(hidden_states, routing_indices) + else: + hidden_states, inverse_permute_mapping = permute(hidden_states, routing_indices) + + # NOTE: this part is all-to-all token dispatching + if self.expert_parallel_size > 1: + # NOTE: Reshape num_local_tokens_per_expert to [ep_size, num_local_experts] + # TODO: .view or .reshape? check which one is faster + # NOTE: this is incorrect in the case of imbalance + num_tokens_per_expert_device = num_tokens_per_expert.reshape( + self.expert_parallel_size, self.num_local_experts + ) + # NOTE: input_size_splits has a shape = [expert_parallel_size] + # where each value represent the number of tokens that we send from this device + # to [i]th device in the input_size_splits + # TODO: double check cpu-gpu sync + input_split_sizes = num_tokens_per_expert_device.sum(dim=1) + list_input_split_sizes = [ + torch.empty_like(input_split_sizes) for _ in range(self.expert_parallel_size) + ] + dist.all_gather(list_input_split_sizes, input_split_sizes, group=self.ep_pg) + + # NOTE: we can compute how many tokens this divide to receive from [i]th device globally + # NOTE: create a tensor corresponding to dist.get_rank(self.ep_pg) + # TODO: double check cpu-gpu sync + input_split_sizes = input_split_sizes.tolist() + output_split_sizes = calculate_output_split_sizes_for_rank( + list_input_split_sizes, dist.get_rank(self.ep_pg) + ) + else: + input_split_sizes, output_split_sizes = None, None + + with torch.autograd.profiler.record_function("AllToAllDispatcher.permute.all_to_all"): + dispatched_hidden_states = all_to_all( + hidden_states, + output_split_sizes=output_split_sizes, + input_split_sizes=input_split_sizes, + group=self.ep_pg, + ) + + self.input_split_sizes = input_split_sizes + self.output_split_sizes = output_split_sizes + + with torch.autograd.profiler.record_function("AllToAllDispatcher.permute.expert_index_sorting"): + # NOTE: a list of rotuing indices corresponding to the dispatched inputs + # we shouldn't sort the indices before permutation, + # but we keep the same expert value for each dispatched token, + # then the permutation function will handle the sorting and replicating for topk + dispatched_routing_indices = _get_dispatched_routing_indices( + global_routing_indices, self.expert_parallel_size, num_experts=self.num_experts + )[ep_rank] + # NOTE: we prefer to keep num_local_dispatched_tokens_per_expert on cpu, + # so we don't have to move it again for grouped_gemm + dispatched_routing_indices_cpu = dispatched_routing_indices.cpu() + + # NOTE: it should be the number of dispatched tokens per expert + # because we will use this for local grouped_gemm + # NOTE: the local_routing_indices has a global expert index, + # so we need to subtract the number of local experts to get the local expert index + num_local_dispatched_tokens_per_expert = torch.bincount( + dispatched_routing_indices_cpu - ep_rank * self.num_local_experts, minlength=self.num_local_experts + ) + + # NOTE: torch.bincount requires the indices to be int32 + # otherwise it raises: "RuntimeError: "bincount_cuda" not implemented for 'BFloat16'" + # NOTE: if dispatched_routing_indices only has a single value, + # then the shape of expert_sort_indices is a single scalar, but we want it to be a 1d tensor + # for the sorted_and_dispatched_hidden_states to has shape [num_tokens, d_model] + expert_sort_indices = torch.argsort(dispatched_routing_indices.squeeze(-1), stable=True) + expert_sort_indices = expert_sort_indices.view(-1) + + sorted_and_dispatched_hidden_states = dispatched_hidden_states[expert_sort_indices] + + return ( + sorted_and_dispatched_hidden_states, + inverse_permute_mapping, + expert_sort_indices, + num_local_dispatched_tokens_per_expert, + ) + + def unpermute(self, expert_outputs, inverse_permute_mapping, routing_weights, expert_sort_indices): + """ + Combines outputs from different experts back to the original tensor layout. + """ + + # NOTE: the expert_outputs here is sorted by the expert index + # so we need to unsort it back to the dispatching order + inverse_expert_sort_indices = torch.argsort(expert_sort_indices, stable=True) + # NOTE: expert_outputs is on cuda, inverse_expert_sort_indices is on cpu, how to remove a cuda sync point here? + expert_outputs = expert_outputs.index_select(0, inverse_expert_sort_indices) + + undispatched_expert_outputs = all_to_all( + expert_outputs, + output_split_sizes=self.input_split_sizes, + input_split_sizes=self.output_split_sizes, + group=self.ep_pg, + ) + + # NOTE: merging the expert output combination and un-permuting them back into a single operation + if self._use_torch_permute: + comebined_expert_outputs = self._haojun_unpermute_topk( + undispatched_expert_outputs, *inverse_permute_mapping, routing_weights + ) + else: + comebined_expert_outputs = unpermute(undispatched_expert_outputs, inverse_permute_mapping, routing_weights) + return comebined_expert_outputs + + +class Router(nn.Module): + def __init__(self, config: Qwen2Config, parallel_config: Optional[ParallelismArgs], layer_idx: int): + super().__init__() + self.config = config + self.parallel_config = parallel_config + self.layer_idx = layer_idx + + self.num_experts = config.moe_config.num_experts + self.num_experts_per_token = config.moe_config.top_k + + # float32 routing weights + # NOTE: qwen keep the routing weights in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L608 + with ignore_init_on_device_and_dtype(): + self.weight = nn.Parameter( + torch.randn(self.num_experts, config.hidden_size, dtype=torch.float32, device="cuda") + ) + assert self.weight.dtype == torch.float32 + + def gating(self, x: torch.Tensor) -> torch.Tensor: + """Compute logits for all experts (no softmax).""" + # NOTE: qwen keep the routing logits in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L613 + return F.linear(x.to(torch.float32), self.weight, bias=None) + + def routing(self, logits: torch.Tensor): + """Top-k softmax-normalized routing weights and indices.""" + routing_weights = F.softmax(logits, dim=-1, dtype=torch.float32) + routing_weights, routing_indices = torch.topk(routing_weights, k=self.num_experts_per_token, dim=-1) + routing_indices = routing_indices.to(torch.int32) # NOTE: ops.permute requires indices to be int32 + return routing_weights, routing_indices + + def forward(self, x: torch.Tensor): + logits = self.gating(x) + return self.routing(logits) + + +class GroupedMLP(nn.Module): + def __init__(self, config: Qwen2Config, parallel_config: Optional[ParallelismArgs], ep_pg: dist.ProcessGroup): + super().__init__() + moe_config = config.moe_config + + num_local_experts = moe_config.num_experts // parallel_config.expert_parallel_size + self.expert_parallel_size = parallel_config.expert_parallel_size + self.num_local_experts = torch.tensor(num_local_experts, dtype=torch.int32, device="cuda") + self.ep_pg = ep_pg + self.merged_gate_up_proj = nn.Parameter( + torch.randn(num_local_experts, moe_config.moe_hidden_size, 2 * moe_config.moe_intermediate_size) + ) + self.merged_down_proj = nn.Parameter( + torch.randn(num_local_experts, moe_config.moe_intermediate_size, moe_config.moe_hidden_size) + ) + self.act = ACT2FN[config.hidden_act] + + def forward( + self, + hidden_states: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor, + ): + """ + assume hidden_states is permuted + + grouped_gemm's notes: + ops.gemm expect the inputs to have the following criteria: + + expect a, b are in bfloat16 + + expect num_tokens_per_expert is a on cpu + """ + + # NOTE: if no tokens are assigned to this expert device, then we just return the hidden states + if torch.count_nonzero(num_local_tokens_per_expert) == 0: + # NOTE: this divide don't receive any tokens + return {"hidden_states": hidden_states} + + merged_states = ops.gmm(hidden_states, self.merged_gate_up_proj, num_local_tokens_per_expert, trans_b=False) + gate_states, up_states = torch.split(merged_states, merged_states.shape[-1] // 2, dim=-1) + hidden_states = self.act(gate_states) * up_states + hidden_states = ops.gmm(hidden_states, self.merged_down_proj, num_local_tokens_per_expert, trans_b=False) + + return {"hidden_states": hidden_states} + + +class Qwen2MoEMLPLayer(nn.Module): + """Mixture of experts Layer for Qwen2 models.""" + + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, + layer_idx: int = 0, + ) -> None: + super().__init__() + moe_config = config.moe_config + self.hidden_size = moe_config.moe_hidden_size + self.intermediate_size = moe_config.moe_intermediate_size + + # MoE specific configurations + num_experts = config.moe_config.num_experts # Total number of experts + num_local_experts = config.moe_config.num_experts // parallel_config.expert_parallel_size # Experts per device + self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) + self.expert_parallel_size = parallel_config.expert_parallel_size + self.num_local_experts = num_local_experts # Experts per device + + # Router for selecting experts + self.router = Router(config, parallel_config, layer_idx) + self.token_dispatcher = AllToAllDispatcher(num_local_experts, num_experts, parallel_context.ep_pg) + self.token_dispatcher._use_torch_permute = config.moe_config.use_torch_permute + + # Enable shared experts if configured + self.enable_shared_expert = config.moe_config.enable_shared_expert + if self.enable_shared_expert: + from nanotron.models.qwen import Qwen2MLP + + self.shared_expert = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + hidden_size=moe_config.shared_expert_hidden_size, + intermediate_size=moe_config.shared_expert_intermediate_size, + ) + # TODO: duplicte the shared expert gate + self.shared_expert_gate = nn.Linear( + self.hidden_size, + 1, + bias=False, + ) # TODO: ensure shared_expert_gate is tied across TP + + self.experts = GroupedMLP(config, parallel_config, ep_pg=parallel_context.ep_pg) + # Whether to recompute MoE layer during backward pass for memory efficiency + self.recompute_layer = parallel_config.recompute_layer + self.ep_pg = parallel_context.ep_pg + self.layer_idx = layer_idx + + def _compute_expert_outputs(self, hidden_states, routing_weights, routing_indices): + ( + dispatched_inputs, + inverse_permute_mapping, + expert_sort_indices, + num_local_tokens_per_expert, + ) = self.token_dispatcher.permute(hidden_states, routing_indices) + + expert_outputs = self.experts(dispatched_inputs, num_local_tokens_per_expert) + output = self.token_dispatcher.unpermute( + expert_outputs["hidden_states"], inverse_permute_mapping, routing_weights, expert_sort_indices + ) + return output, num_local_tokens_per_expert + + def _core_forward(self, hidden_states): + """Core forward logic for MoE layer.""" + # Get top-k routing weights and indices + routing_weights, routing_indices = self.router(hidden_states) # [num_tokens, num_experts_per_token] + + output, num_local_tokens_per_expert = self._compute_expert_outputs( + hidden_states, routing_weights, routing_indices + ) + if self.enable_shared_expert: + shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) + output = output + shared_gate * shared_expert_output + + return {"hidden_states": output} + + def _checkpointed_forward(self, hidden_states): + """Apply gradient checkpointing to save memory during training.""" + return CheckpointFunction.apply(self._core_forward, True, hidden_states) + + def forward(self, hidden_states): + """Forward pass for the MoE layer.""" + if self.recompute_layer and self.training: + outputs = self._checkpointed_forward(hidden_states) + else: + outputs = self._core_forward(hidden_states) + + return outputs diff --git a/src/nanotron/nn/moe_utils.py b/src/nanotron/nn/moe_utils.py new file mode 100644 index 000000000..f3acc25b6 --- /dev/null +++ b/src/nanotron/nn/moe_utils.py @@ -0,0 +1,779 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. + +import math +from typing import Optional + +import torch +import torch.distributed as dist + +try: + from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_sort_chunks_by_index, + moe_unpermute, + ) + + fused_permute = moe_permute + fused_unpermute = moe_unpermute + fused_sort_chunks_by_index = moe_sort_chunks_by_index + HAVE_TE = True +except ImportError: + HAVE_TE = False + + +def switch_load_balancing_loss_func( + probs: torch.Tensor, + tokens_per_expert: torch.Tensor, + topk: int, + moe_aux_loss_coeff: float, + sequence_partition_group=None, +): + """Calculate the auxiliary loss for load balancing. + Refer to the Switch Transformer paper (https://arxiv.org/abs/2101.03961) for details. + + Args: + probs (torch.Tensor): Softmax probabilities output by the router for each token. + Shape in [num_tokens, num_experts]. + tokens_per_expert (torch.Tensor): Number of tokens assigned to each expert. + Shape in [num_experts] + topk (int): The number of experts selected for each token. + moe_aux_loss_coeff (float): The coefficient for the auxiliary loss. + sequence_partition_group (optional): The parallel group over which the sequence is + partitioned. If None, no partitioning is applied. + Defaults to None. + + Returns: + torch.Tensor: The auxiliary loss for load balancing. + """ + num_sub_sequence = 1 + + # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism + # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full + # sequence. + if sequence_partition_group is not None: + # We can keep `aggregated_probs_per_expert` local since we don't need the gradient for + # `tokens_per_expert`, saving one allreduce operation for `aggregated_probs_per_expert`. + num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group) + torch.distributed.all_reduce(tokens_per_expert, group=sequence_partition_group) + + num_tokens = probs.shape[0] * num_sub_sequence + num_experts = probs.shape[1] + + # The formula of aux_loss: aux_loss = sum((probs_per_expert/num_tokens) * + # (tokens_per_expert/(num_tokens*topk))) * num_experts * moe_aux_loss_coeff. + # This can be simplified to fuse the division and multiplication operations. + aggregated_probs_per_expert = probs.sum(dim=0) + aux_loss = torch.sum(aggregated_probs_per_expert * tokens_per_expert) * ( + num_experts * moe_aux_loss_coeff / (num_tokens * num_tokens * topk) + ) + return aux_loss + + +def sequence_load_balancing_loss_func( + probs: torch.Tensor, + routing_map: torch.Tensor, + batch_size: int, + seq_length: int, + topk: int, + moe_aux_loss_coeff: float, + sequence_partition_group=None, +): + """ + Calculate the auxiliary loss in sequence-level by computing the loss for each individual sample. + Refer to the DeepSeek-V2 huggingface repo + (https://huggingface.co/deepseek-ai/DeepSeek-V2) for details. + + Args: + probs (torch.Tensor): Softmax probabilities output by the router for each token. + Shape in [num_tokens, num_experts]. + routing_map (torch.Tensor): Mapping of tokens to experts assignment. + Shape in [num_tokens, num_experts]. + batch_size (int): Batch size to process. + seq_length (int): Sequence length to process. + topk (int): Number of experts to route to for each token. + moe_aux_loss_coeff (float): Scaling coefficient for the auxiliary loss. + sequence_partition_group (optional): The parallel group over which the sequence is + partitioned. If None, no partitioning is applied. + Defaults to None. + + Returns: + torch.Tensor: The sequence auxiliary loss for load balancing. + """ + num_sub_sequence = 1 + num_experts = probs.shape[1] + + probs_for_aux_loss = probs.view(seq_length, batch_size, -1) + routing_map = routing_map.view(seq_length, batch_size, -1) + + # If the sequence is partitioned by certain parallelism strategies like Sequence Parallelism + # or Context Parallelism, compute the gradient of the auxiliary loss with respect to the full + # sequence. + if sequence_partition_group is not None: + num_sub_sequence = torch.distributed.get_world_size(sequence_partition_group) + seq_length *= num_sub_sequence + probs_for_aux_loss = gather_from_sequence_parallel_region(probs_for_aux_loss, group=sequence_partition_group) + + cost_coeff = routing_map.sum(dim=0, dtype=torch.float).div_(seq_length * topk / num_experts) + seq_aux_loss = (cost_coeff * probs_for_aux_loss.mean(dim=0)).sum(dim=1).mean() + seq_aux_loss *= moe_aux_loss_coeff + + return seq_aux_loss + + +def z_loss_func(logits, z_loss_coeff): + """Encourages the router's logits to remain small to enhance stability. + Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. + + Args: + logits (torch.Tensor): The logits of the router. + + Returns: + torch.Tensor: The logits after applying the z-loss. + """ + + z_loss = torch.mean(torch.square(torch.logsumexp(logits, dim=-1))) * z_loss_coeff + return z_loss + + +def sinkhorn(cost: torch.Tensor, tol: float = 0.0001): + """Sinkhorn based MoE routing function""" + cost = torch.exp(cost) + d0 = torch.ones(cost.size(0), device=cost.device, dtype=cost.dtype) + d1 = torch.ones(cost.size(1), device=cost.device, dtype=cost.dtype) + + eps = 0.00000001 + error = 1e9 + d1_old = d1 + while error > tol: + d0 = (1 / d0.size(0)) * 1 / (torch.sum(d1 * cost, 1) + eps) + d1 = (1 / d1.size(0)) * 1 / (torch.sum(d0.unsqueeze(1) * cost, 0) + eps) + error = torch.mean(torch.abs(d1_old - d1)) + d1_old = d1 + return d1 * cost * d0.unsqueeze(1) + + +def get_capacity(num_tokens: int, num_experts: int, capacity_factor: float, min_capacity=None): + """ + Calculate the capacity of each expert. + + Args: + num_tokens (int): num of the input tokens. + num_experts (int): num of the experts. + capacity_factor (float): Capacity factor. + min_capacity (int, optional): Minimum capacity. Defaults to None. + + Returns: + Tensor: Capacity of each expert. + """ + capacity = math.ceil((num_tokens / num_experts) * capacity_factor) + if min_capacity is not None and capacity < min_capacity: + capacity = min_capacity + return capacity + + +class MoEAuxLossAutoScaler(torch.autograd.Function): + """An AutoScaler that triggers the backward pass and scales the grad for auxiliary loss.""" + + main_loss_backward_scale: torch.Tensor = torch.tensor(1.0) + + @staticmethod + def forward(ctx, output: torch.Tensor, aux_loss: torch.Tensor): + """Preserve the aux_loss by storing it in the context to avoid garbage collection. + + Args: + output (torch.Tensor): The output tensor. + aux_loss (torch.Tensor): The auxiliary loss tensor. + + Returns: + torch.Tensor: The output tensor. + """ + ctx.save_for_backward(aux_loss) + return output + + @staticmethod + def backward(ctx, grad_output: torch.Tensor): + """Compute and scale the gradient for auxiliary loss.. + + Args: + grad_output (torch.Tensor): The gradient of the output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The gradient of the output, scaled auxiliary loss + gradient. + """ + (aux_loss,) = ctx.saved_tensors + aux_loss_backward_scale = MoEAuxLossAutoScaler.main_loss_backward_scale + scaled_aux_loss_grad = torch.ones_like(aux_loss) * aux_loss_backward_scale + return grad_output, scaled_aux_loss_grad + + @staticmethod + def set_loss_scale(scale: torch.Tensor): + """set the scale of the aux loss. + + Args: + scale (torch.Tensor): The scale value to set. Please ensure that the scale passed in + matches the scale of the main_loss. + """ + MoEAuxLossAutoScaler.main_loss_backward_scale = scale + + +def permute( + tokens, + routing_map, + num_out_tokens: Optional[int] = None, + fused: bool = False, + drop_and_pad: bool = False, +): + """Permute the tokens and probs based on the mask. + Tokens with the same designated expert will be grouped together. + The shape of mask is [tokens, num_experts], it indicates which experts were selected + by each token. + + When drop_and_pad=True, in routing_map, the number of non-zeros in each column equals to + expert capacity. This function exploits this feature to use ops that support cuda graph. + + Args: + tokens (torch.Tensor): The input token tensor, [num_tokens, hidden]. + routing_map (torch.Tensor): The sparse token to expert mapping, [num_tokens, num_experts]. + num_out_tokens (int, optional): The number of output tokens. If None, it's set to + the number of input tokens. + fused (bool, optional): Whether use the fused permute function. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + If set to true, routing_map has a fixed number of non-zeros + in each column. + """ + if fused: + if not HAVE_TE or fused_permute is None: + raise ValueError("fused_permute is not available. Please install TE >= 2.1.0.") + return fused_permute(tokens, routing_map, num_out_tokens) + + num_tokens, hidden = tokens.shape + num_experts = routing_map.shape[1] + if drop_and_pad and num_out_tokens is not None: + capacity = num_out_tokens // num_experts + assert not routing_map.requires_grad + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.to(dtype=torch.int8).T.contiguous() + # use argsort to put indices of all non-zeros in the beginning of list + # and keep the first `capacity` number of indices + sorted_indices = routing_map.argsort(dim=-1, descending=True, stable=True)[:, :capacity].contiguous() + # flatten from [num_experts, capacity] to 1D + sorted_indices = sorted_indices.view(-1) + else: + # mask [num_tokens, num_experts] -> [num_experts, num_tokens] + routing_map = routing_map.bool().T.contiguous() + + # Create a dense expert-to-token mapping from the sparse token-to-expert mapping + token_indices = torch.arange(num_tokens, device=routing_map.device).unsqueeze(0).expand(num_experts, -1) + sorted_indices = token_indices.masked_select(routing_map) + + # use the mapping to permute the tokens + permuted_input = tokens.index_select(0, sorted_indices) + + return permuted_input, sorted_indices + + +def unpermute( + permuted_tokens: torch.Tensor, + sorted_indices: torch.Tensor, + restore_shape: torch.Size, + probs: torch.Tensor = None, + routing_map: torch.Tensor = None, + fused: bool = False, + drop_and_pad: bool = False, +): + """ + Restore the original order of tokens after permutation. If probs are provided, it + will also apply them to the tokens before restoring the order. + + When drop_and_pad=True, the tensors will have the following properties: + - In routing_map, the number of non-zeros in each column equals to expert capacity + - The size of sorted_indices equals to num_experts * capacity, each split of `capacity` + contains the indices of tokens routed to an expert. + This function exploits these features to use ops that support cuda graph. + + Args: + permuted_tokens (torch.Tensor): The permuted token tensor. + sorted_indices (torch.Tensor): The indices used to sort the tokens. + restore_shape (torch.Size): The shape of the unpermuted tensor. + probs (torch.Tensor, optional): The unpermuted probs tensor, + routing_map (torch.Tensor, optional): Token to expert mapping, shape + [num_tokens, num_experts]. + fused (bool, optional): Whether use the fused unpermute function. + drop_and_pad (bool, optional): Whether or not the token dispatcher uses token-drop + and pads the number of tokens to the expert capacity. + + Returns: + torch.Tensor: The tokens restored to their original order. + """ + if fused: + if not HAVE_TE or fused_unpermute is None: + raise ValueError("fused_unpermute is not available. Please install TE >= 2.1.0.") + return fused_unpermute(permuted_tokens, sorted_indices, probs, restore_shape) + + _, hidden = restore_shape + input_dtype = permuted_tokens.dtype + + if probs is not None: + assert routing_map is not None, "Mask must be provided to permute the probs." + if drop_and_pad: + num_experts = routing_map.size(1) + num_permuted_tokens = sorted_indices.size(0) + capacity = num_permuted_tokens // num_experts + num_unpermuted_tokens = probs.size(0) + + # [num_unpermuted_tokens, num_experts] -> num_experts * num_unpermuted_tokens + probs_T_1D = probs.T.contiguous().view(-1) + + # get 1D indices of the probs selected by routing_map + indices_dim0 = torch.arange(num_experts, device=routing_map.device).unsqueeze(-1) + indices_dim1 = sorted_indices.view(num_experts, capacity) + indices_1D = (indices_dim0 * num_unpermuted_tokens + indices_dim1).view(-1) + + # get probs from indices + permuted_probs = probs_T_1D.index_select(0, indices_1D) + else: + permuted_probs = probs.T.contiguous().masked_select(routing_map.T.contiguous()) + # Here may promote permuted_tokens to higher precision (fp32/fp64) if probs is in + # higher precision due to moe_router_dtype being enabled. This can lead to + # additional GPU memory usage. Use --moe-permute-fusion flag to avoid this extra memory + # allocation. + permuted_tokens = permuted_tokens * permuted_probs.unsqueeze(-1) + + # Create an output tensor filled with zeros + output_tokens = torch.zeros(restore_shape, dtype=permuted_tokens.dtype, device=permuted_tokens.device) + # Scatter add the permuted_input back to the original positions + output_tokens.scatter_add_(0, sorted_indices.unsqueeze(1).expand(-1, hidden), permuted_tokens) + return output_tokens.to(dtype=input_dtype) + + +def sort_chunks_by_idxs( + input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False +): + """Split and sort the input tensor based on the split_sizes and sorted indices.""" + if fused: + if not HAVE_TE or fused_sort_chunks_by_index is None: + raise ValueError("fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0.") + return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs) + + input = torch.split(input, split_sizes.tolist(), dim=0) + output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) + return output + + +def group_limited_topk( + scores: torch.Tensor, + topk: int, + num_tokens: int, + num_experts: int, + num_groups: int, + group_topk: int, +): + """Perform top-k routing on a subset of expert groups. + + When using group-limited routing: + 1. Experts are divided into 'moe_router_num_groups' equal-sized groups + 2. For each token, 'moe_router_group_topk' groups are selected based on routing scores + (specifically, the sum of top-2 expert scores within each group) + 3. From these selected groups, 'moe_router_topk' individual experts are chosen + + Two common use cases: + - Device-limited routing: Set 'moe_router_num_groups' equal to expert parallel size (EP) + to limit each token to experts on a subset of devices + (See DeepSeek-V2: https://arxiv.org/pdf/2405.04434) + + - Node-limited routing: Set 'moe_router_num_groups' equal to number of nodes in EP group + to limit each token to experts on a subset of nodes + (See DeepSeek-V3: https://arxiv.org/pdf/2412.19437) + + Args: + scores (torch.Tensor): Softmax scores generated by the router. + topk (int): The number of experts to select for each token. + num_tokens (int): The number of tokens. + num_experts (int): The number of experts. + num_groups (int): Number of groups for routed experts. + group_topk (int): Number of groups selected for each token. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Probs and indices tensor. + """ + # Organize the experts into groups + group_scores = scores.view(num_tokens, num_groups, -1).topk(2, dim=-1)[0].sum(dim=-1) + group_idx = torch.topk(group_scores, k=group_topk, dim=-1, sorted=False)[1] + group_mask = torch.zeros_like(group_scores) + group_mask.scatter_(1, group_idx, 1) + + # Mask the experts based on selection groups + score_mask = ( + group_mask.unsqueeze(-1).expand(num_tokens, num_groups, num_experts // num_groups).reshape(num_tokens, -1) + ) + + masked_scores = scores.masked_fill(~score_mask.bool(), float("-inf")) + probs, top_indices = torch.topk(masked_scores, k=topk, dim=-1) + + return probs, top_indices + + +def topk_softmax_with_capacity( + logits: torch.Tensor, + topk: int, + capacity_factor: Optional[float] = None, + pad_to_capacity: bool = False, + drop_policy: str = "probs", + use_pre_softmax: bool = False, + num_groups: Optional[int] = None, + group_topk: Optional[int] = None, + scaling_factor: Optional[float] = None, + deterministic_mode: bool = False, + score_function: str = "softmax", + expert_bias: Optional[torch.Tensor] = None, +): + """Apply capacity and padding to the top-k selection. + Args: + logits (torch.Tensor): Logits tensor. + topk (int): The number of experts to select for each token. + capacity_factor (float): The capacity factor of each expert. Will drop tokens if the number + of tokens exceeds the capacity. + pad_to_capacity (bool): Whether to need padding in token drop mode. The probs for padded + tokens will be 0. + drop_policy (str): The policy to drop tokens. Can be either "prob" or "position". + If "prob", the tokens with the lowest probabilities will be dropped. + If "position", tokens at the end of each batch will be dropped. + use_pre_softmax (bool): Whether to apply softmax before top-k selection. + num_groups (int): Number of groups for routed experts. + group_topk (int): Number of selected groups for each token. + scaling_factor (float): Scaling factor of routing score in top-k selection. + deterministic_mode (bool): Deprecated. + score_function (str): The score function to use. Can be either "softmax" or "sigmoid". + expert_bias (torch.Tensor): The bias added to logits for expert routing. + + Returns: + Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + - routing_probs (torch.Tensor): A tensor of shape [num_tokens, num_experts] containing + the routing probabilities for each token to each expert. + - routing_map (torch.Tensor): A mask tensor of shape [num_tokens, num_experts] + indicating which experts were selected for each token. True values represent + the selected experts. + - tokens_per_expert (torch.Tensor): A tensor of shape [num_experts] containing + the number of local tokens assigned to each expert before dropping and padding. + """ + assert logits.dim() == 2, f"Expected 2D logits [num_tokens, num_experts], got {logits.dim()}." + num_tokens, num_experts = logits.shape + + def compute_topk(scores, topk, num_groups=None, group_topk=None): + if group_topk: + return group_limited_topk( + scores=scores, + topk=topk, + num_tokens=num_tokens, + num_experts=num_experts, + num_groups=num_groups, + group_topk=group_topk, + ) + else: + return torch.topk(scores, k=topk, dim=1) + + if score_function == "softmax": + if use_pre_softmax: + scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) + else: + scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) + elif score_function == "sigmoid": + scores = torch.sigmoid(logits) + if expert_bias is not None: + scores_for_routing = scores + expert_bias + _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) + scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + else: + scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) + probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + else: + raise ValueError(f"Invalid score_function: {score_function}") + + if scaling_factor: + probs = probs * scaling_factor + + # TODO Try using element-wise operations instead of scatter? + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) + topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() + tokens_per_expert = topk_map.sum(dim=0) + + if capacity_factor is None: + # TopK without capacity + return topk_masked_gates, topk_map, tokens_per_expert + else: + # TopK with capacity + expert_capacity = get_capacity( + num_tokens=num_tokens * topk, num_experts=num_experts, capacity_factor=capacity_factor + ) + + # Maskout exceeded tokens + if drop_policy == "probs": + _, capacity_indices = torch.topk(topk_masked_gates, k=expert_capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool() + elif drop_policy == "position": + _, capacity_indices = torch.topk(topk_map.int(), k=expert_capacity, dim=0, sorted=False) + capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1).bool() + else: + raise ValueError(f"Invalid drop_policy: {drop_policy}") + + if pad_to_capacity: + final_map = capacity_mask + final_probs = topk_masked_gates * final_map + else: + # Get exceed mask and maskout exceeded probs and indices + final_map = torch.logical_and(topk_map, capacity_mask) + final_probs = topk_masked_gates * final_map + return final_probs, final_map, tokens_per_expert + + +def save_to_aux_losses_tracker( + name: str, + loss: torch.Tensor, + layer_number: int, + num_layers: int, + reduce_group: torch.distributed.ProcessGroup = None, + avg_group: torch.distributed.ProcessGroup = None, +): + """Save the auxiliary loss for logging. + Args: + name (str): The name of the loss. + loss (torch.Tensor): The loss tensor. + layer_number (int): Layer index of the loss. + num_layers (int): The number of total layers. + reduce_group (torch.distributed.ProcessGroup): The group for reducing the loss. + mean_group (torch.distributed.ProcessGroup): The group for averaging the loss. + """ + # Skip aux loss logging if layer_number is None. + if layer_number is None: + return + + tracker = parallel_state.get_moe_layer_wise_logging_tracker() + if name not in tracker: + tracker[name] = {} + tracker[name]["values"] = torch.zeros(num_layers, device=loss.device) + tracker[name]["values"][layer_number - 1] += loss.detach() # Aggregate the loss for the layer. + tracker[name]["reduce_group"] = reduce_group + tracker[name]["avg_group"] = avg_group + + +def clear_aux_losses_tracker(): + """Clear the auxiliary losses.""" + tracker = parallel_state.get_moe_layer_wise_logging_tracker() + for name in tracker: + tracker[name]["values"].zero_() + tracker[name]["reduce_group"] = None + tracker[name]["avg_group"] = None + + +def reduce_aux_losses_tracker_across_ranks(): + """Collect and reduce the auxiliary losses across ranks.""" + tracker = parallel_state.get_moe_layer_wise_logging_tracker() + for name in tracker: + values = tracker[name]["values"] + # Collect aux losses across PP. + torch.distributed.all_reduce(values, group=parallel_state.get_pipeline_model_parallel_group()) + # Reduce aux losses across ranks. + if tracker[name].get("reduce_group") is not None: + torch.distributed.all_reduce(values, group=tracker[name].get("reduce_group")) + if tracker[name].get("avg_group") is not None: + torch.distributed.all_reduce(values, group=tracker[name]["avg_group"], op=torch.distributed.ReduceOp.AVG) + + +def track_moe_metrics(loss_scale, iteration, writer, wandb_writer=None, total_loss_dict=None, per_layer_logging=False): + """Track the MoE metrics for logging.""" + # Aux loss logging + reduce_aux_losses_tracker_across_ranks() + tracker = parallel_state.get_moe_layer_wise_logging_tracker() + if writer is not None: + aux_losses = {k: v["values"].float() * loss_scale for k, v in tracker.items()} + for name, loss_list in aux_losses.items(): + if total_loss_dict is not None: + if name not in total_loss_dict: + total_loss_dict[name] = loss_list.mean() + else: + total_loss_dict[name] += loss_list.mean() + + # currently when using add_scalars, + # torch.utils.add_scalars makes each timer its own run, which + # pollutes the runs list, so we just add each as a scalar + writer.add_scalar(name, loss_list.mean(), iteration) + if per_layer_logging: + for i, loss in enumerate(loss_list.tolist()): + writer.add_scalar(f"moe/{name}_layer_{i}", loss, iteration) + + # W&B logging lacks support for logging multiple scalars simultaneously. + # As a workaround, we log each scalar individually first, then we can create + # a custom panel to manually group them to a single plot. + if wandb_writer: + wandb_writer.log({f"{name}": loss_list.mean()}, iteration) + if per_layer_logging: + wandb_writer.log( + {f"moe/{name}_layer_{i}": loss for i, loss in enumerate(loss_list.tolist())}, + iteration, + ) + + clear_aux_losses_tracker() + + +def get_updated_expert_bias(tokens_per_expert, expert_bias, expert_bias_update_rate): + """Update expert bias for biased expert routing. See https://arxiv.org/abs/2408.15664v1# + + Args: + tokens_per_expert (torch.Tensor): The number of tokens assigned to each expert. + expert_bias (torch.Tensor): The bias for each expert. + expert_bias_udpate_rate (float): The update rate for the expert bias. + """ + with torch.no_grad(): + # All Reduce Across TPxCPxDP group + torch.distributed.all_reduce( + tokens_per_expert, + group=parallel_state.get_tensor_and_data_parallel_group(with_context_parallel=True), + ) + average_tokens = tokens_per_expert.sum(dim=-1, keepdim=True) / tokens_per_expert.shape[-1] + offset = average_tokens - tokens_per_expert + updated_expert_bias = expert_bias + torch.sign(offset) * expert_bias_update_rate + return updated_expert_bias + + +# class _Gather(torch.autograd.Function): +# @staticmethod +# def forward(ctx: Any, input: torch.Tensor, dim: int, parallel_context: ParallelContext) -> torch.Tensor: +# ctx.dim = dim +# ctx.parallel_context = parallel_context + +# dist.all_gather(tensor_list=tensor_list, tensor=tensor, group=group) + +# return all_gather(input, dim=dim, async_op=False, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR) + +# @staticmethod +# def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]: +# dim = ctx.dim +# parallel_context = ctx.parallel_context + +# return ( +# scatter(grad, dim=dim, parallel_context=parallel_context, parallel_mode=ParallelMode.TENSOR), +# None, +# None, +# ) + + +# def gather_from_sequence_parallel_region( +# tensor: torch.Tensor, +# dim: int = 0, +# group = None, +# ) -> torch.Tensor: +# """All gather tensors from all processes in parallel group. + +# Args: +# input (torch.Tensor): The tensor you want to gather. +# dim (int, optional): The dimension along which to gather the tensor.. Defaults to 0. +# group (Optional[dist.ProcessGroup], optional): _description_. Defaults to None. +# """ +# world_size = dist.get_world_size(group) + +# if world_size == 1: +# return tensor + +# tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] +# dist.all_gather(tensor_list=tensor_list, tensor=tensor, group=group) + +# if tensor.dim() == 0: +# tensor_list = [tensor.unsqueeze(dim=0) for tensor in tensor_list] + +# tensor_list = torch.cat(tensor_list, dim=dim) + +# return tensor_list + + +### Distributed Primitives +from typing import Any, Tuple + + +def scatter( + tensor: torch.Tensor, + dim: int, + group: dist.ProcessGroup, +) -> torch.Tensor: + """Scatter tensors to all ranks in parallel group.""" + world_size = dist.get_world_size(group) + rank = dist.get_rank(group) + + if world_size == 1: + return tensor + + assert tensor.size(dim) % world_size == 0 + + tensor_list = torch.chunk(tensor, world_size, dim=dim) + return tensor_list[rank] + + +def all_gather( + tensor: torch.Tensor, + dim: int, + group: dist.ProcessGroup, +) -> torch.Tensor: + world_size = dist.get_world_size(group) + + if world_size == 1: + return tensor + + tensor_list = [torch.empty_like(tensor) for _ in range(world_size)] + dist.all_gather(tensor_list=tensor_list, tensor=tensor, group=group) + + if tensor.dim() == 0: + tensor_list = [tensor.unsqueeze(dim=0) for tensor in tensor_list] + + tensor_list = torch.cat(tensor_list, dim=dim) + + return tensor_list + + +class _Gather(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, input: torch.Tensor, dim: int, group: dist.ProcessGroup) -> torch.Tensor: + ctx.dim = dim + ctx.group = group + + return all_gather(input, dim=dim, group=group) + + @staticmethod + def backward(ctx: Any, grad: torch.Tensor) -> Tuple[torch.Tensor, None, None]: + dim = ctx.dim + group = ctx.group + + return ( + scatter(grad, dim=dim, group=group), + None, + None, + ) + + +class _Scatter(torch.autograd.Function): + @staticmethod + def forward(ctx: Any, input: torch.Tensor, dim: int, group: dist.ProcessGroup) -> torch.Tensor: + ctx.dim = dim + ctx.group = group + return scatter(input, dim=dim, group=group) + + @staticmethod + def backward(ctx: Any, grad_output: torch.Tensor) -> Tuple[torch.Tensor, None, None]: + dim = ctx.dim + group = ctx.group + + return ( + all_gather(grad_output, dim=dim, group=group), + None, + None, + ) + + +def gather_to_tensor_group(input: torch.Tensor, dim: int, group: dist.ProcessGroup): + return _Gather.apply(input, dim, group) + + +def scatter_to_tensor_group(input: torch.Tensor, dim: int, group: dist.ProcessGroup): + return _Scatter.apply(input, dim, group) diff --git a/src/nanotron/nn/rotary.py b/src/nanotron/nn/rotary.py index 4e78849f9..109e59f03 100644 --- a/src/nanotron/nn/rotary.py +++ b/src/nanotron/nn/rotary.py @@ -1,6 +1,11 @@ import torch from flash_attn.layers.rotary import apply_rotary_emb as flash_apply_rotary_emb from torch import nn +from flash_attn.layers.rotary import RotaryEmbedding as OrigFlashRotaryEmbedding +from einops import rearrange +from nanotron import logging +from nanotron.logging import warn_once +logger = logging.get_logger(__name__) class RotaryEmbedding(nn.Module): @@ -140,3 +145,77 @@ def apply_rotary_pos_emb(self, tensor, freqs, multi_latent_attention=False, msca if pass_through_part is not None and pass_through_part.shape[-1] > 0: return torch.cat((rotated_tensor, pass_through_part), dim=-1) return rotated_tensor + +class FlashRotaryEmbedding(OrigFlashRotaryEmbedding): + + def __init__( + self, + dim: int, + base=10000.0, + interleaved=False, + scale_base=None, + pos_idx_in_fp32=True, + device=None, + seq_len_interpolation_factor=None, + ): + super().__init__( + dim, + base, + interleaved, + scale_base, + pos_idx_in_fp32, + device, + ) + self.seq_len_interpolation_factor = seq_len_interpolation_factor + + def _update_cos_sin_cache(self, seqlen, device=None, dtype=None): + # Reset the tables if the sequence length has changed, + # if we're on a new device (possibly due to tracing for instance), + # or if we're switching from inference mode to training + if ( + seqlen > self._seq_len_cached + or self._cos_cached is None + or self._cos_cached.device != device + or self._cos_cached.dtype != dtype + or (self.training and self._cos_cached.is_inference()) + ): + self._seq_len_cached = seqlen + # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16 + # And the output of arange can be quite large, so bf16 would lose a lot of precision. + # However, for compatibility reason, we add an option to use the dtype of self.inv_freq. + if self.pos_idx_in_fp32: + t = torch.arange(seqlen, device=device, dtype=torch.float32) + # We want fp32 here as well since inv_freq will be multiplied with t, and the output + # will be large. Having it in bf16 will lose a lot of precision and cause the + # cos & sin output to change significantly. + # We want to recompute self.inv_freq if it was not loaded in fp32 + if self.inv_freq.dtype != torch.float32: + inv_freq = self._compute_inv_freq(device=device) + else: + inv_freq = self.inv_freq + else: + t = torch.arange(seqlen, device=device, dtype=self.inv_freq.dtype) + inv_freq = self.inv_freq + + # fixed linear scaling + if self.seq_len_interpolation_factor is not None: + warn_once(f"seq_len_interpolation_factor is set to {self.seq_len_interpolation_factor}", logger, rank=0) + t *= 1 / self.seq_len_interpolation_factor + + # Don't do einsum, it converts fp32 to fp16 under AMP + # freqs = torch.einsum("i,j->ij", t, self.inv_freq) + freqs = torch.outer(t, inv_freq) + if self.scale is None: + self._cos_cached = torch.cos(freqs).to(dtype) + self._sin_cached = torch.sin(freqs).to(dtype) + else: + power = ( + torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) + - seqlen // 2 + ) / self.scale_base + scale = self.scale.to(device=power.device) ** rearrange(power, "s -> s 1") + # We want the multiplication by scale to happen in fp32 + self._cos_cached = (torch.cos(freqs) * scale).to(dtype) + self._sin_cached = (torch.sin(freqs) * scale).to(dtype) + self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype) + self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype) \ No newline at end of file diff --git a/src/nanotron/nn/te_moe.py b/src/nanotron/nn/te_moe.py new file mode 100644 index 000000000..314abbf43 --- /dev/null +++ b/src/nanotron/nn/te_moe.py @@ -0,0 +1,1308 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import torch +from torch import nn +from torch.utils.checkpoint import CheckpointFunction + +from nanotron import distributed as dist +from nanotron import logging +from nanotron.config import ParallelismArgs +from nanotron.config.models_config import Qwen2Config +from nanotron.models.base import ignore_init_on_device_and_dtype +from nanotron.nn.activations import ACT2FN +from nanotron.nn.moe import GroupedMLP +from nanotron.nn.moe_utils import scatter_to_tensor_group +from nanotron.parallel.context import ParallelContext +from nanotron.parallel.sharded_parameters import SplitConfig, mark_all_parameters_in_module_as_sharded +from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import ( + all_to_all, + differentiable_all_gather, + differentiable_reduce_scatter_sum, +) +from nanotron.parallel.tensor_parallel.enum import TensorParallelLinearMode + +from .moe_utils import get_capacity + +logger = logging.get_logger(__name__) +from .moe_utils import permute, topk_softmax_with_capacity, unpermute + +try: + import grouped_gemm.ops as ops +except ImportError: + raise RuntimeError( + "Grouped GEMM is not available. Please run `pip install --no-build-isolation git+https://github.com/fanshiqing/grouped_gemm@main` (takes less than 5 minutes)" + ) + +try: + import transformer_engine as te +except ImportError: + raise RuntimeError( + "Transformer Engine is not available. Please run `pip install --no-build-isolation transformer_engine[pytorch]`" + ) + + +@dataclass +class MoELogging: + """ + num_local_tokens: List[torch.Tensor]: The number of tokens per local expert per layer + """ + + num_local_tokens: List[torch.Tensor] + + +def fake_gather_to_tensor_group(tensor, dim, group): + """Fake gather that returns random data with correct shape for performance testing""" + # Calculate expected size after all_gather + gather_size = list(tensor.size()) + gather_size[dim] *= group.size() + + # Create uninitialized tensor with same device/dtype as input + fake_tensor = torch.empty(gather_size, dtype=tensor.dtype, device=tensor.device) + + # Fill with random values (mimic real data but no communication) + # fake_tensor.normal_() + return fake_tensor + + +class Qwen2MoEMLPLayer(nn.Module): + """Mixture of experts Layer for Qwen2 models.""" + + def __init__( + self, + config: Qwen2Config, + parallel_config: Optional[ParallelismArgs], + parallel_context: ParallelContext, + layer_idx: int = 0, + ) -> None: + super().__init__() + moe_config = config.moe_config + self.hidden_size = moe_config.moe_hidden_size + self.intermediate_size = moe_config.moe_intermediate_size + + # MoE specific configurations + num_local_experts = config.moe_config.num_experts // parallel_config.expert_parallel_size # Experts per device + self.num_experts_per_token = config.moe_config.top_k # Number of experts used per token (top-k) + self.expert_parallel_size = parallel_config.expert_parallel_size + self.num_local_experts = num_local_experts # Experts per device + + # Get TP mode configuration + + self.config = config + self.expert_parallel_size = parallel_config.expert_parallel_size + assert self.expert_parallel_size > 0, "Expected non-negative expert parallel size" + + assert self.config.moe_config.num_experts % self.expert_parallel_size == 0 + self.num_local_experts = self.config.moe_config.num_experts // self.expert_parallel_size + local_expert_indices_offset = dist.get_rank(parallel_context.ep_pg) * self.num_local_experts + + self.use_shared_expert = self.config.moe_config.enable_shared_expert + # self.shared_expert_overlap = self.config.moe_config.shared_expert_overlap + + self.local_expert_indices = [local_expert_indices_offset + i for i in range(self.num_local_experts)] + assert all((x < self.config.moe_config.num_experts for x in self.local_expert_indices)) + self.router = None + self.experts = None + self.shared_experts = None + self.token_dispatcher = None + self.layer_number = layer_idx + + # Router for selecting experts + # self.router = Router(config, parallel_config, layer_idx) + self.router = TopKRouter(config=self.config, parallel_context=parallel_context) + + # self.token_dispatcher = AllToAllDispatcher(num_local_experts, num_experts, parallel_context.ep_pg) + if self.config.moe_config.token_dispatcher_type == "allgather": + self.token_dispatcher = MoEAllGatherTokenDispatcher( + num_local_experts=self.num_local_experts, + local_expert_indices=self.local_expert_indices, + config=self.config, + parallel_context=parallel_context, + ) + elif self.config.moe_config.token_dispatcher_type == "alltoall": + self.token_dispatcher = MoEAlltoAllTokenDispatcher( + num_local_experts=self.num_local_experts, + local_expert_indices=self.local_expert_indices, + config=self.config, + parallel_context=parallel_context, + ) + else: # TODO: alltoall + raise ValueError(f"Unsupported token dispatcher type: {self.config.moe_config.token_dispatcher_type}") + + # Enable shared experts if configured + self.enable_shared_expert = config.moe_config.enable_shared_expert + if self.enable_shared_expert: # TODO: check shared + from nanotron.models.qwen import Qwen2MLP + + self.shared_expert = Qwen2MLP( + config=config, + parallel_config=parallel_config, + tp_pg=parallel_context.tp_pg, + hidden_size=moe_config.shared_expert_hidden_size, + intermediate_size=moe_config.shared_expert_intermediate_size, + ) + # TODO: duplicte the shared expert gate + self.shared_expert_gate = nn.Linear( + self.hidden_size, + 1, + bias=False, + ) # TODO: ensure shared_expert_gate is tied across TP + + if self.config.moe_config.grouped_gemm_imple == "transformer_engine": + self.experts = TEGroupedMLP(self.num_local_experts, config, parallel_config, parallel_context) + elif self.config.moe_config.grouped_gemm_imple == "megablock_grouped_gemm": + self.experts = GroupedMLP(config, parallel_config, ep_pg=parallel_context.ep_pg) + + # Whether to recompute MoE layer during backward pass for memory efficiency + self.recompute_layer = parallel_config.recompute_layer + self.ep_pg = parallel_context.ep_pg + self.layer_idx = layer_idx + + def _compute_expert_outputs(self, hidden_states, routing_weights, routing_indices): + ( + dispatched_inputs, + inverse_permute_mapping, + expert_sort_indices, + num_local_tokens_per_expert, + ) = self.token_dispatcher.permute(hidden_states, routing_indices) + + expert_outputs = self.experts(dispatched_inputs, num_local_tokens_per_expert) + output = self.token_dispatcher.unpermute( + expert_outputs["hidden_states"], inverse_permute_mapping, routing_weights, expert_sort_indices + ) + return output, num_local_tokens_per_expert + + def _core_forward(self, hidden_states, moe_logging: Optional[MoELogging]): + """Core forward logic for MoE layer.""" + # Get top-k routing weights and indices + # routing_weights, routing_indices = self.router(hidden_states) # [num_tokens, num_experts_per_token] + probs, indices = self.router(hidden_states) + + # output, num_local_tokens_per_expert = self._compute_expert_outputs( + # hidden_states, routing_weights, routing_indices + # ) + (dispatched_input, tokens_per_expert) = self.token_dispatcher.token_permutation(hidden_states, probs, indices) + + # NOTE: if there are no tokens routed to this expert device + # then we just skip the computation + if dispatched_input.shape[0] == 0: + expert_output = hidden_states + output = expert_output + mlp_bias = None + else: + if self.config.moe_config.grouped_gemm_imple == "transformer_engine": + expert_output, mlp_bias = self.experts(dispatched_input, tokens_per_expert) + elif self.config.moe_config.grouped_gemm_imple == "megablock_grouped_gemm": + expert_output = self.experts(dispatched_input, tokens_per_expert) + expert_output = expert_output["hidden_states"] + mlp_bias = None + + output, mlp_bias = self.token_dispatcher.token_unpermutation(expert_output, mlp_bias) + + if self.enable_shared_expert: + shared_expert_output = self.shared_expert(hidden_states=hidden_states)["hidden_states"] + shared_gate = torch.sigmoid(self.shared_expert_gate(hidden_states)) + output = output + shared_gate * shared_expert_output + + if moe_logging is not None: + moe_logging[self.layer_idx, :] = tokens_per_expert + + return {"hidden_states": output} + + def _checkpointed_forward(self, hidden_states): + """Apply gradient checkpointing to save memory during training.""" + return CheckpointFunction.apply(self._core_forward, True, hidden_states) + + def forward(self, hidden_states, moe_logging: Optional[MoELogging] = None): + """Forward pass for the MoE layer.""" + if self.recompute_layer and self.training: + outputs = self._checkpointed_forward(hidden_states, moe_logging) + else: + outputs = self._core_forward(hidden_states, moe_logging) + + return outputs + + +from .moe_utils import ( + MoEAuxLossAutoScaler, + sinkhorn, + switch_load_balancing_loss_func, + topk_softmax_with_capacity, + z_loss_func, +) + + +class TopKRouter(nn.Module): + """Route each token to the top-k experts.""" + + def __init__(self, config: Qwen2Config, parallel_context: ParallelContext) -> None: + """Initialize the zero token dropping router.""" + super().__init__() + self.config = config + self.num_experts = self.config.moe_config.num_experts + self.moe_aux_loss_func = None + self.layer_number = None + + # float32 routing weights + # NOTE: qwen keep the routing weights in float32 + # https://github.com/huggingface/transformers/blob/27a25bee4fcb865e8799ba026f1ea4455f2cca98/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py#L608 + with ignore_init_on_device_and_dtype(): + self.weight = nn.Parameter( + torch.randn(self.num_experts, config.hidden_size, dtype=torch.float32, device="cuda") + ) + assert self.weight.dtype == torch.float32 + + self.topk = self.config.moe_config.top_k + self.routing_type = self.config.moe_config.router_load_balancing_type + self.input_jitter = None + self.tp_size = parallel_context.tensor_parallel_size + + def sinkhorn_load_balancing(self, logits: torch.Tensor): + """Apply sinkhorn routing to the logits tensor. + + Args: + logits (torch.Tensor): The logits tensor. + + Returns: + torch.Tensor: The logits tensor after applying sinkhorn routing. + """ + + def _sinkhorn_activation(logits): + if self.topk == 1: + logits = torch.sigmoid(logits) + else: # k > 1 + logits = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + return logits + + assert self.config.moe_config.router_aux_loss_coef == 0, "Sinkhorn routing does not support aux loss." + if self.training: + with torch.no_grad(): + norm_logits = sinkhorn(logits.to(dtype=torch.float32)) # explicit fp32 conversion for stability + _, indices = torch.topk(norm_logits, k=self.topk, dim=1) + logits = _sinkhorn_activation(logits) + scores = torch.gather(logits, 1, indices) + else: + logits = _sinkhorn_activation(logits) + scores, indices = torch.topk(logits, k=self.topk, dim=1) + return scores, indices + + def aux_loss_load_balancing(self, logits: torch.Tensor): + """Apply loss-based load balancing to the logits tensor. + + Args: + logits (torch.Tensor): the logits tensor after gating, shape: [num_tokens, num_experts]. + + Returns: + probs (torch.Tensor): the probabilities tensor after load balancing. + indices (torch.Tensor): the indices tensor after top-k selection. + """ + probs, indices, tokens_per_expert = topk_softmax_with_capacity( + logits, + self.topk, + capacity_factor=self.config.moe_config.moe_expert_capacity_factor, + pad_to_capacity=self.config.moe_config.moe_pad_expert_input_to_capacity, + drop_policy=self.config.moe_config.moe_token_drop_policy, + use_pre_softmax=self.config.moe_config.moe_router_pre_softmax, + num_groups=self.config.moe_config.moe_router_num_groups, + group_topk=self.config.moe_config.moe_router_group_topk, + scaling_factor=self.config.moe_config.moe_router_topk_scaling_factor, + score_function=self.config.moe_config.moe_router_score_function, + expert_bias=False, # TODO: no bias + ) + + # Apply load balancing loss + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) + + if self.config.moe_config.router_aux_loss_coef > 0: + probs = self.apply_load_balancing_loss(scores, tokens_per_expert, activation=probs) + + return probs, indices + + def apply_load_balancing_loss( + self, + probs: torch.Tensor, + num_local_tokens_per_expert: torch.Tensor, + activation: torch.Tensor, + ): + """Applies auxiliary loss to the MoE layer. + + Args: + probs (torch.Tensor): The probs output by the router for each token. [num_tokens, num_experts] + num_local_tokens_per_expert (torch.Tensor): The number of tokens per expert. [num_experts] + activation (torch.Tensor): The activation tensor to attach the gradient function to. + + Returns: + torch.Tensor: The activation tensor with the attached gradient function. + """ + moe_aux_loss_coeff = self.config.moe_config.router_aux_loss_coef / self.tp_size + aux_loss = switch_load_balancing_loss_func(probs, num_local_tokens_per_expert, self.topk, moe_aux_loss_coeff) + # save_to_aux_losses_tracker( + # "load_balancing_loss", + # aux_loss / moe_aux_loss_coeff, + # self.layer_number, + # self.config.num_layers, + # ) + activation = MoEAuxLossAutoScaler.apply(activation, aux_loss) + return activation + + def apply_z_loss(self, logits): + """Encourages the router's logits to remain small to enhance stability. + Please refer to the ST-MoE paper (https://arxiv.org/pdf/2202.08906.pdf) for details. + + Args: + logits (torch.Tensor): The logits of the router. + + Returns: + torch.Tensor: The logits after applying the z-loss. + """ + if self.config.moe_config.moe_z_loss_coeff is not None: + moe_z_loss_coeff = self.config.moe_config.moe_z_loss_coeff / self.tp_size + z_loss = z_loss_func(logits, moe_z_loss_coeff) + logits = MoEAuxLossAutoScaler.apply(logits, z_loss) + # save_to_aux_losses_tracker( + # "z_loss", + # z_loss / self.config.moe_z_loss_coeff, + # self.layer_number, + # self.config.num_layers, + # ) + return logits + + def apply_input_jitter(self, input: torch.Tensor): + """Add noise to the input tensor. + Refer to https://arxiv.org/abs/2101.03961. + + Args: + input (Tensor): Input tensor. + + Returns: + Tensor: Jittered input. + """ + if self.config.moe_config.input_jitter_eps is not None: + eps = self.config.moe_config.input_jitter_eps + if self.input_jitter is None: + self.input_jitter = torch.distributions.uniform.Uniform( + torch.tensor(1.0 - eps, device=input.device), + torch.tensor(1.0 + eps, device=input.device), + ).rsample + return input * self.input_jitter(input.shape) + else: + return input + + def gating(self, input: torch.Tensor): + """Forward pass of the router gate. + + Args: + input (torch.Tensor): Input tensor. + + Returns: + torch.Tensor: Logits tensor. + """ + if self.weight.device.type == "cpu": + # move weights to GPU + self.weight.data = self.weight.data.to(device=torch.cuda.current_device()) + # Convert to specified datatype for routing computation if enabled + router_dtype = input.dtype + if self.config.moe_config.moe_router_dtype == "fp32": + router_dtype = torch.float32 + elif self.config.moe_config.moe_router_dtype == "fp64": + router_dtype = torch.float64 + logits = torch.nn.functional.linear(input.to(router_dtype), self.weight.to(router_dtype)) + return logits + + def routing(self, logits: torch.Tensor): + """Top-k routing function + + Args: + logits (torch.Tensor): Logits tensor after gating. + + Returns: + probs (torch.Tensor): the probabilities tensor after load balancing. + indices (torch.Tensor): the indices tensor after top-k selection. + """ + logits = logits.view(-1, self.config.moe_config.num_experts) + + # Apply Z-Loss + logits = self.apply_z_loss(logits) + + # if self.tp_size > 1 and self.config.moe_config.moe_token_dispatcher_type == "alltoall": + # # Gather the logits from the TP region + # raise NotImplementedError("fix TP in router") + # logits = gather_from_sequence_parallel_region(logits) + + if self.routing_type == "sinkhorn": + scores, indices = self.sinkhorn_load_balancing(logits) + elif self.routing_type == "aux_loss": + scores, indices = self.aux_loss_load_balancing(logits) + elif self.routing_type == "none": + # A naive top-k routing without load balancing + scores, indices, _ = topk_softmax_with_capacity( + logits, + self.topk, + capacity_factor=self.config.moe_config.moe_expert_capacity_factor, + pad_to_capacity=self.config.moe_config.moe_pad_expert_input_to_capacity, + drop_policy=self.config.moe_config.moe_token_drop_policy, + ) + else: + raise ValueError(f"Unsupported MoE routing type: {self.routing_type}") + + return scores, indices + + def forward(self, input: torch.Tensor): + """ + Forward pass of the router. + + Args: + input (torch.Tensor): Input tensor. + """ + self.hidden = input.shape[-1] + + # Apply input jitter + input = self.apply_input_jitter(input) + logits = self.gating(input) + logits = logits.view(-1, self.config.moe_config.num_experts) + + scores, indices = self.routing(logits) + + return scores, indices + + +class TEGroupedLinear(te.pytorch.GroupedLinear): + """ + Wrapper for the Transformer-Engine's `GroupedLinear` layer. + + Note that if Megatron's parallel_state has not been initialized + yet, the tp_group passed to TE will be None and must be set later + via set_tensor_parallel_group(). + """ + + def __init__( + self, + num_gemms: int, + input_size: int, + output_size: int, + parallel_mode: Optional[str], + config: Qwen2Config, + parallel_config: ParallelismArgs, + parallel_context: ParallelContext, + bias: bool, + skip_bias_add: bool, + is_expert: bool = False, + tp_comm_buffer_name: Optional[str] = None, + ): + self.config = config + self.parallel_config = parallel_config + sequence_parallel = parallel_config.tp_mode == TensorParallelLinearMode.REDUCE_SCATTER + + # TE returns a zero length Tensor when bias=False and + # return_bias=True, but we prefer None. So in that case we + # tell TE to not return the bias, and return None + # ourselves. This way our forward always returns two values + # and we don't have to deal with the zero length Tensor. + self.te_return_bias = skip_bias_add and bias + self.is_first_microbatch = True + self.disable_parameter_transpose_cache = self.config.moe_config.disable_parameter_transpose_cache + + self.expert_parallel = parallel_context.expert_parallel_size > 1 + + # The comms between TP and EP group is explicitly handled by MoE token dispatcher. + # So we disable comms by making TE agnostic of model parallel. + if is_expert: + tp_group = parallel_context.ep_tp_pg + else: + tp_group = parallel_context.tp_pg + self.explicit_expert_comm = is_expert and (tp_group.size() > 1 or self.expert_parallel) + + if self.explicit_expert_comm: + if parallel_mode == "column": + assert ( + output_size % tp_group.size() == 0 + ), f"output_size {output_size} must be divisible by tp_group.size() {tp_group.size()}" + output_size = output_size // tp_group.size() + elif parallel_mode == "row": + assert ( + input_size % tp_group.size() == 0 + ), f"input_size {input_size} must be divisible by tp_group.size() {tp_group.size()}" + input_size = input_size // tp_group.size() + parallel_mode = None + tp_group = None + + super().__init__( + num_gemms=num_gemms, + in_features=input_size, + out_features=output_size, + sequence_parallel=sequence_parallel, + fuse_wgrad_accumulation=config.moe_config.gradient_accumulation_fusion, + tp_group=tp_group, + tp_size=tp_group.size() if tp_group is not None else 1, + init_method=None, # TODO: + bias=bias, + return_bias=self.te_return_bias, + parallel_mode=parallel_mode, + ub_name=tp_comm_buffer_name, + # get_rng_state_tracker=None, + # params_dtype=torch.bfloat16, + # device=device, + # rng_tracker_name= # TODO: do i need rng tracker name? + ) + + for param in self.parameters(): + setattr( + param, "allreduce", not (is_expert and self.expert_parallel) + ) # TODO: does this work with TE or megatron? + + # self._te_state_cleanup() + + # def _te_state_cleanup(self): + # """ + # Remove uncessary TE states. + # """ + + # # remove extra_state because it seems to be related to FP8 training + # # https://github.com/NVIDIA/Megatron-LM/blob/460e9611ba7fe07dbfb40219515c91572f622db9/megatron/core/extensions/transformer_engine.py#L1068-L1107 + # # and it keeps it, it fails + # pass + + def forward(self, x, m_splits): + """Forward.""" + _is_first_microbatch = None if self.disable_parameter_transpose_cache else self.is_first_microbatch + out = super().forward(x, m_splits, is_first_microbatch=_is_first_microbatch) + self.is_first_microbatch = False + + # TE only returns a tuple when return_bias is True, otherwise + # it returns a single Tensor, we always want to return two + # values regardless of the arguments. + if self.te_return_bias: + return out + return out, None + + +class TEGroupedMLP(nn.Module): + """An efficient implementation of the Experts layer using TE's GroupedLinear. + + Executes multiple experts in parallel to maximize computational efficiency. + """ + + def __init__( + self, + num_local_experts, + config: Qwen2Config, + parallel_config: ParallelismArgs, + parallel_context: ParallelContext, + ): + super().__init__() + self.num_local_experts = num_local_experts + self.input_size = config.hidden_size + self.config = config + + # Double the output width with gated linear unit, see https://arxiv.org/pdf/2002.05202.pdf + ffn_hidden_size = config.moe_config.moe_intermediate_size + if config.hidden_act == "silu": # gated_linear_unit + ffn_hidden_size *= 2 + else: + raise ValueError(f"Unsupported activation function: {config.hidden_act}") + + self.linear_fc1 = TEGroupedLinear( # TEColumnParallelGroupedLinear + num_gemms=self.num_local_experts, + input_size=self.input_size, + output_size=ffn_hidden_size, + parallel_mode="column", + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + bias=False, # TODO: no bias right? + skip_bias_add=True, + is_expert=True, + tp_comm_buffer_name="fc1", + ) + + self.linear_fc2 = TEGroupedLinear( # TERowParallelGroupedLinear + num_gemms=self.num_local_experts, + input_size=ffn_hidden_size + // 2, # TODO: hack for now. AssertionError: GEMM not possible: inp.shape[-1] = 1024, in_features = 2048 + output_size=config.hidden_size, + parallel_mode="row", + config=config, + parallel_config=parallel_config, + parallel_context=parallel_context, + bias=False, + skip_bias_add=True, + is_expert=True, + tp_comm_buffer_name="fc2", + ) + + self.act = ACT2FN[config.hidden_act] # TODO: cleanup + + mark_all_parameters_in_module_as_sharded( + self, + pg=parallel_context.ep_pg, + split_config=SplitConfig(split_dim=0), + ) + + def forward( + self, permuted_local_hidden_states: torch.Tensor, tokens_per_expert: torch.Tensor + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """Forward of TEGroupedMLP + + Args: + permuted_local_hidden_states (torch.Tensor): The permuted input hidden states of the + local experts. + tokens_per_expert (torch.Tensor): The number of tokens per expert. + + Return: + output (torch.Tensor): The output of the local experts. + """ + # NOTE: transformer engine requires this to be a list + tokens_per_expert = tokens_per_expert.tolist() + + intermediate_parallel, bias_parallel = self.linear_fc1(permuted_local_hidden_states, tokens_per_expert) + + if self.config.moe_config.bias_activation_fusion: # TODO: need to fix this, it's true by default in megatron + intermediate_parallel = bias_swiglu_impl( + intermediate_parallel, + bias_parallel, + ) + else: + # TODO: we assume gated + intermediate_parallel = torch.chunk(intermediate_parallel, 2, dim=-1) + intermediate_parallel = self.act(intermediate_parallel[0]) * intermediate_parallel[1] + + output, output_bias = self.linear_fc2(intermediate_parallel, tokens_per_expert) + + return output, output_bias + + +def bias_swiglu_impl(input, bias, fp8_input_store=False): + ori_shape = input.shape + assert len(ori_shape) in [2, 3] + input = input.view(-1, ori_shape[-1]) + if bias is not None: + raise NotImplementedError("Bias is not supported") + else: + output = SwiGLUFunction.apply(input, fp8_input_store) + + return output if len(ori_shape) == 2 else output.view(ori_shape[0], ori_shape[1], -1) + + +class MoEAllGatherTokenDispatcher(nn.Module): + """ + AllGather Based Token dispatcher. + Note that this allgather spans the communication domain of TP*EP: + """ + + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: Qwen2Config, + parallel_context: ParallelContext, + ) -> None: + """ + Initialize the zero token dropping router. + """ + super().__init__() + self.config = config + self.shared_experts = None + self.etp_size = parallel_context.expert_tensor_parallel_size + self.ep_size = parallel_context.expert_parallel_size + self.ep_pg = parallel_context.ep_pg + + self.num_local_experts = num_local_experts + assert self.num_local_experts > 0, "Expected at least one expert" + self.local_expert_indices = local_expert_indices + assert len(self.local_expert_indices) > 0, "Expected at least one local expert index" + self.router_topk = config.moe_config.top_k + self.add_bias = False # TODO: assume no bias + + # self.global_local_map: 2D tensor. A mask of mapping between global and local tokens where + # each element is True if it's between the local_expert_indices. Only useful when cross + # device token permutation is enabled and **AllGahter** is performed. + self.global_local_map = None + + def token_permutation(self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor): + """Dispatch tokens to local experts. It's composed of two stages: + (1) Gather the tokens across the expert parallel devices. After this stage, + each device receives all of the tokens assigned to its local set of experts + in its local HBM. + (2) Permute the tokens locally so that they are grouped by their expert + assignment. + + Args: + hidden_states: 3D tensor [S/TP, B, H]. Input tokens. + probs: 2D tensor [S/TP*B, num_experts]. Each row of probs contains + the probility distribution across `topk` experts for one local token. + routing_map: 2D tensor [S/TP*B, num_experts], representing token assignment to + global experts. + + Returns: + permuted_local_hidden_states: Permutation of tokens to local experts group. + tokens_per_expert: the number of tokens each local expert to process. + """ + self.hidden_shape = hidden_states.shape # [bs*seq_len, hidden_size] + # [S/TP, B, H] -> [S*B/TP, H] + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + + # Permute the tokens across the expert parallel devices. + # if self.etp_size > 1 or self.ep_size > 1: + if self.etp_size > 1 or self.ep_size > 1: + assert self.etp_size == 1, "Expert tensor parallelism currently not supported" + + ## local_indices calculation + with torch.no_grad(): + # [num_local_tokens, num_experts] -> [num_global_tokens, num_experts], where: + # num_local_tokens=(S/TP)*B, num_global_tokens=S*B*EP + routing_map = differentiable_all_gather(routing_map, group=self.ep_pg) + # routing_map = fake_gather_to_tensor_group(routing_map, dim=0, group=self.ep_pg) + + ## local_probs calculation + # max_prob: [S/TP*B, num_experts] -> global_probs: [S*B*EP, num_experts] + probs = differentiable_all_gather(probs, group=self.ep_pg) + # probs = fake_gather_to_tensor_group(probs, dim=0, group=self.ep_pg) + + # Note that this allgather spans the communication domain of TP*EP. + # [(S/TP)*B, H] -> [((S/TP)*B)*(TP*EP), H] = [S*B*EP, H] + # hidden_states = gather_from_sequence_parallel_region( + # hidden_states, group=self.tp_ep_group, use_global_buffer=True + # ) + # hidden_states = gather_from_sequence_parallel_region( + # hidden_states, group=self.ep_pg, use_global_buffer=True + # ) + hidden_states = differentiable_all_gather(hidden_states, group=self.ep_pg) + # hidden_states = fake_gather_to_tensor_group(hidden_states, dim=0, group=self.ep_pg) + + self.hidden_shape_before_permute = hidden_states.shape + + # The routing map and probs that for local experts. + self.local_map = routing_map[:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1].contiguous() + # probs of global token assignment to local experts. + self.local_probs = probs[:, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1].contiguous() + + tokens_per_expert = self.local_map.sum(dim=0).long() + + # NOTE: megablock grouped gemm requires tokens_per_expert to be on cpu + if self.config.moe_config.grouped_gemm_imple == "megablock_grouped_gemm": + tokens_per_expert = tokens_per_expert.cpu() + + (permuted_local_hidden_states, self.reversed_local_input_permutation_mapping) = permute( + hidden_states, + self.local_map, + num_out_tokens=tokens_per_expert.sum(), + fused=self.config.moe_config.permute_fusion, + ) + + return permuted_local_hidden_states, tokens_per_expert + + def token_unpermutation(self, hidden_states: torch.Tensor, bias: torch.Tensor = None): + """ + Reverse process of `dispatch()` which permutes the output of local + experts locallay and across expert parallel rank into the original order to + produce the final output. + + Args: + hidden_states: 2D tensor [num_permuted_tokens_for_local_experts, H], + output of local experts. + bias (optional): The bias tensor. + + Returns: + output_total: un-permuted updated hidden states output from all local experts + with shape of [S/TP, B, H] + """ + + # # Scale the expert output prior to reduction and subsequent to local unpermutation if k > 1. + # # Unpermute the expert output and bias + # permuted_probs = self.local_probs.T.contiguous().masked_select(self.local_map.T.contiguous()) + # # Here may change permuted_tokens to higher precision if probs use fp32/fp64. + # weighted_hidden_states = hidden_states * permuted_probs.unsqueeze(-1) + weighted_hidden_states = hidden_states + permuted_probs = self.local_probs + unpermuted_local_hidden = unpermute( + weighted_hidden_states, + self.reversed_local_input_permutation_mapping, + restore_shape=self.hidden_shape_before_permute, + probs=permuted_probs, + routing_map=self.local_map, + fused=self.config.moe_config.permute_fusion, + ) + + unpermuted_local_bias = None + if self.add_bias: + assert bias is not None + weighted_bias = bias * permuted_probs.unsqueeze(-1) + unpermuted_local_bias = unpermute( + weighted_bias, + self.reversed_local_input_permutation_mapping, + restore_shape=self.hidden_shape_before_permute, + routing_map=self.local_map, + fused=self.config.moe_config.permute_fusion, + ) + + output_total = unpermuted_local_hidden + output_bias_total = unpermuted_local_bias + + # Unpermute the tokens across ranks. + # if self.etp_size > 1 or self.ep_size > 1: + if self.ep_size > 1: + if self.etp_size > 1: + raise NotImplementedError("etp_size>1 not implemented") + output_total = reduce_scatter_to_sequence_parallel_region(output_total, group=self.tp_ep_group) + if self.add_bias: + # Unpermute the bias across expert parallel devices. + # bias is duplicated across tensor parallelism ranks; + output_bias_total = ( + reduce_scatter_to_sequence_parallel_region(output_bias_total, group=self.tp_ep_group) + / self.etp_size + ) + + # output_total = scatter_tensor_along_dim(output_total, dim=0, group=self.ep_pg) + output_total = scatter_to_tensor_group(output_total, dim=0, group=self.ep_pg) + # from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import differentiable_reduce_scatter_sum + # output_total = differentiable_reduce_scatter_sum(output_total, group=self.ep_pg) + + output_total = output_total.view(self.hidden_shape) + + if self.add_bias: + output_bias_total = output_bias_total.view(self.hidden_shape) + + # Restore the dtype of the output to the original dtype. + output_total = output_total.to(hidden_states.dtype) + if bias is not None: + output_bias_total = output_bias_total.to(bias.dtype) + return output_total, output_bias_total + + +class MoEAlltoAllTokenDispatcher(nn.Module): + """ + AlltoAll-based token dispatcher. + + The workflow of AlltoAll token dispatcher is as follows: + (1) preprocess(): calculate necessary metadata for communication and permute + (2) token_permutation(): permute->A2A(EP)->AG(TP)->sort_chunk(if num_local_experts>1) + (3) token_unpermutation(): sort_chunk(if num_local_experts>1)->RS(TP)->A2A(EP)->unpermute + """ + + def __init__( + self, + num_local_experts: int, + local_expert_indices: List[int], + config: Qwen2Config, + parallel_context: ParallelContext, + ) -> None: + """ + Initialize the AlltoAll token dispatcher. + + Args: + num_local_experts (int): Number of local experts on the current device. + local_expert_indices (List[int]): Indices of local experts on the current device. + config (TransformerConfig): Configuration for the transformer model. + """ + super().__init__() + self.config = config + self.shared_experts = None + self.etp_size = parallel_context.expert_tensor_parallel_size + self.ep_size = parallel_context.expert_parallel_size + self.ep_pg = parallel_context.ep_pg + self.ep_tp_pg = parallel_context.ep_tp_pg + + self.num_local_experts = num_local_experts + assert config.moe_config.num_experts is not None + self.num_experts = config.moe_config.num_experts + assert self.num_local_experts > 0, "Expected at least one expert" + self.local_expert_indices = local_expert_indices + assert len(self.local_expert_indices) == self.num_local_experts, "Invalid local expert indices" + for i in range(len(self.local_expert_indices) - 1): + assert ( + self.local_expert_indices[i] == self.local_expert_indices[i + 1] - 1 + ), "local_expert_indices must be continuous" + + # [ep_size]. Represents the number of tokens sent by the current rank to other + # EP ranks. + self.input_splits = None + # [ep_size]. Represents the number of tokens received by the current rank from + # other EP ranks. + self.output_splits = None + # [tp_size]. Represents the number of tokens received by the current rank from + # other TP ranks. + self.output_splits_tp = None + self.permute_idx_device = torch.device("cuda") if self.config.moe_config.permute_fusion else None + input_chunk_idxs = torch.arange(self.num_experts * self.etp_size, device=self.permute_idx_device) + # [num_local_experts, tp_size * ep_size]. Sort the input chunks by local experts. + self.sort_input_by_local_experts = input_chunk_idxs.reshape(-1, self.num_local_experts).T.ravel() + # [tp_size * ep_size, num_local_experts]. Restore the output chunks by local experts. + self.restore_output_by_local_experts = input_chunk_idxs.reshape(self.num_local_experts, -1).T.ravel() + + # Token drop and padding. + # Drop and pad the input to capacity. + self.drop_and_pad = self.config.moe_config.moe_pad_expert_input_to_capacity + if self.drop_and_pad: + assert self.config.moe_config.moe_expert_capacity_factor is not None + self.moe_expert_capacity_factor = self.config.moe_config.moe_expert_capacity_factor + self.capacity = None + # NOTE: since we don't have etp, assume expert tensor parallel size is always 1 + # NOTE: https://github.com/NVIDIA/Megatron-LM/blob/28118fcdc22e42621776a021af568ae39c198418/megatron/core/transformer/moe/token_dispatcher.py#L60-L68 + self.tp_rank = 0 + + # A cuda stream synchronization is needed in self.token_permutation() in some cases, + # because there are several non-blocking DtoH data transfers called in self.preprocess(). + # The synchronization happens at different points based on MoE settings as late as possible. + # Valid sync points are "before_permutation_1", "before_ep_alltoall", "before_finish", + # and "no_sync". + self.cuda_sync_point = "no_sync" + + self.shared_experts = None + + def preprocess(self, routing_map: torch.Tensor) -> torch.Tensor: + """ + Preprocess token routing map for AlltoAll communication and token permutation. + + This method computes the number of tokens assigned to each expert based on the routing_map. + It also initializes the necessary data structures for AlltoAll communication, such as input + and output splits, and the mapping between global tokens and local experts. + + Args: + routing_map (torch.Tensor): The mapping of tokens to experts, with shape + [num_tokens, num_experts]. + + Returns: + torch.Tensor: Tensor containing the number of tokens assigned to local expert. + """ + # [num_experts], number of tokens assigned to each expert from the current rank's input. + num_local_tokens_per_expert = routing_map.sum(dim=0).long() + + if self.drop_and_pad: + # Drop and pad the input to capacity. + num_tokens = routing_map.size(0) * self.config.moe_config.top_k + self.capacity = get_capacity( + num_tokens=num_tokens, + num_experts=self.num_experts, + capacity_factor=self.moe_expert_capacity_factor, + ) + self.num_out_tokens = self.capacity * self.num_experts + # [num_local_experts], number of tokens processed by each expert. + num_tokens_per_local_expert = torch.full( + (self.num_local_experts,), + self.capacity * self.etp_size * self.ep_size, + dtype=torch.long, + ) + # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = torch.full( + (self.num_experts * self.etp_size,), + self.capacity, + dtype=torch.long, + device=self.permute_idx_device, + ) + return num_tokens_per_local_expert + elif self.config.moe_config.moe_expert_capacity_factor is not None: + # Drop tokens to capacity, no padding. + # A synchronization is needed before the first + # permutation to get the `num_out_tokens` CPU value. + self.num_out_tokens = num_local_tokens_per_expert.sum().to(torch.device("cpu"), non_blocking=True) + self.cuda_sync_point = "before_permutation_1" + else: + # Dropless + self.num_out_tokens = routing_map.size(0) * self.config.moe_config.top_k + if self.ep_size > 1 or self.num_local_experts > 1: + # Token dropless and enable ep. A synchronization is needed before expert parallel + # AlltoAll communication to get the `input_splits` and `output_splits` CPU values. + self.cuda_sync_point = "before_ep_alltoall" + else: + # Token dropless and no ep. A synchronization is needed before the returns + # to get the `tokens_per_expert` CPU value for + self.cuda_sync_point = "before_finish" + + if self.ep_size > 1 or self.etp_size > 1: + # =================================================== + # Calculate input_splits, output_splits for alltoall/allgather in variable size. + # =================================================== + # [ep_size]. Represents the number of tokens sent by the current rank to other + # EP ranks. + self.input_splits = ( + num_local_tokens_per_expert.reshape(self.ep_size, self.num_local_experts) + .sum(axis=1) + .to(torch.device("cpu"), non_blocking=True) + .numpy() + ) + # Gather the global distribution of tokens across ranks. + # num_global_tokens_per_expert represents the number of tokens sent to each + # expert by all ranks. + # [tp_size, ep_size, num_experts] + # TODO: do i need this when etp_size==1? + num_global_tokens_per_expert = ( + differentiable_all_gather(num_local_tokens_per_expert, group=self.ep_tp_pg) + .reshape(self.ep_size, self.etp_size, self.num_experts) + .transpose(0, 1) + ) + # [tp_size, ep_size, num_experts] -> [tp_size, ep_size, num_local_experts] + num_global_tokens_per_local_expert = num_global_tokens_per_expert[ + :, :, self.local_expert_indices[0] : self.local_expert_indices[-1] + 1 + ].contiguous() + # [tp_size, ep_size, num_local_experts] -> [tp_size, ep_size] + num_global_tokens_per_rank = num_global_tokens_per_local_expert.sum(axis=2) + # [tp_size, ep_size] -> [ep_size] + # self.output_splits represents the number of tokens received by the current rank + # from other EP rank. + self.output_splits = ( + num_global_tokens_per_rank[self.tp_rank].to(torch.device("cpu"), non_blocking=True).numpy() + ) + # [tp_size, ep_size] -> [tp_size] + # self.output_splits_tp represents the number of tokens received by the current + # rank from other TP rank. + self.output_splits_tp = ( + num_global_tokens_per_rank.sum(axis=1).to(torch.device("cpu"), non_blocking=True).numpy() + ) + # [tp_size, ep_size, num_local_experts] -> [num_local_experts] + num_tokens_per_local_expert = num_global_tokens_per_local_expert.sum(dim=(0, 1)).to( + torch.device("cpu"), non_blocking=True + ) + else: + num_global_tokens_per_local_expert = num_local_tokens_per_expert.reshape(self.num_experts) + num_tokens_per_local_expert = num_local_tokens_per_expert.to(torch.device("cpu"), non_blocking=True) + + if self.num_local_experts > 1: + # [tp_size * ep_size, num_local_experts]. Represents the number of tokens sent + # to each local expert by all ranks. + self.num_global_tokens_per_local_expert = num_global_tokens_per_local_expert.view( + -1, self.num_local_experts + ) + if not self.config.moe_config.permute_fusion: + self.num_global_tokens_per_local_expert = self.num_global_tokens_per_local_expert.to( + torch.device("cpu"), non_blocking=True + ) + + return num_tokens_per_local_expert + + def token_permutation( + self, hidden_states: torch.Tensor, probs: torch.Tensor, routing_map: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Dispatch tokens to local experts using AlltoAll communication. + + This method performs the following steps: + 1. Preprocess the routing map to get metadata for communication and permutation. + 2. Permute input tokens for AlltoAll communication. + 3. Perform expert parallel AlltoAll communication. + 4. Sort tokens by local expert (if multiple local experts exist). + + Args: + hidden_states (torch.Tensor): Input token embeddings. + probs (torch.Tensor): The probabilities of token to experts assignment. + routing_map (torch.Tensor): The mapping of token to experts assignment. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: + - Permuted token embeddings for local experts. + - Number of tokens per expert. + """ + # Preprocess: Get the metadata for communication, permutation and computation operations. + self.hidden_shape = hidden_states.shape + self.probs = probs + self.routing_map = routing_map + assert probs.dim() == 2, "Expected 2D tensor for probs" + assert routing_map.dim() == 2, "Expected 2D tensor for token2expert mask" + assert routing_map.dtype == torch.bool, "Expected bool tensor for mask" + hidden_states = hidden_states.view(-1, self.hidden_shape[-1]) + tokens_per_expert = self.preprocess(self.routing_map) + + if self.shared_experts is not None: + self.shared_experts.pre_forward_comm(hidden_states.view(self.hidden_shape)) + + # Permutation 1: input to AlltoAll input + self.hidden_shape_before_permute = hidden_states.shape + if self.cuda_sync_point == "before_permutation_1": + torch.cuda.current_stream().synchronize() + permutated_local_input_tokens, self.reversed_local_input_permutation_mapping = permute( + hidden_states, + routing_map, + num_out_tokens=self.num_out_tokens, + fused=self.config.moe_config.permute_fusion, + drop_and_pad=self.drop_and_pad, + ) + + # Perform expert parallel AlltoAll communication + if self.cuda_sync_point == "before_ep_alltoall": + torch.cuda.current_stream().synchronize() + global_input_tokens = all_to_all( + permutated_local_input_tokens, self.output_splits, self.input_splits, self.ep_pg + ) + if self.shared_experts is not None: + self.shared_experts.linear_fc1_forward_and_act(global_input_tokens) + + if self.etp_size > 1: + raise NotImplementedError("Not implemented") + if self.output_splits_tp is None: + output_split_sizes = None + else: + output_split_sizes = self.output_splits_tp.tolist() + global_input_tokens = gather_from_sequence_parallel_region( + global_input_tokens, group=self.ep_tp_pg, output_split_sizes=output_split_sizes + ) + + # Permutation 2: Sort tokens by local expert. + if self.num_local_experts > 1: + if self.drop_and_pad: + global_input_tokens = ( + global_input_tokens.view( + self.etp_size * self.ep_size, + self.num_local_experts, + self.capacity, + *global_input_tokens.size()[1:], + ) + .transpose(0, 1) + .contiguous() + .flatten(start_dim=0, end_dim=2) + ) + else: + global_input_tokens = sort_chunks_by_idxs( + global_input_tokens, + self.num_global_tokens_per_local_expert.ravel(), + self.sort_input_by_local_experts, + fused=self.config.moe_config.permute_fusion, + ) + + if self.cuda_sync_point == "before_finish": + torch.cuda.current_stream().synchronize() + + return global_input_tokens, tokens_per_expert + + def token_unpermutation( + self, hidden_states: torch.Tensor, bias: Optional[torch.Tensor] = None + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Reverse the token permutation to restore the original order. + + This method performs the following steps: + 1. Unsort tokens by local expert (if multiple local experts exist). + 2. Perform expert parallel AlltoAll communication to restore the original order. + 3. Unpermute tokens to restore the original order. + + Args: + hidden_states (torch.Tensor): Output from local experts. + bias (torch.Tensor, optional): Bias tensor (not supported). + + Returns: + Tuple[torch.Tensor, Optional[torch.Tensor]]: + - Unpermuted token embeddings in the original order. + - None (bias is not supported). + """ + assert bias is None, "Bias is not supported in MoEAlltoAllTokenDispatcher" + + # Unpermutation 2: Unsort tokens by local expert. + if self.num_local_experts > 1: + if self.drop_and_pad: + hidden_states = ( + hidden_states.view( + self.num_local_experts, + self.etp_size * self.ep_size, + self.capacity, + *hidden_states.size()[1:], + ) + .transpose(0, 1) + .contiguous() + .flatten(start_dim=0, end_dim=2) + ) + else: + hidden_states = sort_chunks_by_idxs( + hidden_states, + self.num_global_tokens_per_local_expert.T.ravel(), + self.restore_output_by_local_experts, + fused=self.config.moe_config.permute_fusion, + ) + + if self.etp_size > 1: + raise NotImplementedError("Not implemented") + if self.output_splits_tp is None: + input_split_sizes = None + else: + input_split_sizes = self.output_splits_tp.tolist() + # TODO: input_split_sizes + hidden_states = differentiable_reduce_scatter_sum( + hidden_states, group=self.ep_tp_pg, input_split_sizes=input_split_sizes + ) + + # Perform expert parallel AlltoAll communication + # hidden_states: [SEQL, H] -> [SEQL, H/TP] + permutated_local_input_tokens = all_to_all(hidden_states, self.input_splits, self.output_splits, self.ep_pg) + if self.shared_experts is not None: + self.shared_experts.linear_fc2_forward(permutated_local_input_tokens) + self.shared_experts.post_forward_comm() + + # Unpermutation 1: AlltoAll output to output + output = unpermute( + permutated_local_input_tokens, + self.reversed_local_input_permutation_mapping, + restore_shape=self.hidden_shape_before_permute, + probs=self.probs, + routing_map=self.routing_map, + fused=self.config.moe_config.permute_fusion, + drop_and_pad=self.drop_and_pad, + ) + + # Reshape the output tensor + output = output.view(self.hidden_shape) + + # Add shared experts output + if self.shared_experts is not None: + shared_expert_output = self.shared_experts.get_output() + output += shared_expert_output + return output, None + + +def sort_chunks_by_idxs( + input: torch.Tensor, split_sizes: torch.Tensor, sorted_idxs: torch.Tensor, fused: bool = False +): + """Split and sort the input tensor based on the split_sizes and sorted indices.""" + if fused: + if not HAVE_TE or fused_sort_chunks_by_index is None: + raise ValueError("fused_sort_chunks_by_index is not available. Please install TE >= 2.1.0.") + return fused_sort_chunks_by_index(input, split_sizes, sorted_idxs) + + input = torch.split(input, split_sizes.tolist(), dim=0) + output = torch.cat([input[i] for i in sorted_idxs.tolist()], dim=0) + return output + + +try: + from transformer_engine.pytorch.permutation import ( + moe_permute, + moe_sort_chunks_by_index, + moe_unpermute, + ) + + fused_permute = moe_permute + fused_unpermute = moe_unpermute + fused_sort_chunks_by_index = moe_sort_chunks_by_index + HAVE_TE = True +except ImportError: + HAVE_TE = False + + +import torch.nn.functional as F + + +@torch.compile +def swiglu(y): + y_1, y_2 = torch.chunk(y, 2, -1) + return F.silu(y_1) * y_2 + + +@torch.compile +def swiglu_back(g, y): + y_1, y_2 = torch.chunk(y, 2, -1) + return torch.cat((g * torch.sigmoid(y_1) * (1 + y_1 * (1 - torch.sigmoid(y_1))) * y_2, g * F.silu(y_1)), -1) + + +class SwiGLUFunction(torch.autograd.Function): + @staticmethod + # bias is an optional argument + def forward(ctx, input, fp8_input_store): + input_for_backward = input.to(torch.float8_e4m3fn) if fp8_input_store else input + ctx.save_for_backward(input_for_backward) + ctx.ori_input_dtype = input.dtype + ctx.fp8_input_store = fp8_input_store + return swiglu(input) + + @staticmethod + def backward(ctx, grad_output): + input = ctx.saved_tensors[0] + input = input.to(ctx.ori_input_dtype) if ctx.fp8_input_store else input + tmp = swiglu_back(grad_output, input) + return tmp, None diff --git a/src/nanotron/optim/gradient_accumulator.py b/src/nanotron/optim/gradient_accumulator.py index 8107b46e6..f553a048e 100644 --- a/src/nanotron/optim/gradient_accumulator.py +++ b/src/nanotron/optim/gradient_accumulator.py @@ -2,13 +2,15 @@ from abc import ABC, abstractmethod from collections import OrderedDict from contextlib import contextmanager -from typing import Callable, Dict, Iterator, Optional, Tuple +from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union import torch from torch.distributed import GradBucket import nanotron.distributed as dist from nanotron import logging +from nanotron.logging import log_rank +from nanotron.logging.timers import nanotron_timer from nanotron.parallel.parameters import NanotronParameter from nanotron.utils import get_untyped_storage, tensor_from_untyped_storage @@ -16,7 +18,7 @@ class GradientAccumulator(ABC): - fp32_grads_allreduce_handle: Optional[torch.futures.Future] + fp32_grads_allreduce_handle: Optional[Union[List[torch.futures.Future], torch.futures.Future]] @abstractmethod def __init__(self, named_parameters: Iterator[Tuple[str, NanotronParameter]]): @@ -117,7 +119,7 @@ def __init__( self._is_accumulation_sync_step = False # We need the last allreduce handle to make sure it finishes before the optimizer step - self.fp32_grads_allreduce_handle: Optional[torch.futures.Future] = None + self.fp32_grads_allreduce_handle: Optional[Union[List[torch.futures.Future], torch.futures.Future]] = [] def assign_param_offsets(self, param_name_to_offsets: Dict[str, Dict[int, Tuple[int, int]]], dp_rank: int): """To use only when you use with ZeRODistributedOptimizer""" @@ -182,7 +184,8 @@ def build_grad_buffers( if not param.requires_grad: continue - assert param.dtype != torch.float, f"Expected {name} not to be float" + # MoE router weights are initialized in float32 + assert param.dtype != torch.float or "router.weight" in name, f"Expected {name} not to be float" assert param.is_contiguous(), f"Expected {name} to be contiguous" next_offset = offset + param.numel() * element_size @@ -203,16 +206,21 @@ def build_grad_buffers( return fp32_grad_buffers, contiguous_buffer_f32_gradients def backward(self, loss: torch.Tensor): + nanotron_timer("loss backward", timer_type="cuda", cuda_sync=True).start() result = loss.backward() + nanotron_timer("loss backward", timer_type="cuda", cuda_sync=True).end() + nanotron_timer("accumulate_grad", timer_type="cuda", cuda_sync=True).start() for name, elt in self.fp32_grad_buffers.items(): self._accumulate_grad(name=name, half_param=elt["half"]) + nanotron_timer("accumulate_grad", timer_type="cuda", cuda_sync=True).end() return result @torch.profiler.record_function("FP32GradientAccumulator._accumulate_grad") def _accumulate_grad(self, name: str, half_param: NanotronParameter) -> None: """Accumulate grad in fp32 and set the fp32 grad to the fp32 grad buffer, so that optimizer can update fp32 weights afterwards""" + assert half_param.grad is not None, f"Expected param {name} to have gradient." fp32_grad = self.get_grad_buffer(name=name) @@ -303,7 +311,8 @@ class FP32GradBucketManager: """Manages the fp32 gradient buckets. Attributes: - dp_pg: The process group to allreduce gradients across. + dp_pg: The process group to allreduce non-expert gradients across. + ep_dp_pg: The process group to allreduce expert gradients across. accumulator: The gradient accumulator which keeps the gradient buffers. bucket_id_to_fp32_grad_buckets_and_dependencies: A dictionary mapping bucket ids to: - fp32 grad bucket (torch.Tensor) @@ -311,6 +320,7 @@ class FP32GradBucketManager: param_id_to_bucket_id: A dictionary mapping param ids to bucket ids.""" dp_pg: dist.ProcessGroup + ep_dp_pg: Optional[dist.ProcessGroup] accumulator: FP32GradientAccumulator param_id_to_name: Dict[int, str] @@ -330,27 +340,52 @@ def get_fp32_accum_hook( # s = torch.cuda.Stream() def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.futures.Future[torch.Tensor]: + from nanotron.nn.moe import is_expert_param + # nonlocal s # DDP groups grads in GradBuckets. This hook is called throughout the bwd pass, once each bucket is ready to overlap communication with computation. # See https://pytorch.org/docs/stable/ddp_comm_hooks.html#what-does-a-communication-hook-operate-on for more details. - dp_pg = state.dp_pg accumulator = state.accumulator param_id_to_name = state.param_id_to_name # Add new incoming gradient # with torch.cuda.stream(s): + param_in_buckets = [] for param, grad in zip(bucket.parameters(), bucket.gradients()): name = param_id_to_name[id(param)] fp32_grad_buffer = accumulator.get_grad_buffer(name) fp32_grad_buffer.add_(grad.view_as(fp32_grad_buffer)) + param_in_buckets.append(name) + + log_rank( + "[DDP triggered register_comm_hook] parameters in bucket: {}".format(param_in_buckets), + logger=logger, + level=logging.DEBUG, + rank=0, + ) + + expert_params = [param for param in bucket.parameters() if is_expert_param(param_id_to_name[id(param)])] + non_expert_params = [ + param for param in bucket.parameters() if not is_expert_param(param_id_to_name[id(param)]) + ] + ep_dp_pg = state.ep_dp_pg + dp_pg = state.dp_pg + + if len(expert_params) > 0: + all_reduce_group = ep_dp_pg + else: + all_reduce_group = dp_pg # sync across dp - if dp_pg.size() == 1: + if all_reduce_group.size() == 1: fut = torch.futures.Future() fut.set_result(bucket.buffer()) return fut if reduce_scatter: + if len(expert_params) > 0: + raise NotImplementedError("Reduce scatter is not implemented for MoE") + assert hasattr(accumulator, "param_name_to_offsets") grad_buffer_tensor_list = [ accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters() @@ -364,23 +399,41 @@ def fp32_accum_hook(state: FP32GradBucketManager, bucket: GradBucket) -> torch.f for grad_buffer, param in zip(grad_buffer_tensor_list, bucket.parameters()) ] input_tensor_lists = [ - torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // dp_pg.size()) + torch.split(grad_buffer, split_size_or_sections=len(grad_buffer) // all_reduce_group.size()) for grad_buffer in grad_buffer_tensor_list ] dist.reduce_scatter_coalesced( output_tensor_list=output_tensor_list, input_tensor_lists=input_tensor_lists, op=reduce_op, - group=dp_pg, + group=all_reduce_group, async_op=True, ) else: - grad_buffer_tensor_list = [ - accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in bucket.parameters() - ] - accumulator.fp32_grads_allreduce_handle = dist.all_reduce_coalesced( - grad_buffer_tensor_list, group=dp_pg, async_op=True, op=reduce_op + + def execute_all_reduce(params: List[NanotronParameter], pg: dist.ProcessGroup): + grad_buffer_tensor_list = [ + accumulator.get_grad_buffer(param_id_to_name[id(param)]).view(-1) for param in params + ] + accumulator.fp32_grads_allreduce_handle.append( + dist.all_reduce_coalesced(grad_buffer_tensor_list, group=pg, async_op=True, op=reduce_op) + ) + + log_rank( + "[DDP triggered register_comm_hook] Detect a bucket have both expert and non-expert params: {}".format( + param_in_buckets + ), + logger=logger, + level=logging.DEBUG, + rank=0, ) + + if dist.get_world_size(ep_dp_pg) > 1 and len(expert_params) > 0: + execute_all_reduce(expert_params, pg=ep_dp_pg) + + if dist.get_world_size(dp_pg) > 1 and len(non_expert_params) > 0: + execute_all_reduce(non_expert_params, pg=dp_pg) + # we shouldn't wait for this future for the rest of the backward # with torch.cuda.stream(s): diff --git a/src/nanotron/parallel/context.py b/src/nanotron/parallel/context.py index 7820f8a25..7ea41b1c8 100644 --- a/src/nanotron/parallel/context.py +++ b/src/nanotron/parallel/context.py @@ -1,14 +1,28 @@ import os +from enum import Enum from typing import Dict, Literal import numpy as np import torch +from einops import rearrange import nanotron.distributed as dist DistributedBackend = Literal["gloo", "mpi", "nccl"] +class ParallelMode(Enum): + TP = "tp" + CP = "cp" + DP = "dp" + PP = "pp" + + EP = "ep" + EP_TP = "ep_tp" + EP_DP = "ep_dp" + EP_PP = "ep_pp" + + class ParallelContext: def __init__( self, @@ -17,15 +31,31 @@ def __init__( data_parallel_size: int, context_parallel_size: int = 1, expert_parallel_size: int = 1, + expert_tensor_parallel_size: int = 1, + expert_data_parallel_size: int = 1, + enabled_moe: bool = False, backend: DistributedBackend = "nccl", ): + """ + expert_parallel_size = 1 doesnt mean we dont have moe, it just means we dont have expert parallelism + """ """Initialize parallel context.""" world_size = int(os.environ["WORLD_SIZE"]) local_world_size = int(os.environ.get("LOCAL_WORLD_SIZE", "8")) if world_size > 8 else world_size - assert ( - tensor_parallel_size * pipeline_parallel_size * context_parallel_size * data_parallel_size - ) == world_size, f"TP*CP*DP*PP={tensor_parallel_size}*{pipeline_parallel_size}*{context_parallel_size}*{data_parallel_size}={tensor_parallel_size * pipeline_parallel_size * context_parallel_size * data_parallel_size} != WORLD_SIZE={world_size}" + if enabled_moe is False: + assert ( + tensor_parallel_size * pipeline_parallel_size * context_parallel_size * data_parallel_size + ) == world_size, f"TP*CP*DP*PP={tensor_parallel_size}*{pipeline_parallel_size}*{context_parallel_size}*{data_parallel_size}={tensor_parallel_size * pipeline_parallel_size * context_parallel_size * data_parallel_size} != WORLD_SIZE={world_size}" + else: + assert ( + data_parallel_size * tensor_parallel_size * context_parallel_size * pipeline_parallel_size + == world_size + ), f"DP*TP*CP*PP={data_parallel_size}*{tensor_parallel_size}*{context_parallel_size}*{pipeline_parallel_size}={data_parallel_size * tensor_parallel_size * context_parallel_size * pipeline_parallel_size} != WORLD_SIZE={world_size}" + assert ( + expert_data_parallel_size * expert_tensor_parallel_size * expert_parallel_size * pipeline_parallel_size + == world_size + ), f"EP_DP*EP_TP*EP*PP={expert_data_parallel_size}*{expert_tensor_parallel_size}*{expert_parallel_size}*{pipeline_parallel_size}={expert_data_parallel_size * expert_tensor_parallel_size * expert_parallel_size * pipeline_parallel_size} != WORLD_SIZE={world_size}" if not dist.is_available(): raise ValueError("torch.distributed is not available as a package, please install it.") @@ -35,6 +65,9 @@ def __init__( self.data_parallel_size = data_parallel_size self.context_parallel_size = context_parallel_size self.expert_parallel_size = expert_parallel_size + self.expert_tensor_parallel_size = expert_tensor_parallel_size + self.expert_data_parallel_size = expert_data_parallel_size + self.enabled_moe = enabled_moe self.world_size = world_size self.local_world_size = local_world_size @@ -59,57 +92,145 @@ def __init__( def _init_parallel_groups(self): """Initialize 3D parallelism's all process groups.""" dist.barrier() - ranks = np.arange(0, self.world_size).reshape( - ( - self.expert_parallel_size, - self.pipeline_parallel_size, - self.data_parallel_size, - self.context_parallel_size, - self.tensor_parallel_size, - ) - ) + self.world_ranks_to_pg = {} - self.local_pg = self.create_new_group(ranks.reshape((-1, self.local_world_size))) - assert int(os.environ.get("LOCAL_RANK")) == dist.get_rank(self.local_pg), "Local rank mismatch" - - # Relevant process groups containing the current rank - self.tp_pg = self.create_new_group(ranks.transpose((0, 1, 2, 3, 4)).reshape((-1, self.tensor_parallel_size))) - self.cp_pg = self.create_new_group(ranks.transpose((4, 0, 1, 2, 3)).reshape((-1, self.context_parallel_size))) - self.dp_pg = self.create_new_group(ranks.transpose((3, 4, 0, 1, 2)).reshape((-1, self.data_parallel_size))) - self.pp_pg = self.create_new_group(ranks.transpose((2, 3, 4, 0, 1)).reshape((-1, self.pipeline_parallel_size))) - self.ep_pg = self.create_new_group( - ranks.transpose((1, 2, 3, 4, 0)).reshape((-1, self.expert_parallel_size)) - ) # TODO: ep should be a subset of dp - - # model parallel group = combination of tp and pp and exp for a given dp rank + self._group_to_ranks = {} + + self._init_process_group() + + def _init_process_group(self): + """ + Decoupled 5D parallelism + based on the paper: + + MoE Parallel Folding: Heterogeneous Parallelism + Mappings for Efficient Large-Scale MoE Model + Training with Megatron Core + + Following the process group initialization in page 17 + + https://www.arxiv.org/abs/2504.14960 + """ + ranks = np.arange(0, self.world_size) + + # NOTE: attention parallelism + attn_ranks = ranks.reshape( + self.data_parallel_size, self.pipeline_parallel_size, self.context_parallel_size, self.tensor_parallel_size + ) + tp_ranks = rearrange( + attn_ranks, + "attn_dp pp cp tp -> (attn_dp pp cp) tp", + tp=self.tensor_parallel_size, + cp=self.context_parallel_size, + pp=self.pipeline_parallel_size, + attn_dp=self.data_parallel_size, + ).tolist() + cp_ranks = rearrange( + attn_ranks, + "attn_dp pp cp tp -> (attn_dp pp tp) cp", + tp=self.tensor_parallel_size, + cp=self.context_parallel_size, + pp=self.pipeline_parallel_size, + attn_dp=self.data_parallel_size, + ).tolist() + pp_ranks = rearrange( + attn_ranks, + "attn_dp pp cp tp -> (attn_dp cp tp) pp", + tp=self.tensor_parallel_size, + cp=self.context_parallel_size, + pp=self.pipeline_parallel_size, + attn_dp=self.data_parallel_size, + ).tolist() + dp_ranks = rearrange( + attn_ranks, + "attn_dp pp cp tp -> (pp cp tp) attn_dp", + tp=self.tensor_parallel_size, + cp=self.context_parallel_size, + pp=self.pipeline_parallel_size, + attn_dp=self.data_parallel_size, + ).tolist() + self.tp_pg = self.create_new_group(tp_ranks) + self.cp_pg = self.create_new_group(cp_ranks) + self.pp_pg = self.create_new_group(pp_ranks) + self.dp_pg = self.create_new_group(dp_ranks) self.mp_pg = self.create_new_group( [ - ranks[:, :, dp_rank, cp_rank, :].reshape(-1) + attn_ranks[dp_rank, :, cp_rank, :].reshape(-1) for cp_rank in range(self.context_parallel_size) for dp_rank in range(self.data_parallel_size) ] ) - - self.tp_and_ep_pg = self.create_new_group( + self.tp_and_cp_pg = self.create_new_group( [ - ranks[:, pp_rank, dp_rank, cp_rank, :].reshape(-1) - for cp_rank in range(self.context_parallel_size) + attn_ranks[dp_rank, pp_rank, :, :].reshape(-1) for pp_rank in range(self.pipeline_parallel_size) for dp_rank in range(self.data_parallel_size) ] ) - # self.tp_and_cp_pg = self.create_new_group( - # [ - # ranks[ep_rank, pp_rank, dp_rank, :, :].reshape(-1) - # for ep_rank in range(self.expert_parallel_size) - # for pp_rank in range(self.pipeline_parallel_size) - # for dp_rank in range(self.data_parallel_size) - # ] - # ) + _group_to_ranks = { + # NOTE: attention parallelism + ParallelMode.TP: tp_ranks, + ParallelMode.CP: cp_ranks, + ParallelMode.PP: pp_ranks, + ParallelMode.DP: dp_ranks, + } + self.parallel_order = ["dp", "pp", "cp", "tp"] + + if self.enabled_moe is True: + + # NOTE: expert parallelism + moe_ranks = ranks.reshape( + self.expert_data_parallel_size, + self.pipeline_parallel_size, + self.expert_parallel_size, + self.expert_tensor_parallel_size, + ) + ep_ranks = rearrange( + moe_ranks, + "moe_dp pp ep tp -> (moe_dp pp tp) ep", + tp=self.expert_tensor_parallel_size, + ep=self.expert_parallel_size, + pp=self.pipeline_parallel_size, + moe_dp=self.expert_data_parallel_size, + ) + ep_tp_ranks = rearrange( + moe_ranks, + "moe_dp pp ep tp -> (moe_dp pp ep) tp", + tp=self.expert_tensor_parallel_size, + ep=self.expert_parallel_size, + pp=self.pipeline_parallel_size, + moe_dp=self.expert_data_parallel_size, + ) + ep_pp_ranks = rearrange( + moe_ranks, + "moe_dp pp ep tp -> (moe_dp ep tp) pp", + tp=self.expert_tensor_parallel_size, + ep=self.expert_parallel_size, + pp=self.pipeline_parallel_size, + moe_dp=self.expert_data_parallel_size, + ) + ep_dp_ranks = rearrange( + moe_ranks, + "moe_dp pp ep tp -> (pp ep tp) moe_dp", + tp=self.expert_tensor_parallel_size, + ep=self.expert_parallel_size, + pp=self.pipeline_parallel_size, + moe_dp=self.expert_data_parallel_size, + ) + self.ep_pg = self.create_new_group(ep_ranks) + self.ep_tp_pg = self.create_new_group(ep_tp_ranks) + self.ep_pp_pg = self.create_new_group(ep_pp_ranks) + self.ep_dp_pg = self.create_new_group(ep_dp_ranks) + + _group_to_ranks[ParallelMode.EP] = ep_ranks + _group_to_ranks[ParallelMode.EP_TP] = ep_tp_ranks + _group_to_ranks[ParallelMode.EP_PP] = ep_pp_ranks + _group_to_ranks[ParallelMode.EP_DP] = ep_dp_ranks + self.parallel_ep_order = ["ep_dp", "ep_pp", "ep", "ep_tp"] - self.world_rank_matrix: np.ndarray = ranks - self.parallel_order = ["ep", "pp", "dp", "cp", "tp"] + self._group_to_ranks = _group_to_ranks + self.world_rank_matrix = attn_ranks def create_new_group(self, all_groups_ranks: np.ndarray) -> dist.ProcessGroup: dist.barrier() @@ -141,8 +262,16 @@ def set_device(self): def get_local_ranks(self, world_rank: int) -> Dict[str, int]: # return tuple(i.item() for i in np.where(self.world_rank_matrix == world_rank)) + # NOTE: return ep ranks local_ranks = np.where(self.world_rank_matrix == world_rank) - return {ax: local_ranks[i].item() for i, ax in enumerate(self.parallel_order)} + mappings = {ax: local_ranks[i].item() for i, ax in enumerate(self.parallel_order)} + if self.enabled_moe is True: + mappings["ep"] = dist.get_rank(self.ep_pg) + mappings["ep_dp"] = dist.get_rank(self.ep_dp_pg) + mappings["ep_pp"] = dist.get_rank(self.ep_pp_pg) + mappings["ep_tp"] = dist.get_rank(self.ep_tp_pg) + + return mappings def destroy(self): if not dist.is_initialized(): @@ -153,7 +282,6 @@ def destroy(self): def get_global_rank( self, - ep_rank: int, pp_rank: int, dp_rank: int, cp_rank: int, @@ -170,4 +298,4 @@ def get_global_rank( :return: numpy.int64, The global rank. """ - return self.world_rank_matrix[ep_rank, pp_rank, dp_rank, cp_rank, tp_rank] + return self.world_rank_matrix[dp_rank, pp_rank, cp_rank, tp_rank] diff --git a/src/nanotron/parallel/pipeline_parallel/engine.py b/src/nanotron/parallel/pipeline_parallel/engine.py index 141060013..fd9061dd2 100644 --- a/src/nanotron/parallel/pipeline_parallel/engine.py +++ b/src/nanotron/parallel/pipeline_parallel/engine.py @@ -9,6 +9,7 @@ from nanotron import logging from nanotron.distributed import ProcessGroup from nanotron.logging import log_rank +from nanotron.logging.timers import nanotron_timer from nanotron.optim.gradient_accumulator import GradientAccumulator from nanotron.parallel.data_parallel.utils import ddp_trigger_sync_in_bwd from nanotron.parallel.pipeline_parallel.context_manager import attach_pipeline_state_to_model @@ -16,8 +17,6 @@ from nanotron.parallel.pipeline_parallel.tensor_pointer import TensorPointer from nanotron.utils import ContextManagers -# from nanotron.logging.timers import nanotron_timer - logger = logging.get_logger(__name__) @@ -84,9 +83,11 @@ def backward( with context: if grad_accumulator is None: - sum(activations).backward() + # sum(activations).backward() + activations[0].backward() else: - grad_accumulator.backward(sum(activations)) + # grad_accumulator.backward(sum(activations)) + grad_accumulator.backward(activations[0]) # TODO @nouamane: this fixes interleaved afab but makes 1f1b hang # with context: @@ -289,10 +290,13 @@ def train_batch_iter( output = {k: v.detach() for k, v in output.items()} outputs.append(output) + nanotron_timer("Forward+backward", timer_type="cuda", cuda_sync=True).start() for micro_batch in batch: context = self._get_fwd_context(model=model) # with nanotron_timer("forward", timer_type="cuda"): + nanotron_timer("Forward", timer_type="cuda", cuda_sync=True).start() output = self.forward(context=context, state=state, micro_batch=micro_batch, model=model) + nanotron_timer("Forward", timer_type="cuda", cuda_sync=True).end() # We make `output` a dict if not isinstance(output, dict): @@ -310,7 +314,10 @@ def train_batch_iter( grad_accumulator=grad_accumulator, ) # with nanotron_timer("backward", timer_type="cuda"): + nanotron_timer("Backward", timer_type="cuda", cuda_sync=True).start() self.backward(context=context, state=state, grad_accumulator=grad_accumulator) + nanotron_timer("Backward", timer_type="cuda", cuda_sync=True).end() + nanotron_timer("Forward+backward", timer_type="cuda", cuda_sync=True).end() # Check figure in paper: The remain blocks are all backward and there is only `pg.size() - current_pp_rank - 1` blocks left assert len(state.microbatches_activations_requiring_backward) == pg.size() - current_pp_rank - 1 diff --git a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py index bd41347a8..71bd736a7 100644 --- a/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py +++ b/src/nanotron/parallel/tensor_parallel/distributed_differentiable_primitives.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional +from typing import List, Optional import torch from torch import distributed as torch_dist @@ -125,6 +125,43 @@ def backward(ctx, grad_output): return DifferentiableAllGather.apply(grad_output, group), None +class DifferentiableAllToAll(torch.autograd.Function): + """All to all in a differentiable fashion""" + + @staticmethod + def forward(ctx, input, output_split_sizes, input_split_sizes, group: Optional[ProcessGroup]): + ctx.group = group + ctx.output_split_sizes = output_split_sizes + ctx.input_split_sizes = input_split_sizes + if group.size() == 1: + return input + + # input = input.contiguous() # TODO: do we need this? + + if output_split_sizes is None: + # Equal split (all2all) + output = torch.empty_like(input) + else: + # Unequal split (all2all-v) + # NOTE: we all-to-all the first dimension, + # so the data has shape (sum(output_split_sizes), *input.shape[1:]) + output = torch.empty(sum(output_split_sizes), *input.shape[1:], device=input.device, dtype=input.dtype) + + dist.all_to_all_single( + output, input, output_split_sizes=output_split_sizes, input_split_sizes=input_split_sizes, group=group + ) + return output + + @staticmethod + def backward(ctx, grad_output): + return ( + DifferentiableAllToAll.apply(grad_output, ctx.input_split_sizes, ctx.output_split_sizes, ctx.group), + None, + None, + None, + ) + + # ----------------- # Helper functions. # ----------------- @@ -144,3 +181,12 @@ def differentiable_all_gather(tensor, group: Optional[ProcessGroup] = None): def differentiable_reduce_scatter_sum(tensor, group: Optional[ProcessGroup] = None): return DifferentiableReduceScatterSum.apply(tensor, group) + + +def all_to_all( + input: torch.Tensor, + output_split_sizes: Optional[List[int]] = None, + input_split_sizes: Optional[List[int]] = None, + group: Optional[ProcessGroup] = None, +): + return DifferentiableAllToAll.apply(input, output_split_sizes, input_split_sizes, group) diff --git a/src/nanotron/parallel/tied_parameters.py b/src/nanotron/parallel/tied_parameters.py index b2a1f7f52..3e60b4cf0 100644 --- a/src/nanotron/parallel/tied_parameters.py +++ b/src/nanotron/parallel/tied_parameters.py @@ -158,6 +158,11 @@ def sync_tied_weights_gradients( rank=0, ) key = (group_ranks, tied_info.reduce_op) + + # NOTE: quick hack + if grad_accumulator is None and tied_grad is not torch.bfloat16: + tied_grad = tied_grad.to(torch.bfloat16) + if key in group_ranks_and_reduce_op_to_tensors_to_reduce: group_ranks_and_reduce_op_to_tensors_to_reduce[(group_ranks, tied_info.reduce_op)].append(tied_grad) else: diff --git a/src/nanotron/scaling/parametrization.py b/src/nanotron/scaling/parametrization.py index 8f3062a93..e94039cb4 100644 --- a/src/nanotron/scaling/parametrization.py +++ b/src/nanotron/scaling/parametrization.py @@ -6,6 +6,8 @@ from nanotron.config import Config, ModelArgs from nanotron.config.models_config import InitScalingMethod from nanotron.nn.layer_norm import LlamaRMSNorm, TritonRMSNorm +from nanotron.nn.moe import GroupedMLP, Router +from nanotron.nn.te_moe import TEGroupedLinear, TopKRouter from nanotron.parallel.tensor_parallel.nn import ( TensorParallelColumnLinear, TensorParallelEmbedding, @@ -27,7 +29,6 @@ def __init__(self, config: ModelArgs): def parametrize(self, param_name: str, module: nn.Module): if not isinstance(module, tuple(self.MODULE_TO_PARAMETRIZE.keys())): raise Exception(f"Module {type(module)} with parameter {param_name} is not supported for initialization") - return self.MODULE_TO_PARAMETRIZE[type(module)](param_name, module) @@ -36,10 +37,17 @@ def __init__(self, config: Config): super().__init__(config) self.MODULE_TO_PARAMETRIZE = { TensorParallelColumnLinear: self._parametrize_column_linear, + # TODO: double check if correct initialization for grouped MLP TensorParallelRowLinear: self._parametrize_row_linear, TritonRMSNorm: self._parametrize_layer_norm, LlamaRMSNorm: self._parametrize_layer_norm, TensorParallelEmbedding: self._parametrize_embedding, + # NOTE: MoE's specific initialization + GroupedMLP: self._parametrize_grouped_mlp, + Router: self._parametrize_router, + nn.Linear: self._parametrize_column_linear, + TopKRouter: self._parametrize_router, + TEGroupedLinear: self._parametrize_grouped_mlp, } self.std = config.model.init_method.std @@ -49,12 +57,44 @@ def __init__(self, config: Config): self.hidden_size = config.model.model_config.hidden_size def _parametrize_column_linear(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + # assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) - if "weight" == param_name: + if "weight" in param_name: # TODO @nouamane: should we use trunc_normal_ init.normal_(module.weight, mean=0.0, std=self.std) - elif "bias" == param_name: + elif "bias" in param_name: + module.bias.zero_() + + def _parametrize_grouped_mlp(self, param_name: str, module: nn.Module): + COLUMN_LINEAR_PARAMS = [ + "gate_up_proj", # nanotron's moe modeling + "linear_fc1", # TEGroupedLinear's first linear layer + ] + ROW_LINEAR_PARAMS = [ + "down_proj", # nanotron's moe modeling + "linear_fc2", # TEGroupedLinear's second linear layer + ] + + for n, p in module.named_parameters(): + if any(x in param_name for x in COLUMN_LINEAR_PARAMS) or ( + "weight" in n and module.parallel_mode == "column" + ): + # NOTE: the same as parametrization of column linear + init.normal_(p, mean=0.0, std=self.std) + elif any(x in param_name for x in ROW_LINEAR_PARAMS) or ("weight" in n and module.parallel_mode == "row"): + # NOTE: the same as parametrization of row linear + scaling = self._compute_scaling_factor() + adjusted_std = self.std / scaling + # TODO @nouamane: should we use trunc_normal_ + init.normal_(p, mean=0.0, std=adjusted_std) + else: + raise ValueError(f"Unknown parameter {n}") + + def _parametrize_router(self, param_name: str, module: nn.Module): + if "weight" in param_name: + init.normal_(module.weight, mean=0.0, std=self.std) + elif "bias" in param_name: module.bias.zero_() def _compute_scaling_factor(self) -> float: @@ -71,28 +111,29 @@ def _compute_scaling_factor(self) -> float: raise ValueError(f"Invalid scaling method: {self.scaling_method}") def _parametrize_row_linear(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + # assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) - if "weight" == param_name: + if "weight" in param_name: scaling = self._compute_scaling_factor() adjusted_std = self.std / scaling # TODO @nouamane: should we use trunc_normal_ init.normal_(module.weight, mean=0.0, std=adjusted_std) - elif "bias" == param_name: + elif "bias" in param_name: module.bias.zero_() def _parametrize_layer_norm(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) - if "weight" == param_name: + if "weight" in param_name: module.weight.fill_(1) - elif "bias" == param_name: + elif "bias" in param_name: module.bias.zero_() def _parametrize_embedding(self, param_name: str, module: nn.Module): - assert param_name in ["weight"] + assert "weight" in param_name - if "weight" == param_name: + if "weight" in param_name: init.normal_(module.weight, mean=0.0, std=self.std) @@ -123,9 +164,11 @@ def _compute_spectral_std(std: float, fan_in: int, fan_out: int): return (std / math.sqrt(fan_in)) * min(1, math.sqrt(fan_out / fan_in)) def _parametrize_mup_weight(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + # assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) - data = module.weight if param_name == "weight" else module.bias + # data = module.weight if param_name == "weight" else module.bias + data = module.weight if "weight" in param_name else module.bias fan_in, fan_out = init._calculate_fan_in_and_fan_out(data) world_size = module.world_size @@ -140,20 +183,22 @@ def _parametrize_mup_weight(self, param_name: str, module: nn.Module): init.normal_(data, mean=0.0, std=std) def _parametrize_layer_norm(self, param_name: str, module: nn.Module): - assert param_name in ["weight", "bias"] + # assert param_name in ["weight", "bias"] + assert any(x in param_name for x in ["weight", "bias"]) # NOTE: you're free to change the initialization of layer norm # as it's not a part of µTransfer - if "weight" == param_name: + if "weight" in param_name: module.weight.fill_(1) - elif "bias" == param_name: + elif "bias" in param_name: module.bias.zero_() def _parametrize_embedding(self, param_name: str, module: nn.Module): - assert param_name in ["weight"] + # assert param_name in ["weight"] + assert "weight" in param_name # NOTE: you're free to change the initialization of input embedding/lm head - if "weight" == param_name: + if "weight" in param_name: init.normal_(module.weight, mean=0.0, std=self.std) diff --git a/src/nanotron/serialize/metadata.py b/src/nanotron/serialize/metadata.py index 5812bfd1d..99633746b 100644 --- a/src/nanotron/serialize/metadata.py +++ b/src/nanotron/serialize/metadata.py @@ -13,23 +13,31 @@ from nanotron.constants import CHECKPOINT_FILE_NAME, CHECKPOINT_VERSION from nanotron.parallel import ParallelContext from nanotron.parallel.parameters import SlicesPair - +from collections import defaultdict @dataclasses.dataclass class DataStageMetadata: """ - consumed_train_samples: The number of samples consumed by the model in the this stage (each stage starts from zero). - last_train_step: The last training step across all stages. - - # NOTE: we should allow people to change the name of the data stages in the config file. - # but not the start_training_step, because it could + consumed_train_samples: The number of samples consumed by the model in the this stage (resets at each stage). + consumed_tokens_per_dataset_folder: The number of tokens consumed by the model in the this stage for each dataset folder. (resets at each stage) """ name: str start_training_step: int - consumed_train_samples: int - consumed_tokens_per_dataset_folder: Dict[str, int] = dataclasses.field(default_factory=dict) + consumed_train_samples: int # We use this for sampler, and it's reset at each stage + sequence_length: Optional[int] = None # TODO: put back as non-optional + consumed_tokens_per_dataset_folder: Dict[str, int] = dataclasses.field(default_factory=dict) # this gets reset at each stage + + def __post_init__(self): + if self.sequence_length is None: + self.sequence_length = 4096 + + def sanity_consumed_train_samples(self): + assert self.consumed_train_samples*self.sequence_length == sum(self.consumed_tokens_per_dataset_folder.values()), f"Mismatch between the total consumed samples and the sum of consumed samples across dataset folders! consumed_train_samples={self.consumed_train_samples}, sequence_length={self.sequence_length}, consumed_tokens_per_dataset_folder={self.consumed_tokens_per_dataset_folder}" + @property + def consumed_tokens_all_datasets(self): + return sum(self.consumed_tokens_per_dataset_folder.values()) @dataclasses.dataclass class TrainingMetadata: @@ -40,8 +48,9 @@ class TrainingMetadata: data_stages: The metadata for each stage. """ - consumed_train_samples: int + consumed_train_samples: int # TODO: Legacy. This assumed same sequence length across all stages. Not used anymore last_train_step: int + consumed_tokens_total: Optional[int] = None # TODO: put back as non-optional # TODO(xrsrke): make this not optional, once we entirely remove # the old checkpoint version @@ -50,15 +59,31 @@ class TrainingMetadata: def __post_init__(self): # NOTE: this is a sanity check after loading a trained checkpoint - total_consumed_samples_across_stages = sum(stage.consumed_train_samples for stage in self.data_stages) assert ( - self.consumed_train_samples == total_consumed_samples_across_stages + self.consumed_train_samples == sum(stage.consumed_train_samples for stage in self.data_stages) ), "Mismatch between the total consumed samples and the sum of consumed samples across stages! Something went wrong in the training." + if self.consumed_tokens_total is not None: + assert self.consumed_tokens_total == sum(stage.consumed_tokens_all_datasets for stage in self.data_stages), "Mismatch between the total consumed tokens and the sum of consumed tokens across stages! Something went wrong in the training." + else: + self.consumed_tokens_total = sum(stage.consumed_tokens_all_datasets for stage in self.data_stages) + # TODO(xrsrke): remove this once we entirely remove non-data-stage training if self.last_stage_idx is not None: assert self.data_stages is not None, "data_stages should not be None if last_stage_idx is not None" + @property + def consumed_tokens_per_dataset_folder_total(self): + consumed = defaultdict(int) + for stage in self.data_stages: + for dataset_folder, tokens in stage.consumed_tokens_per_dataset_folder.items(): + consumed[dataset_folder] += tokens + return consumed + + @property + def current_stage(self) -> DataStageMetadata: + return self.data_stages[self.last_stage_idx] + @dataclasses.dataclass class CheckpointMetadata: diff --git a/src/nanotron/serialize/optimizer.py b/src/nanotron/serialize/optimizer.py index fc71a237c..b2372705d 100644 --- a/src/nanotron/serialize/optimizer.py +++ b/src/nanotron/serialize/optimizer.py @@ -47,7 +47,10 @@ def save_optimizer( - If Zero-0 is used, optimizer states are replicated across all DPs. Only DP-0 saves the states - If Zero-1 is used, optimizer states are sharded across all DPs. Each DP saves its own states """ - if (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank(parallel_context.dp_pg) > 0: + is_first_replicas_in_zero0 = (not optimizer.inherit_from(optim.ZeroDistributedOptimizer)) and dist.get_rank( + parallel_context.dp_pg + ) > 0 + if is_first_replicas_in_zero0 and dist.get_rank(parallel_context.ep_dp_pg) > 1: # this is Zero-0, so only DP-0 saves the optimizer states return @@ -61,7 +64,6 @@ def save_optimizer( tp_size = parallel_context.tp_pg.size() pp_size = parallel_context.pp_pg.size() dp_size = parallel_context.dp_pg.size() - expert_parallel_size = parallel_context.expert_parallel_size config = { "type": str(optimizer.__class__.__name__), @@ -69,11 +71,17 @@ def save_optimizer( "tp_size": str(tp_size), "dp_size": str(dp_size), "pp_size": str(pp_size), - "expert_parallel_size": str(expert_parallel_size), }, "configs": {}, } + if parallel_context.enabled_moe is True: + config["parallelism"]["expert_parallel_size"] = str(parallel_context.expert_parallel_size) + config["parallelism"]["expert_tensor_parallel_size"] = str( + parallel_context.expert_tensor_parallel_size + ) + config["parallelism"]["expert_data_parallel_size"] = str(parallel_context.expert_data_parallel_size) + if isinstance(optimizer, ZeroDistributedOptimizer): # NOTE: in order to serialize, we must save all keys and values as strings def convert_to_string(input_item): @@ -113,7 +121,7 @@ def save_lr_scheduler( root_folder: Path, ): """Saves lr scheduler states""" - if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0: + if not is_zero and dist.get_rank(parallel_context.dp_pg) > 0 and dist.get_rank(parallel_context.ep_dp_pg) > 0: # this is Zero-0, so only DP-0 saves the optimizer states return diff --git a/src/nanotron/serialize/utils.py b/src/nanotron/serialize/utils.py index 9a5348f49..3cfc69fb7 100644 --- a/src/nanotron/serialize/utils.py +++ b/src/nanotron/serialize/utils.py @@ -1,4 +1,5 @@ import re +from dataclasses import dataclass from enum import Enum from pathlib import Path from typing import List, Optional, Tuple @@ -16,38 +17,77 @@ class ObjectType(Enum): LR_SCHEDULER = "lr_scheduler" +@dataclass +class CheckpointParallelRanks: + """ + ep_rank is optional because it's only applicable to moe params + + NOTE: because a non-moe has no ep_rank, we need to pass it as an optional + """ + + pp_rank: int + pp_world_size: int + + tp_rank: int + tp_world_size: int + + ep_rank: Optional[int] = None + ep_world_size: Optional[int] = None + + def get_exp_tp_pp_rank_and_size_from( world_rank: int, parallel_context: ParallelContext ) -> Tuple[Tuple[int, int], Tuple[int, int]]: result = parallel_context.get_local_ranks(world_rank=world_rank) - return ( - (result["ep"], parallel_context.ep_pg.size()), - (result["tp"], parallel_context.tp_pg.size()), - (result["pp"], parallel_context.pp_pg.size()), - ) + + if parallel_context.enabled_moe is True: + return CheckpointParallelRanks( + pp_rank=result["pp"], + pp_world_size=parallel_context.pp_pg.size(), + tp_rank=result["tp"], + tp_world_size=parallel_context.tp_pg.size(), + ep_rank=result["ep"], + ep_world_size=parallel_context.ep_pg.size(), + ) + else: + return CheckpointParallelRanks( + pp_rank=result["pp"], + pp_world_size=parallel_context.pp_pg.size(), + tp_rank=result["tp"], + tp_world_size=parallel_context.tp_pg.size(), + ) def get_path( tensor_name: str, type: ObjectType, - exp_tp_pp_rank_and_size: Tuple[Tuple[int, int], Tuple[int, int]], + checkpoint_parallel_ranks: CheckpointParallelRanks, is_expert_sharded: bool, prefix: Optional[Path] = None, ) -> List[str]: + def get_checkpoint_name_from_parallel_ranks(checkpoint_parallel_ranks: CheckpointParallelRanks, is_moe: bool): + pp_rank = checkpoint_parallel_ranks.pp_rank + pp_size = checkpoint_parallel_ranks.pp_world_size + + tp_rank = checkpoint_parallel_ranks.tp_rank + tp_size = checkpoint_parallel_ranks.tp_world_size + + ep_rank = checkpoint_parallel_ranks.ep_rank + ep_size = checkpoint_parallel_ranks.ep_world_size + + if not is_moe: + return f"pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}" + else: + return f"pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{ep_rank}-of-{ep_size}" + suffix = tensor_name.split(".") suffix_path, suffix_name = suffix[:-1], suffix[-1] - if exp_tp_pp_rank_and_size: + if checkpoint_parallel_ranks: # We always show pp_rank and tp_rank if `exp_tp_pp_rank_and_size` is provided - (exp_rank, exp_size), (tp_rank, tp_size), (pp_rank, pp_size) = exp_tp_pp_rank_and_size - if not is_expert_sharded or exp_size == 1: - suffix_name = ( - f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}.safetensors" - ) - else: - # We only show exp_rank if tensor is exp_sharded and exp_size > 1 - suffix_name = f"{type.value}_{suffix_name}_pp-rank-{pp_rank}-of-{pp_size}_tp-rank-{tp_rank}-of-{tp_size}_exp-rank-{exp_rank}-of-{exp_size}.safetensors" + suffix_name = f"{type.value}_{suffix_name}_{get_checkpoint_name_from_parallel_ranks(checkpoint_parallel_ranks, is_moe=is_expert_sharded)}.safetensors" else: + # NOTE: for params that aren't sharded, we don't need to add any parallel ranks suffix_name = f"{type.value}_{suffix_name}.safetensors" suffix_path.append(suffix_name) @@ -58,10 +98,15 @@ def get_path( def extract_tp_pp_rank_from_shard_path(shard_path: Path): - pattern = r"pp-rank-(\d+)-of-\d+_tp-rank-(\d+)-of-\d+" + from nanotron.serialize.utils import CheckpointParallelRanks + + pattern = r"pp-rank-(\d+)-of-(\d+)_tp-rank-(\d+)-of-(\d+)" match = re.search(pattern, str(shard_path)) - pp_rank, tp_rank = match.groups() - return pp_rank, tp_rank + pp_rank, pp_size, tp_rank, tp_size = match.groups() + + return CheckpointParallelRanks( + pp_rank=int(pp_rank), pp_world_size=int(pp_size), tp_rank=int(tp_rank), tp_world_size=int(tp_size) + ) def merge_and_shard_tp_tensors( @@ -82,3 +127,14 @@ def merge_and_shard_tp_tensors( buffer[local_slices] = unsharded_buffer[global_slices] return buffer + + +def merge_and_shard_ep_tensors(): + # ep_rank = dist.get_rank(parallel_context.ep_pg) + # _data = fi.get_tensor("data") + # _num_experts = _data.shape[0] + # _num_local_experts = _num_experts // parallel_context.ep_pg.size() + # _start_idx = ep_rank * _num_local_experts + # _end_idx = _start_idx + _num_local_experts + # param_or_buffer[:] = _data[_start_idx:_end_idx, :, :] + pass diff --git a/src/nanotron/serialize/weights.py b/src/nanotron/serialize/weights.py index 4c16d3a6d..2b980889b 100644 --- a/src/nanotron/serialize/weights.py +++ b/src/nanotron/serialize/weights.py @@ -32,7 +32,7 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde # We save only `dist.get_rank(parallel_context.dp_pg) == 0` # TODO @thomasw21: Figure how this works with Zero-3 - if dist.get_rank(parallel_context.dp_pg) != 0: + if dist.get_rank(parallel_context.dp_pg) != 0 and dist.get_rank(parallel_context.ep_dp_pg) != 0: return module_id_to_prefix = {id(module): f"{module_name}." for module_name, module in model.named_modules()} @@ -41,6 +41,9 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde # We chunk everything by `tp_world_size` in order to make sure that we gather all the weights into a single device before saving it for name, param_or_buffer in tqdm(model.state_dict().items(), desc="Saving weights"): + # NOTE: skipping TE's extra_state + if "_extra_state" in name: + continue # exp_rank=0 saves all weights whereas exp_rank>0 save only MLP weights if dist.get_rank(parallel_context.ep_pg) != 0: @@ -71,7 +74,7 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde if param.is_sharded: sharded_info: ShardedInfo = param.get_sharded_info() group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks] - exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from( + checkpoint_parallel_ranks = get_exp_tp_pp_rank_and_size_from( world_rank=get_global_rank(group=group, group_rank=dist.get_rank(group)), parallel_context=parallel_context, ) @@ -81,15 +84,18 @@ def save_weights(model: nn.Module, parallel_context: ParallelContext, root_folde local_global_slices_pairs=sharded_info.local_global_slices_pairs, unsharded_shape=sharded_info.unsharded_shape, ).to_str_dict() - else: - exp_tp_pp_rank_and_size = None + checkpoint_parallel_ranks = None is_expert_sharded = False + # NOTE: we only save weights of the first replicas of dense + if dist.get_rank(parallel_context.dp_pg) != 0 and is_expert_sharded is False: + continue + path = get_path( base_name, type=ObjectType.MODEL, - exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size, + checkpoint_parallel_ranks=checkpoint_parallel_ranks, is_expert_sharded=is_expert_sharded, prefix=root_folder, ) @@ -170,7 +176,9 @@ def load_sharded_param_latest( if param_shard_metadata is not None: # NOTE: store how does model parameter are sharded # so that we can shard optimizer checkpoints in this way - pp_rank, tp_rank = extract_tp_pp_rank_from_shard_path(shard_path) + checkpoint_parallel_ranks = extract_tp_pp_rank_from_shard_path(shard_path) + pp_rank = checkpoint_parallel_ranks.pp_rank + tp_rank = checkpoint_parallel_ranks.tp_rank param_shard_metadata[(pp_rank, tp_rank)] = param_metadata assert checkpoint_unsharded_shape is not None @@ -214,6 +222,10 @@ def load_weights( for name, param_or_buffer in tqdm( filtered_state_dict.items(), disable=dist.get_rank(parallel_context.world_pg) != 0, desc="Loading weights" ): + # NOTE: skipping TE's extra_state + if "_extra_state" in name: + continue + # NOTE: extract how does the current model parameter are sharded # so that we can load optimizer checkpoints in this way param_shard_metadata[name] = {} @@ -241,19 +253,19 @@ def load_weights( group = parallel_context.world_ranks_to_pg[sharded_info.global_ranks] group_rank = dist.get_rank(group) - exp_tp_pp_rank_and_size = get_exp_tp_pp_rank_and_size_from( + checkpoint_parallel_ranks = get_exp_tp_pp_rank_and_size_from( world_rank=get_global_rank(group=group, group_rank=group_rank), parallel_context=parallel_context ) # TODO @nouamane: do we consider exp_size=1 expert_sharded? is_expert_sharded = sharded_info.is_expert_sharded(parallel_context) else: - exp_tp_pp_rank_and_size = None + checkpoint_parallel_ranks = None is_expert_sharded = False path = get_path( base_name, type=ObjectType.MODEL, - exp_tp_pp_rank_and_size=exp_tp_pp_rank_and_size, + checkpoint_parallel_ranks=checkpoint_parallel_ranks, prefix=param_root_folder, is_expert_sharded=is_expert_sharded, ) diff --git a/src/nanotron/trainer.py b/src/nanotron/trainer.py index 00c26943b..b61131801 100644 --- a/src/nanotron/trainer.py +++ b/src/nanotron/trainer.py @@ -172,8 +172,11 @@ def __init__( tensor_parallel_size=self.config.parallelism.tp, pipeline_parallel_size=self.config.parallelism.pp, data_parallel_size=self.config.parallelism.dp, - expert_parallel_size=self.config.parallelism.expert_parallel_size, context_parallel_size=self.config.parallelism.context_parallel_size, + expert_parallel_size=self.config.parallelism.expert_parallel_size, + expert_tensor_parallel_size=self.config.parallelism.expert_tensor_parallel_size, + expert_data_parallel_size=self.config.parallelism.expert_data_parallel_size, + enabled_moe=self.config.parallelism.enabled_moe, ) self.pre_init() @@ -245,24 +248,31 @@ def __init__( assert isinstance(checkpoint_metadata.metas, TrainingMetadata) log_rank(str(checkpoint_metadata), logger=logger, level=logging.INFO, rank=0) self.metadata: TrainingMetadata = checkpoint_metadata.metas - # NOTE: we should not change data stages + # In case of a new datastage, metadata will be updated in `get_dataloader` assert ( self.config.tokens.train_steps > self.metadata.last_train_step ), f"Loaded checkpoint has already trained {self.metadata.last_train_step} batches, you need to specify a higher `config.tokens.train_steps`" else: data_stages = [ DataStageMetadata( - name=stage.name, start_training_step=stage.start_training_step, consumed_train_samples=0 + name=stage.name, + start_training_step=stage.start_training_step, + consumed_train_samples=0, + sequence_length=stage.sequence_length, ) for stage in self.config.data_stages ] self.metadata: TrainingMetadata = TrainingMetadata( - consumed_train_samples=0, last_train_step=0, last_stage_idx=0, data_stages=data_stages + consumed_train_samples=0, + consumed_tokens_total=0, + last_train_step=0, + last_stage_idx=0, + data_stages=data_stages, ) # Setup tensorboard write and log writers on output rank self.logger_ranks = self.parallel_context.get_global_rank( - ep_rank=0, pp_rank=self.unwrapped_model.output_pp_rank, dp_rank=0, tp_rank=0, cp_rank=0 + pp_rank=self.unwrapped_model.output_pp_rank, dp_rank=0, tp_rank=0, cp_rank=0 ).flatten() self.loggerwriter = self.setup_log_writers() @@ -278,6 +288,7 @@ def __init__( self.limit_val_batches = self.config.tokens.limit_val_batches self.current_dataloader: Optional[DataLoader] = None # used for the current training stage self.current_base_dl: Optional[DataLoader] = None # used for the current training stage + self.iteration_timer = None # Will be initialized during training log_libraries_versions(logger=logger) log_rank("Config:", logger=logger, level=logging.INFO, rank=0, is_separator=True) @@ -339,7 +350,7 @@ def pre_training(self, *args, **kwargs): log_rank("Start training", logger=logger, level=logging.INFO, rank=0, is_separator=True) log_rank( - f"mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | sequence_length: {self.sequence_length} | global_batch_size: {self.global_batch_size} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_train_samples: {metadata.consumed_train_samples}", # noqa + f"mbs: {self.micro_batch_size} | grad_accum: {self.n_micro_batches_per_batch} | sequence_length: {self.sequence_length} | global_batch_size: {self.global_batch_size} | train_steps: {self.config.tokens.train_steps} | start_iteration_step: {metadata.last_train_step} | consumed_tokens_total: {metadata.consumed_tokens_total}", # noqa logger=logger, level=logging.INFO, rank=0, @@ -450,9 +461,6 @@ def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[Da return assert len(dataloaders) > 0, "No dataloaders provided" - assert len(dataloaders) == len( - self.config.data_stages - ), "Number of dataloaders should match the number of dataset stages" def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str): import gc @@ -565,22 +573,31 @@ def train( prof.step() self.iteration_start_time = time.time() + nanotron_timer("update_dataloader", "cuda", cuda_sync=True).start() self._update_dataloader_based_on_training_stages(dataloader_or_dls) + nanotron_timer("update_dataloader", "cuda", cuda_sync=True).end() # Training step + nanotron_timer("training_step", "cuda", cuda_sync=True).start() outputs, loss_avg, z_loss_avg = self.training_step(dataloader=self.current_dataloader) + nanotron_timer("training_step", "cuda", cuda_sync=True).end() # Update consumption tracking for current batch - if hasattr(self.current_base_dl, "dataset"): + nanotron_timer("update_consumption_metrics", "cuda", cuda_sync=True).start() + if hasattr(self.current_base_dl, "dataset") and hasattr( + self.current_base_dl.dataset, "update_consumption_metrics" + ): + # TODO: only works for BlendableDataset self.current_base_dl.dataset.update_consumption_metrics( start_idx=(self.iteration_step - 1) * self.global_batch_size, # assumes we start from iteration_step=1 end_idx=self.iteration_step * self.global_batch_size, sequence_length=self.sequence_length, ) - + nanotron_timer("update_consumption_metrics", "cuda", cuda_sync=True).end() # Training Logs # Track consumed tokens for all dataset folders in current stage + nanotron_timer("update_consumption_metrics_2", "cuda", cuda_sync=True).start() if hasattr(self.current_base_dl, "dataset"): consumption_stats = self.current_base_dl.dataset.get_consumption_stats() current_stage = self.metadata.data_stages[self.metadata.last_stage_idx] @@ -588,16 +605,22 @@ def train( # Update consumed tokens for all folders in the consumption stats for folder_path, stats in consumption_stats.items(): current_stage.consumed_tokens_per_dataset_folder[folder_path] = stats["tokens"] - + nanotron_timer("update_consumption_metrics_2", "cuda", cuda_sync=True).end() # Original consumption tracking - self.metadata.consumed_train_samples += self.global_batch_size + self.metadata.consumed_train_samples += self.global_batch_size # TODO: Legacy: idc abt this + self.metadata.consumed_tokens_total += self.global_batch_size * self.sequence_length self.metadata.last_train_step = self.iteration_step - self.metadata.data_stages[ - self.metadata.last_stage_idx - ].consumed_train_samples += self.global_batch_size + self.metadata.current_stage.consumed_train_samples += self.global_batch_size + assert ( + self.metadata.current_stage.sequence_length == self.sequence_length + ), "Sequence length mismatch between the current stage and the global sequence length" if (self.iteration_step - 1) % self.config.logging.iteration_step_info_interval == 0: - self.train_step_logs(outputs=outputs, loss_avg=loss_avg, z_loss_avg=z_loss_avg) + self.train_step_logs( + outputs=outputs, + loss_avg=loss_avg, + z_loss_avg=z_loss_avg, + ) # Checkpoint if self.iteration_step % self.config.checkpoints.checkpoint_interval == 0: @@ -613,6 +636,8 @@ def train( def training_step( self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]] ) -> Tuple[Iterable[Dict], Optional[torch.Tensor]]: + # dist.barrier() + # log_rank(f"training_step {self.iteration_step}", logger=logger, level=logging.INFO) before_tbi_sanity_checks( self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator, self.lr_scheduler ) @@ -620,7 +645,7 @@ def training_step( if self.iteration_step < self.initial_iter_step + 5: log_memory(logger=logger, msg="Before train_batch_iter") - nanotron_timer("train_batch_iter", "cuda").start() + nanotron_timer("train_batch_iter", "cuda", cuda_sync=True).start() with torch.profiler.record_function("train_batch_iter"): outputs = self.pipeline_engine.train_batch_iter( model=self.model, @@ -629,10 +654,10 @@ def training_step( nb_microbatches=self.n_micro_batches_per_batch, grad_accumulator=self.grad_accumulator, ) - nanotron_timer("train_batch_iter", "cuda").end() + nanotron_timer("train_batch_iter", "cuda", cuda_sync=True).end() - if self.iteration_step < self.initial_iter_step + 5: - log_memory(logger=logger, msg="After train_batch_iter") + # if self.iteration_step < self.initial_iter_step + 5: + # log_memory(logger=logger, msg="After train_batch_iter") after_tbi_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) @@ -647,6 +672,8 @@ def training_step( else: self.grad_accumulator.fp32_grads_allreduce_handle.wait() + # dist.barrier() + # log_rank(f"sync_gradients {self.iteration_step}", logger=logger, level=logging.INFO) nanotron_timer("sync_gradients", "cuda").start() # Sync tied weights if not isinstance(self.model, DistributedDataParallel): @@ -697,6 +724,7 @@ def training_step( ).sum() # already divided by n_micro_batches_per_batch else: z_loss_avg = None + # sync loss across DP (we should do the same for z_loss but it's only for logging so let's not sync it rn) handle = dist.all_reduce(loss_avg, group=self.parallel_context.dp_pg, async_op=True, op=dist.ReduceOp.AVG) else: @@ -719,19 +747,31 @@ def training_step( # Apply gradient nanotron_timer("optimizer_step", "cuda").start() self.optimizer.step() + # dist.barrier() + # log_rank(f"optimizer_step {self.iteration_step}", logger=logger, level=logging.INFO) self.optimizer.zero_grad() nanotron_timer("optimizer_step", "cuda").end() + # dist.barrier() + # log_rank(f"zero_grad {self.iteration_step}", logger=logger, level=logging.INFO) # Update the learning rate + nanotron_timer("lr_scheduler_step", "cuda").start() self.lr_scheduler.step() + nanotron_timer("lr_scheduler_step", "cuda").end() after_optim_step_sanity_checks(self.config, self.parallel_context, self.unwrapped_model, self.grad_accumulator) + # dist.barrier() + # log_rank(f"handle.wait {self.iteration_step}", logger=logger, level=logging.INFO) if handle is not None: handle.wait() + # dist.barrier() + # print(f"post_train_step {self.iteration_step}") self.post_train_step() + # TODO: return a dataclass instead of a list of tensors, + # it's more readable return outputs, loss_avg, z_loss_avg def validation_step(self, dataloader: Iterator[Dict[str, Union[torch.Tensor, TensorPointer]]]) -> Iterable[Dict]: @@ -749,8 +789,12 @@ def train_step_logs( z_loss_avg: Optional[torch.Tensor], ) -> None: # TODO @nouamanetazi: Megatron-LM seems to be using a barrier to report their interval time. Check if this is necessary. https://github.com/NouamaneTazi/Megatron-LM/blob/e241a96c3085b18e36c6cee1d68a8155de77b5a6/megatron/training.py#L607 + nanotron_timer("train_step_logs_barrier", "cuda").start() dist.barrier() + nanotron_timer("train_step_logs_barrier", "cuda").end() + nanotron_timer("train_step_logs_sync", "cuda").start() torch.cuda.synchronize() + nanotron_timer("train_step_logs_sync", "cuda").end() elapsed_time_per_iteration_ms = (time.time() - self.iteration_start_time) * 1000 tokens_per_sec = ( self.global_batch_size * self.sequence_length / (elapsed_time_per_iteration_ms / 1000) @@ -1114,14 +1158,40 @@ def _init_model( total_params = total_params.item() self.num_params = {"total": total_params, "local": num_params} - # TODO @nouamanetazi: better memory logs + # Compute active parameters for MoE + if config.model.model_config.is_moe_model: + from nanotron.nn.moe import is_expert_param + + expert_params = sum(p.numel() for n, p in model.named_parameters() if is_expert_param(n)) + non_expert_params = num_params - expert_params + active_params = ( + non_expert_params + + expert_params + * config.model.model_config.moe_config.top_k + / config.model.model_config.moe_config.num_experts + ) + active_params_t = torch.tensor(active_params, device="cuda") + dist.all_reduce(active_params_t, group=parallel_context.ep_pg) + dist.all_reduce(active_params_t, group=parallel_context.ep_pp_pg) + self.num_params["active"] = active_params_t.item() + log_rank( - f"Total number of parameters: {human_format(total_params)} ({total_size.item() / 1024**2:.2f}MiB)", + f"Total number of parameters: {human_format(total_params)} ({total_size.item() / 1024**2:.2f}MiB)\n", logger=logger, level=logging.INFO, group=parallel_context.world_pg, rank=0, ) + + if config.model.model_config.is_moe_model: + log_rank( + f"Active parameters: {human_format(self.num_params['active'])}", + logger=logger, + level=logging.INFO, + group=parallel_context.world_pg, + rank=0, + ) + log_rank( f"Local number of parameters: {human_format(num_params)} ({size_params / 1024**2:.2f}MiB)", logger=logger, @@ -1129,6 +1199,8 @@ def _init_model( group=parallel_context.dp_pg, rank=0, ) + # TODO @nouamanetazi: better memory logs + log_rank( f"[After model building] Memory usage: {torch.cuda.memory_allocated() / 1024**2:.2f}MiB." f" Peak allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f}MiB" @@ -1175,7 +1247,7 @@ def setup_log_writers( def pre_save_checkpoint(self) -> Path: # Check if eval_interval should be updated from file - eval_interval_file = self.config.lighteval.eval_interval_file + eval_interval_file = self.config.lighteval.eval_interval_file if self.config.lighteval is not None else None if eval_interval_file is not None and Path(eval_interval_file).exists(): try: with open(eval_interval_file, "r") as f: @@ -1255,8 +1327,11 @@ def save_checkpoint(self) -> Path: model=self.unwrapped_model, optimizer=self.optimizer, lr_scheduler=self.lr_scheduler, + # NOTE: we save a model weights if + # 1. the first replicas of dense + # 2. the first replicas of moe's experts should_save_model=bool( - dist.get_rank(self.parallel_context.dp_pg) == 0 + dist.get_rank(self.parallel_context.dp_pg) == 0 or dist.get_rank(self.parallel_context.ep_dp_pg) == 0 ), # We only save the weights on DP==0 should_save_optimizer=True, should_save_lr_scheduler=True, @@ -1305,7 +1380,6 @@ def mark_tied_parameters( target, ( parallel_context.get_global_rank( - ep_rank=dist.get_rank(parallel_context.ep_pg), pp_rank=get_pp_rank_of(target, module=model), dp_rank=dist.get_rank(parallel_context.dp_pg), cp_rank=dist.get_rank(parallel_context.cp_pg), @@ -1315,6 +1389,7 @@ def mark_tied_parameters( ) for target in embeddings_lm_head_tied_names ] + tie_parameters( root_module=model, ties=shared_embeddings, parallel_context=parallel_context, reduce_op=dist.ReduceOp.SUM ) @@ -1325,7 +1400,7 @@ def mark_tied_parameters( # Sync all parameters that have the same name and that are not sharded across TP and EXP assert not isinstance(model, DistributedDataParallel), "model shouldn't be DDP at this point" mark_unsharded_params_as_tied_across_tp(model, parallel_context, parallel_config) - mark_unsharded_params_as_tied_across_expert(model, parallel_context, parallel_config) + # mark_unsharded_params_as_tied_across_expert(model, parallel_context, parallel_config) create_pg_for_tied_weights(root_module=model, parallel_context=parallel_context) @@ -1385,6 +1460,8 @@ def mark_unsharded_params_as_tied_across_expert( if param.is_sharded: sharded_info = param.get_sharded_info() + + # TODO: double check and remove if necessary if sharded_info.is_expert_sharded(parallel_context): continue diff --git a/test_timer_decorator.py b/test_timer_decorator.py new file mode 100644 index 000000000..b900f33d9 --- /dev/null +++ b/test_timer_decorator.py @@ -0,0 +1,63 @@ +"""Test script for the timer decorator with both CPU and CUDA timer types.""" + +from nanotron.logging.timers import nanotron_timer, TimerType +import time +import torch + +# Enable timers for testing +nanotron_timer.enable() + +# Test with default CUDA timing +@nanotron_timer +def test_default_decorator(): + """Test function with default CUDA timing.""" + # Simulate some work + time.sleep(0.1) + if torch.cuda.is_available(): + x = torch.randn(1000, 1000, device="cuda") + y = torch.matmul(x, x) + torch.cuda.synchronize() + return "Done" + +# Test with explicit CUDA timing +@nanotron_timer(timer_type=TimerType.CUDA) +def test_cuda_decorator(): + """Test function with explicit CUDA timing.""" + # Simulate some work + time.sleep(0.1) + if torch.cuda.is_available(): + x = torch.randn(1000, 1000, device="cuda") + y = torch.matmul(x, x) + torch.cuda.synchronize() + return "Done" + +# Test with CPU timing +@nanotron_timer(timer_type=TimerType.CPU) +def test_cpu_decorator(): + """Test function with CPU timing.""" + # Simulate some CPU work + time.sleep(0.2) + return "Done" + +# Test with custom name +@nanotron_timer("custom_name") +def test_custom_name_decorator(): + """Test function with custom name.""" + # Simulate some work + time.sleep(0.1) + return "Done" + +if __name__ == "__main__": + print("Testing timer decorators...") + + # Run the test functions + test_default_decorator() + test_cuda_decorator() + test_cpu_decorator() + test_custom_name_decorator() + + # Log all timers + print("\nTimer results:") + nanotron_timer.log_all(rank=None) # Log on all ranks + + print("\nTest completed successfully!") diff --git a/tests/helpers/qwen_helper.py b/tests/helpers/qwen_helper.py index b333e2a71..0cb7846aa 100644 --- a/tests/helpers/qwen_helper.py +++ b/tests/helpers/qwen_helper.py @@ -17,30 +17,47 @@ TokensArgs, ) from nanotron.config.config import PretrainDatasetsArgs +from nanotron.config.models_config import MoEConfig from nanotron.models import build_model from nanotron.models.qwen import Qwen2Config, Qwen2ForTraining from nanotron.parallel.context import ParallelContext from nanotron.trainer import mark_tied_parameters -TINY_QWEN_CONFIG = Qwen2Config( +QWEN_MOE_CONFIG = MoEConfig( + num_experts=4, + top_k=1, + moe_intermediate_size=32 * 4, + shared_expert_intermediate_size=32 * 4, + enable_shared_expert=True, + token_dispatcher_type="alltoall", + use_torch_permute=True, +) + +QWEN_MODEL_CONFIG = { + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 32, + "initializer_range": 0.02, + "intermediate_size": 32 * 4, + "max_position_embeddings": 32, + "num_attention_heads": 4, + "num_hidden_layers": 4, + "num_key_value_heads": 2, + "pad_token_id": None, + "rms_norm_eps": 1e-06, + "rope_theta": 10000.0, + "tie_word_embeddings": False, + "use_cache": True, + "vocab_size": 4096, + "_attn_implementation": "flash_attention_2", +} + +TINY_QWEN_CONFIG = Qwen2Config(**QWEN_MODEL_CONFIG) +TINY_MOE_QWEN_CONFIG = Qwen2Config( **{ - "bos_token_id": 1, - "eos_token_id": 2, - "hidden_act": "silu", - "hidden_size": 128, - "initializer_range": 0.02, - "intermediate_size": 128 * 4, - "max_position_embeddings": 128, - "num_attention_heads": 4, - "num_hidden_layers": 4, - "num_key_value_heads": 2, - "pad_token_id": None, - "rms_norm_eps": 1e-06, - "rope_theta": 10000.0, - "tie_word_embeddings": False, - "use_cache": True, - "vocab_size": 4096, - "_attn_implementation": "flash_attention_2", + **QWEN_MODEL_CONFIG, + "moe_config": QWEN_MOE_CONFIG, } ) diff --git a/tests/helpers/utils.py b/tests/helpers/utils.py index d0fb01b57..509f64d0d 100644 --- a/tests/helpers/utils.py +++ b/tests/helpers/utils.py @@ -243,7 +243,34 @@ def _run_until_success(*args, **kwargs): return _wrapper -def global_wrapper(rank, func, tp, pp, dp, port, kwargs): +def _get_world_size( + tp: int, + pp: int, + dp: int, + expert_parallel_size: Optional[int] = None, + expert_tensor_parallel_size: Optional[int] = None, + expert_data_parallel_size: Optional[int] = None, + enabled_moe: bool = False, +): + if enabled_moe is False: + return tp * pp * dp + else: + return expert_parallel_size * expert_tensor_parallel_size * expert_data_parallel_size * pp + + +def global_wrapper( + rank, + func, + tp, + pp, + dp, + expert_parallel_size, + expert_tensor_parallel_size, + expert_data_parallel_size, + enabled_moe, + port, + kwargs, +): def setup_dist_env(rank, world_size, port): os.environ["WORLD_SIZE"] = str(world_size) os.environ["RANK"] = str(rank) @@ -253,22 +280,54 @@ def setup_dist_env(rank, world_size, port): os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(port) - world_size = tp * pp * dp + world_size = _get_world_size( + tp, pp, dp, expert_parallel_size, expert_tensor_parallel_size, expert_data_parallel_size, enabled_moe + ) setup_dist_env(rank, world_size, port) - parallel_context = ParallelContext(data_parallel_size=dp, pipeline_parallel_size=pp, tensor_parallel_size=tp) + parallel_context = ParallelContext( + data_parallel_size=dp, + pipeline_parallel_size=pp, + tensor_parallel_size=tp, + expert_parallel_size=expert_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + expert_data_parallel_size=expert_data_parallel_size, + enabled_moe=enabled_moe, + ) func(parallel_context, **kwargs) -def init_distributed(tp: int, dp: int, pp: int): +def init_distributed( + tp: int, + dp: int, + pp: int, + expert_parallel_size: Optional[int] = None, + expert_tensor_parallel_size: Optional[int] = None, + expert_data_parallel_size: Optional[int] = None, + enabled_moe: bool = False, +): def _init_distributed(func): def wrapper(**kwargs): from nanotron.utils import find_free_port - world_size = tp * pp * dp + world_size = _get_world_size( + tp, pp, dp, expert_parallel_size, expert_tensor_parallel_size, expert_data_parallel_size, enabled_moe + ) + port = find_free_port() # Note that kwargs needs to be passed as part of args in a way that can be unpacked - args = (func, tp, pp, dp, port, kwargs) + args = ( + func, + tp, + pp, + dp, + expert_parallel_size, + expert_tensor_parallel_size, + expert_data_parallel_size, + enabled_moe, + port, + kwargs, + ) mp.spawn(global_wrapper, args=args, nprocs=world_size) return wrapper diff --git a/tests/test_distributed_primitives.py b/tests/test_distributed_primitives.py new file mode 100644 index 000000000..48b0f4803 --- /dev/null +++ b/tests/test_distributed_primitives.py @@ -0,0 +1,85 @@ +import pytest +import torch +import torch.distributed as dist +from helpers.utils import ( + init_distributed, + rerun_if_address_is_in_use, +) +from nanotron.parallel import ParallelContext +from nanotron.parallel.tensor_parallel.distributed_differentiable_primitives import all_to_all + + +def _test_all_to_all( + parallel_context: ParallelContext, inputs, expected_outputs, input_split_sizes, output_split_sizes +): + rank = dist.get_rank(parallel_context.tp_pg) + + input = inputs[rank].to("cuda") + expected_output = expected_outputs[rank].to("cuda") + input_split_sizes = input_split_sizes[rank] + output_split_sizes = output_split_sizes[rank] + + output = all_to_all( + input, group=parallel_context.tp_pg, input_split_sizes=input_split_sizes, output_split_sizes=output_split_sizes + ) + assert torch.allclose(output, expected_output) + + +@pytest.mark.parametrize( + "world_size, inputs, expected_outputs, input_split_sizes, output_split_sizes", + [ + ( + 4, + # NOTE: range(4) is range(world_size) + [torch.arange(4) + rank * 4 for rank in range(4)], + [torch.tensor([rank + i * 4 for i in range(4)]) for rank in range(4)], + [None, None, None, None], + [None, None, None, None], + ), # Default case: uniform sizes + ( + 2, + [ + torch.tensor([2, 0, 1, 3]).unsqueeze(-1).expand(-1, 4), + torch.tensor([5, 4, 6, 7]).unsqueeze(-1).expand(-1, 4), + ], + [ + torch.tensor([2, 5, 4]).unsqueeze(-1).expand(-1, 4), + torch.tensor([0, 1, 3, 6, 7]).unsqueeze(-1).expand(-1, 4), + ], + [[1, 3], [2, 2]], + [[1, 2], [3, 2]], + ), # Custom split sizes + ], +) +@rerun_if_address_is_in_use() +def test_all_to_all(world_size, inputs, expected_outputs, input_split_sizes, output_split_sizes): + init_distributed( + tp=world_size, + dp=1, + pp=1, + expert_parallel_size=1, + expert_tensor_parallel_size=1, + expert_data_parallel_size=1, + enabled_moe=False, + )(_test_all_to_all)( + inputs=inputs, + expected_outputs=expected_outputs, + input_split_sizes=input_split_sizes, + output_split_sizes=output_split_sizes, + ) + + +if __name__ == "__main__": + test_all_to_all( + 2, + [ + torch.tensor([2, 0, 1, 3]).unsqueeze(-1).expand(-1, 4), + torch.tensor([5, 4, 6, 7]).unsqueeze(-1).expand(-1, 4), + ], + [ + torch.tensor([2, 5, 4]).unsqueeze(-1).expand(-1, 4), + torch.tensor([0, 1, 3, 6, 7]).unsqueeze(-1).expand(-1, 4), + ], + [[1, 3], [2, 2]], + [[1, 2], [3, 2]], + ) diff --git a/tests/test_moe.py b/tests/test_moe.py new file mode 100644 index 000000000..3f2e1485f --- /dev/null +++ b/tests/test_moe.py @@ -0,0 +1,400 @@ +import os +from copy import copy +from dataclasses import dataclass + +import numpy as np +import pytest +import torch +import torch.distributed as dist +from helpers.qwen_helper import TINY_MOE_QWEN_CONFIG +from helpers.utils import ( + init_distributed, + rerun_if_address_is_in_use, +) +from nanotron.config.parallelism_config import ParallelismArgs +from nanotron.models.base import init_on_device_and_dtype +from nanotron.nn.moe import ( + GroupedMLP, + Qwen2MoEMLPLayer, + _get_dispatched_routing_indices, + permute, + unpermute, +) +from nanotron.parallel import ParallelContext +from nanotron.parallel.context import ParallelMode +from torch.distributed import ProcessGroup + +HIDDEN_SIZE = TINY_MOE_QWEN_CONFIG.hidden_size + + +@dataclass(frozen=True) +class ParalellismConfig: + tp: int + dp: int + pp: int + expert_parallel_size: int + expert_tensor_parallel_size: int + expert_data_parallel_size: int + + +PARALLEL_CONFIGS_TO_PARALLEL_RANKS = { + ParalellismConfig( + tp=1, dp=4, pp=1, expert_parallel_size=2, expert_tensor_parallel_size=1, expert_data_parallel_size=2 + ): { + "attn_groups": { + "tp": [[0], [1], [2], [3]], + "cp": [[0], [1], [2], [3]], + "pp": [[0], [1], [2], [3]], + "dp": [[0, 1, 2, 3]], + }, + "moe_groups": { + "tp": [[0], [1], [2], [3]], + "ep": [[0, 1], [2, 3]], + "pp": [[0], [1], [2], [3]], + "dp": [[0, 2], [1, 3]], + }, + }, + ParalellismConfig( + tp=1, dp=2, pp=1, expert_parallel_size=2, expert_tensor_parallel_size=1, expert_data_parallel_size=1 + ): { + "attn_groups": {"tp": [[0], [1]], "cp": [[0], [1]], "pp": [[0], [1]], "dp": [[0, 1]]}, + "moe_groups": {"tp": [[0], [1]], "ep": [[0, 1]], "pp": [[0], [1]], "dp": [[0], [1]]}, + }, + ParalellismConfig( + tp=2, dp=4, pp=1, expert_parallel_size=2, expert_tensor_parallel_size=2, expert_data_parallel_size=2 + ): { + "attn_groups": { + "tp": [[0, 1], [2, 3], [4, 5], [6, 7]], + "cp": [[0], [1], [2], [3], [4], [5], [6], [7]], + "pp": [[0], [1], [2], [3], [4], [5], [6], [7]], + "dp": [[0, 2, 4, 6], [1, 3, 5, 7]], + }, + "moe_groups": { + "tp": [[0, 1], [2, 3], [4, 5], [6, 7]], + "ep": [[0, 2], [1, 3], [4, 6], [5, 7]], + "pp": [[0], [1], [2], [3], [4], [5], [6], [7]], + "dp": [[0, 4], [1, 5], [2, 6], [3, 7]], + }, + }, + ParalellismConfig( + tp=2, dp=4, pp=1, expert_parallel_size=4, expert_tensor_parallel_size=2, expert_data_parallel_size=1 + ): { + "attn_groups": { + "tp": [[0, 1], [2, 3], [4, 5], [6, 7]], + "cp": [[0], [1], [2], [3], [4], [5], [6], [7]], + "pp": [[0], [1], [2], [3], [4], [5], [6], [7]], + "dp": [[0, 2, 4, 6], [1, 3, 5, 7]], + }, + "moe_groups": { + "tp": [[0, 1], [2, 3], [4, 5], [6, 7]], + "ep": [[0, 2, 4, 6], [1, 3, 5, 7]], + "pp": [[0], [1], [2], [3], [4], [5], [6], [7]], + "dp": [[0], [1], [2], [3], [4], [5], [6], [7]], + }, + }, +} + + +def get_ep_shard(input, ep_rank: int, parallel_context: ParallelContext): + return torch.chunk(input, chunks=parallel_context.expert_parallel_size, dim=0)[ep_rank] + + +def _test_init_moe_process_groups(parallel_context: ParallelContext): + assert dist.is_initialized() is True + assert isinstance(parallel_context.world_pg, ProcessGroup) + assert isinstance(parallel_context.tp_pg, ProcessGroup) if parallel_context.tensor_parallel_size > 1 else True + assert isinstance(parallel_context.pp_pg, ProcessGroup) if parallel_context.pipeline_parallel_size > 1 else True + assert isinstance(parallel_context.dp_pg, ProcessGroup) if parallel_context.data_parallel_size > 1 else True + + assert isinstance(parallel_context.ep_pg, ProcessGroup) if parallel_context.expert_parallel_size > 1 else True + assert ( + isinstance(parallel_context.ep_tp_pg, ProcessGroup) + if parallel_context.expert_tensor_parallel_size > 1 + else True + ) + assert ( + isinstance(parallel_context.ep_dp_pg, ProcessGroup) if parallel_context.expert_data_parallel_size > 1 else True + ) + assert parallel_context.enabled_moe is True + + expected_parallel_ranks = PARALLEL_CONFIGS_TO_PARALLEL_RANKS[ + ParalellismConfig( + tp=parallel_context.tensor_parallel_size, + dp=parallel_context.data_parallel_size, + pp=parallel_context.pipeline_parallel_size, + expert_parallel_size=parallel_context.expert_parallel_size, + expert_tensor_parallel_size=parallel_context.expert_tensor_parallel_size, + expert_data_parallel_size=parallel_context.expert_data_parallel_size, + ) + ] + + assert np.all(expected_parallel_ranks["attn_groups"]["dp"] == parallel_context._group_to_ranks[ParallelMode.DP]) + assert np.all(expected_parallel_ranks["attn_groups"]["tp"] == parallel_context._group_to_ranks[ParallelMode.TP]) + assert np.all(expected_parallel_ranks["attn_groups"]["pp"] == parallel_context._group_to_ranks[ParallelMode.PP]) + assert np.all(expected_parallel_ranks["attn_groups"]["cp"] == parallel_context._group_to_ranks[ParallelMode.CP]) + + assert np.all(expected_parallel_ranks["moe_groups"]["dp"] == parallel_context._group_to_ranks[ParallelMode.EP_DP]) + assert np.all(expected_parallel_ranks["moe_groups"]["tp"] == parallel_context._group_to_ranks[ParallelMode.EP_TP]) + assert np.all(expected_parallel_ranks["moe_groups"]["pp"] == parallel_context._group_to_ranks[ParallelMode.EP_PP]) + assert np.all(expected_parallel_ranks["moe_groups"]["ep"] == parallel_context._group_to_ranks[ParallelMode.EP]) + + +@pytest.mark.parametrize( + "tp,dp,pp,expert_parallel_size,expert_tensor_parallel_size,expert_data_parallel_size", + [ + (1, 4, 1, 2, 1, 2), + (1, 2, 1, 2, 1, 1), + (2, 4, 1, 2, 2, 2), + (2, 4, 1, 4, 2, 1), + ], +) +@rerun_if_address_is_in_use() +def test_init_moe_process_groups( + tp: int, + dp: int, + pp: int, + expert_parallel_size: int, + expert_tensor_parallel_size: int, + expert_data_parallel_size: int, +): + enabled_moe = True + init_distributed( + tp=tp, + dp=dp, + pp=pp, + expert_parallel_size=expert_parallel_size, + expert_tensor_parallel_size=expert_tensor_parallel_size, + expert_data_parallel_size=expert_data_parallel_size, + enabled_moe=enabled_moe, + )(_test_init_moe_process_groups)() + + +@pytest.mark.parametrize( + "global_routing_indices, expert_parallel_size, num_experts, expected_result", + [ + (torch.tensor([[0], [2], [1], [3]], dtype=torch.long), 2, 4, [torch.tensor([0, 1]), torch.tensor([2, 3])]), + ( + torch.tensor([[2, 1], [3, 0], [1, 2], [3, 1], [1, 2], [0, 1], [2, 1], [1, 2]], dtype=torch.long), + 2, + 4, + [torch.tensor([0, 1, 1, 1, 0, 1, 1, 1, 1]), torch.tensor([2, 2, 3, 3, 2, 2, 2])], + ), + ], +) +def test_get_dispatched_routing_indices(global_routing_indices, expert_parallel_size, num_experts, expected_result): + """Test basic functionality with simple routing indices.""" + result = _get_dispatched_routing_indices(global_routing_indices, expert_parallel_size, num_experts) + + assert isinstance(result, list) + assert len(result) == expert_parallel_size + + assert torch.equal(result[0], expected_result[0]) + assert torch.equal(result[1], expected_result[1]) + + +@pytest.mark.parametrize( + "routing_indices, expected_output, routing_weights", + [ + ( + torch.tensor([[2], [3], [1], [3]], dtype=torch.int32, device="cuda"), + torch.tensor([2, 0, 1, 3], dtype=torch.bfloat16, device="cuda").unsqueeze(-1).expand(-1, HIDDEN_SIZE), + torch.tensor([[1], [1], [1], [1]], dtype=torch.bfloat16, device="cuda"), + ), + ( + torch.tensor([[2, 1], [3, 0], [1, 2], [3, 1]], dtype=torch.int32, device="cuda"), + torch.tensor([1, 0, 2, 3, 0, 2, 1, 3], dtype=torch.bfloat16, device="cuda") + .unsqueeze(-1) + .expand(-1, HIDDEN_SIZE), + torch.tensor([[1, 1], [1, 1], [1, 1], [1, 1]], dtype=torch.bfloat16, device="cuda"), + ), + ], +) +def test_permute_and_unpermute(routing_indices, expected_output, routing_weights): + x = torch.tensor([0, 1, 2, 3], dtype=torch.bfloat16, device="cuda").unsqueeze(-1).expand(-1, HIDDEN_SIZE) + y, inverse_mapping = permute(x, routing_indices) + + assert torch.equal(y, expected_output) + + y_combined = unpermute(y, inverse_mapping, routing_weights) + + assert y_combined.shape == x.shape + + +def test_grouped_mlp(): + parallel_config = ParallelismArgs( + dp=1, + pp=1, + tp=1, + expert_parallel_size=1, + expert_tensor_parallel_size=1, + expert_data_parallel_size=1, + pp_engine="1f1b", + tp_mode="REDUCE_SCATTER", + tp_linear_async_communication=True, + ) + + # NOTE: num_tokens_per_experts.shape = (num_experts,) + # it should match the number of experts in TINY_MOE_QWEN_CONFIG + num_tokens_per_experts = torch.tensor([1, 2, 3, 4]) + NUM_TOKENS = num_tokens_per_experts.sum() + NUM_EXPERTS = TINY_MOE_QWEN_CONFIG.moe_config.num_experts + HIDDEN_SIZE = TINY_MOE_QWEN_CONFIG.hidden_size + permuted_hidden_states = torch.randn(NUM_TOKENS, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda") + + assert len(num_tokens_per_experts) == NUM_EXPERTS + + with init_on_device_and_dtype(device=torch.device("cuda"), dtype=torch.bfloat16): + grouped_mlp = GroupedMLP(config=TINY_MOE_QWEN_CONFIG, parallel_config=parallel_config, ep_pg=None) + + output = grouped_mlp(permuted_hidden_states, num_tokens_per_experts) + + assert output["hidden_states"].shape == (NUM_TOKENS, HIDDEN_SIZE) + assert output["hidden_states"].dtype == torch.bfloat16 + assert output["hidden_states"].device.type == "cuda" + + +@rerun_if_address_is_in_use() +def test_expert_parallelism(): + DP_SIZE = 2 + EP_SIZE = 2 + BS = 16 + SEQ_LEN = 128 + parallel_config = ParallelismArgs( + tp=1, + dp=DP_SIZE, + pp=1, + expert_parallel_size=EP_SIZE, + expert_tensor_parallel_size=1, + expert_data_parallel_size=1, + enabled_moe=True, + ) + inputs = torch.randn(BS * SEQ_LEN, HIDDEN_SIZE, dtype=torch.bfloat16, device="cuda") + + init_distributed( + tp=1, + dp=DP_SIZE, + pp=1, + expert_parallel_size=EP_SIZE, + expert_tensor_parallel_size=1, + expert_data_parallel_size=1, + enabled_moe=True, + )(_test_expert_parallelism)( + list_input_batches=inputs, + parallel_config=parallel_config, + ) + + +def _test_expert_parallelism( + parallel_context: ParallelContext, + list_input_batches: torch.Tensor, + parallel_config: ParallelismArgs, +): + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + os.environ["TORCH_USE_CUDA_DSA"] = "1" + + ep_rank = dist.get_rank(parallel_context.ep_pg) + input_batches = get_ep_shard(list_input_batches, ep_rank, parallel_context).contiguous().cuda() + list_input_batches = list_input_batches.contiguous().cuda() + + input_batches.requires_grad = True + list_input_batches.requires_grad = True + + # TODO(xrsrke): deduplicate the code + ref_parallel_context = copy(parallel_context) + ref_parallel_context.expert_parallel_size = 1 + ref_parallel_context.expert_tensor_parallel_size = 1 + ref_parallel_context.expert_data_parallel_size = 1 + ref_parallel_context.ep_pg = parallel_context.tp_pg + + ref_parallel_config = copy(parallel_config) + ref_parallel_config.expert_parallel_size = 1 + ref_parallel_config.expert_tensor_parallel_size = 1 + ref_parallel_config.expert_data_parallel_size = 1 + + with init_on_device_and_dtype(device="cuda", dtype=torch.bfloat16): + moe_layer = Qwen2MoEMLPLayer( + config=TINY_MOE_QWEN_CONFIG, parallel_context=parallel_context, parallel_config=parallel_config + ) + ref_moe_layer = Qwen2MoEMLPLayer( + config=TINY_MOE_QWEN_CONFIG, parallel_context=ref_parallel_context, parallel_config=ref_parallel_config + ) + # NOTE: make the parameters of all ranks in the ref_moe_layer the same + for p in ref_moe_layer.parameters(): + dist.all_reduce(p, op=dist.ReduceOp.AVG) + + # NOTE: copy the parameter from ref moe to parallelized moe + def is_expert_param(name): + return any(x for x in ["experts.merged_gate_up_proj", "experts.merged_down_proj"] if x in name) + + for (n, p), (ref_n, ref_p) in zip(moe_layer.named_parameters(), ref_moe_layer.named_parameters()): + assert n == ref_n + if is_expert_param(n): + # NOTE: expert parallel sharding + num_local_experts = moe_layer.num_local_experts + start_idx = ep_rank * num_local_experts + end_idx = start_idx + num_local_experts + p.data.copy_(ref_p.data[start_idx:end_idx, :, :]) + else: + p.data.copy_(ref_p.data) + + for (name, param), (ref_name, ref_param) in zip(moe_layer.named_parameters(), ref_moe_layer.named_parameters()): + if is_expert_param(name): + continue + + assert name == ref_name + assert torch.allclose(param, ref_param) + + outputs = moe_layer(input_batches)["hidden_states"] + ref_outputs = ref_moe_layer(list_input_batches)["hidden_states"] + + assert torch.allclose( + outputs, + get_ep_shard(ref_outputs, ep_rank, parallel_context), + ) + assert outputs.requires_grad is True + + outputs.sum().backward() + ref_outputs.sum().backward() + + # NOTE: for the gradient match comparison, run end-to-end test + assert all(p.grad is not None for p in moe_layer.parameters()) + + assert 1 == 1 + + for (n, p), (ref_n, ref_p) in zip(moe_layer.named_parameters(), ref_moe_layer.named_parameters()): + # NOTE: we run end-to-end training for gradient matching test + if any(x in n for x in ["router.weight", "shared_expert"]): + continue + + assert torch.equal(p.grad, get_ep_shard(ref_p.grad, ep_rank, parallel_context)) + + +if __name__ == "__main__": + # test_grouped_mlp() + # test_init_moe_process_groups(tp=1, dp=4, pp=1, expert_parallel_size=2, expert_tensor_parallel_size=1, expert_data_parallel_size=2, enabled_moe=True) + # test_init_moe_process_groups(tp=2, dp=2, pp=2, expert_parallel_size=1, expert_tensor_parallel_size=1, expert_data_parallel_size=1, enabled_moe=False) + + # (1, 1, 1, 2, 1, 1) the test that fails + # test_init_moe_process_groups(tp=1, dp=1, pp=1, expert_parallel_size=2, expert_tensor_parallel_size=1, expert_data_parallel_size=1) + + test_expert_parallelism( + # list_routing_indicies=torch.tensor([[2], [3], [1], [3], [1], [0], [2], [3]], dtype=torch.int32), + # include_backward=True, + ) + + # test_permute( + # routing_indices=torch.tensor([[2], [3], [1], [3]], dtype=torch.int32, device="cuda"), + # expected_output=torch.tensor([2, 0, 1, 3], dtype=torch.bfloat16, device="cuda").unsqueeze(-1).expand(-1, 16) + # ) + # test_permute_and_unpermute( + # torch.tensor([[2, 1], [3, 0], [1, 2], [3, 1]], dtype=torch.int32, device="cuda"), + # torch.tensor([1, 0, 2, 3, 0, 2, 1, 3], dtype=torch.bfloat16, device="cuda").unsqueeze(-1).expand(-1, 16), + # ) + + # test_get_dispatched_routing_indices( + # # torch.tensor([[0, 2], [1, 3], [0, 3], [1, 2]], dtype=torch.long), 2, 2, [torch.tensor([0, 1]), torch.tensor([2, 3])] + # torch.tensor([[2, 1], [3, 0], [1, 2], [3, 1], [1, 2], [0, 1], [2, 1], [1, 2]], dtype=torch.long), + # 2, + # 4, + # [torch.tensor([0, 1, 1, 1, 0, 1, 1, 1, 1]), torch.tensor([2, 2, 3, 3, 2, 2, 2])], + # ) \ No newline at end of file diff --git a/tests/test_moe_dispatcher.py b/tests/test_moe_dispatcher.py new file mode 100644 index 000000000..d56f66252 --- /dev/null +++ b/tests/test_moe_dispatcher.py @@ -0,0 +1,163 @@ +import os + +import pytest +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from helpers.utils import ( + rerun_if_address_is_in_use, +) +from nanotron.distributed import initialize_torch_distributed +from nanotron.nn.moe import AllToAllDispatcher +from nanotron.utils import find_free_port + +BS = 1 +SEQ_LEN = 8 +HIDDEN_SIZE = 4 + + +def setup_dist_env(rank, world_size, port): + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["RANK"] = str(rank) + # NOTE: since we do unit tests in a + # single node => this is fine! + os.environ["LOCAL_RANK"] = str(rank) + os.environ["MASTER_ADDR"] = "localhost" + os.environ["MASTER_PORT"] = str(port) + os.environ["CUDA_LAUNCH_BLOCKING"] = "1" + + +def _test_all_to_all_dispatcher( + rank, + world_size, + port, + inputs, + routing_indices, + routing_weights, + expected_permuted_outputs, + expected_num_local_dispatched_tokens_per_expert, + expected_combined_outputs, + num_experts, +): + setup_dist_env(rank, world_size, port) + initialize_torch_distributed() + + ep_pg = dist.new_group(ranks=list(range(world_size))) + ep_rank = dist.get_rank(ep_pg) + expected_num_local_dispatched_tokens_per_expert = expected_num_local_dispatched_tokens_per_expert[ep_rank].cuda() + + # NOTE: each ep rank holds a chunk of the inputs + input = torch.chunk(inputs, world_size, dim=0)[ep_rank].contiguous().cuda() + expected_output = expected_permuted_outputs[ep_rank].contiguous().cuda() + routing_indices = torch.chunk(routing_indices, world_size, dim=0)[ep_rank].contiguous().cuda() + routing_weights = torch.chunk(routing_weights, world_size, dim=0)[ep_rank].contiguous().cuda() + expected_combined_outputs = torch.chunk(expected_combined_outputs, world_size, dim=0)[ep_rank].contiguous().cuda() + num_local_experts = num_experts // world_size + + dispatcher = AllToAllDispatcher(num_local_experts=num_local_experts, num_experts=num_experts, ep_pg=ep_pg) + + ( + dispatched_input, + inverse_permute_mapping, + inverse_expert_sorting_index, + num_local_dispatched_tokens_per_expert, + ) = dispatcher.permute(input, routing_indices) + + assert torch.allclose(dispatched_input, expected_output) + assert torch.equal(expected_num_local_dispatched_tokens_per_expert.cpu(), num_local_dispatched_tokens_per_expert) + + undispatched_and_comebined_input = dispatcher.unpermute( + dispatched_input, inverse_permute_mapping, routing_weights, inverse_expert_sorting_index + ) + + assert torch.allclose(undispatched_and_comebined_input, expected_combined_outputs) + + dist.destroy_process_group() + + +@pytest.mark.parametrize( + "routing_indices, routing_weights, expected_permuted_outputs, expected_num_local_dispatched_tokens_per_expert, expected_combined_outputs", + [ + [ + torch.tensor([[2], [3], [1], [3], [1], [0], [2], [3]], dtype=torch.int32), + torch.tensor( + [[1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0], [1.0]], dtype=torch.bfloat16, device="cuda" + ), + [ + torch.tensor([5, 2, 4], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + torch.tensor([0, 6, 1, 3, 7], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + ], + torch.tensor([[1, 2], [2, 3]], dtype=torch.bfloat16), + torch.arange(BS * SEQ_LEN, dtype=torch.bfloat16) + .unsqueeze(-1) + .expand(-1, HIDDEN_SIZE), # identical as input + ], # top-k=1 + [ + torch.tensor([[2, 1], [3, 0], [1, 2], [3, 1], [1, 2], [0, 1], [2, 1], [1, 2]], dtype=torch.int32), + torch.tensor( + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + dtype=torch.bfloat16, + device="cuda", + ), + [ + # NOTE: this isn't include expert sorting index + torch.tensor([1, 5, 0, 2, 3, 4, 5, 6, 7], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + torch.tensor([0, 2, 4, 6, 7, 1, 3], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + ], + torch.tensor([[2, 7], [5, 2]], dtype=torch.bfloat16), + torch.tensor([0, 2, 4, 6, 8, 10, 12, 14], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + ], # top-k=2 + ], +) +@rerun_if_address_is_in_use() +def test_all_to_all_dispatcher( + routing_indices, + routing_weights, + expected_permuted_outputs, + expected_num_local_dispatched_tokens_per_expert, + expected_combined_outputs, +): + port = find_free_port() + WORLD_SIZE = 2 + NUM_EXPERTS = 4 + + # NOTE: input.shape = [bs*seq_len, hidden_size] + inputs = torch.arange(BS * SEQ_LEN, dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE) + + mp.spawn( + _test_all_to_all_dispatcher, + args=( + WORLD_SIZE, + port, + inputs, + routing_indices, + routing_weights, + expected_permuted_outputs, + expected_num_local_dispatched_tokens_per_expert, + expected_combined_outputs, + NUM_EXPERTS, + ), + nprocs=WORLD_SIZE, + ) + + +if __name__ == "__main__": + test_all_to_all_dispatcher( + # routing_indices=torch.tensor([[2], [3], [1], [3], [1], [0], [2], [3]], dtype=torch.int32) + routing_indices=torch.tensor( + [[2, 1], [3, 0], [1, 2], [3, 1], [1, 2], [0, 1], [2, 1], [1, 2]], dtype=torch.int32 + ), + routing_weights=torch.tensor( + [[1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0], [1.0, 1.0]], + dtype=torch.bfloat16, + device="cuda", + ), + expected_permuted_outputs=[ + torch.tensor([1, 5, 0, 2, 3, 4, 5, 6, 7], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + torch.tensor([0, 2, 4, 6, 7, 1, 3], dtype=torch.bfloat16).unsqueeze(-1).expand(-1, HIDDEN_SIZE), + ], + expected_num_local_dispatched_tokens_per_expert=torch.tensor([[2, 7], [5, 2]], dtype=torch.bfloat16), + expected_combined_outputs=torch.tensor([0, 2, 4, 6, 8, 10, 12, 14], dtype=torch.bfloat16) + .unsqueeze(-1) + .expand(-1, HIDDEN_SIZE), + ) diff --git a/tools/preprocess_data.py b/tools/preprocess_data.py index f3cdab70e..290ba5eb3 100644 --- a/tools/preprocess_data.py +++ b/tools/preprocess_data.py @@ -9,7 +9,7 @@ from datatrove.executor.local import LocalPipelineExecutor from datatrove.pipeline.readers import HuggingFaceDatasetReader, JsonlReader -from datatrove.pipeline.tokens import DocumentTokenizer +from datatrove.pipeline.tokens import DocumentTokenizer, MegatronDocumentTokenizer, DocumentTokenizerMerger def get_args(): @@ -91,6 +91,7 @@ def main(args): preprocess_executor = LocalPipelineExecutor( pipeline=[ datatrove_reader, + # for nanotron DocumentTokenizer( output_folder=args.output_folder, tokenizer_name_or_path=args.tokenizer_name_or_path, @@ -98,6 +99,20 @@ def main(args): shuffle=False, max_tokens_per_file=1e9, ), + # optional: merge files + # DocumentTokenizerMerger( + # input_folder=args.output_folder, + # output_folder=args.output_folder, + # save_filename="merged", + # shuffle=False, + # ) + # for megatron + # MegatronDocumentTokenizer( + # output_folder=args.output_folder, + # tokenizer_name_or_path=args.tokenizer_name_or_path, + # eos_token=args.eos_token, + # ), + ], tasks=args.n_tasks, logging_dir=args.logging_dir,