diff --git a/torchtitan/experiments/README.md b/torchtitan/experiments/README.md index 46d4e57539..29fc1ea44b 100644 --- a/torchtitan/experiments/README.md +++ b/torchtitan/experiments/README.md @@ -28,4 +28,5 @@ We provide this `experiments/` folder to host experiments that add significant v | [vlm](./vlm/) | [![VLM 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml?query=branch%3Amain) | [@lkhphuc](https://github.com/lkhphuc) | | [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) | | [torchcomms](./torchcomms/) | TBA | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) | -| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/pytorch/torchtitan/pulls/kwen2501) | +| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) | +| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) | diff --git a/torchtitan/experiments/__init__.py b/torchtitan/experiments/__init__.py index 928683b147..c5783720ea 100644 --- a/torchtitan/experiments/__init__.py +++ b/torchtitan/experiments/__init__.py @@ -5,5 +5,5 @@ # LICENSE file in the root directory of this source tree. _supported_experiments = frozenset( - ["flux", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"] + ["flux", "gpt_oss", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"] ) diff --git a/torchtitan/experiments/gpt_oss/README.md b/torchtitan/experiments/gpt_oss/README.md new file mode 100644 index 0000000000..a8283ab7b6 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/README.md @@ -0,0 +1,17 @@ +# gpt-oss Model in torchtitan + +## Quick Start +```bash +CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh +``` + +## Supported Features +- FSDP/HSDP, TP, EP, ETP +- Grouped matrix multiplication for efficient computation + + +## TODO +1. More parallelism support: CP, PP +2. Conversion between HF weights (StateDictAdapter) +3. Forward parity verification +4. CI support diff --git a/torchtitan/experiments/gpt_oss/__init__.py b/torchtitan/experiments/gpt_oss/__init__.py new file mode 100644 index 0000000000..9e52796413 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/__init__.py @@ -0,0 +1,87 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from torchtitan.components.loss import build_cross_entropy_loss +from torchtitan.components.lr_scheduler import build_lr_schedulers +from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing +from torchtitan.components.tokenizer import build_hf_tokenizer +from torchtitan.datasets.hf_datasets import build_hf_dataloader +from torchtitan.models.moe import MoEArgs + +from torchtitan.protocols.train_spec import TrainSpec + +from .infra.parallelize import parallelize_gptoss +from .model.args import GptOssModelArgs +from .model.model import GptOssModel + +__all__ = [ + "parallelize_gptoss", + "GptOssModelArgs", + "GptOssModel", + "gptoss_configs", +] + + +gptoss_configs = { + "debugmodel": GptOssModelArgs( + dim=256, + n_layers=4, + moe_args=MoEArgs( + num_experts=8, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ), + attn_mask_type="causal", + ), + "20b": GptOssModelArgs( + n_layers=24, + moe_args=MoEArgs( + num_experts=32, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ), + ), + "120b": GptOssModelArgs( + n_layers=36, + moe_args=MoEArgs( + num_experts=128, + num_shared_experts=0, + score_func="softmax", + route_norm=False, + route_scale=1.0, + score_before_experts=False, + top_k=4, + use_grouped_mm=True, + load_balance_coeff=1e-3, + ), + ), +} + + +def get_train_spec() -> TrainSpec: + return TrainSpec( + model_cls=GptOssModel, + model_args=gptoss_configs, + parallelize_fn=parallelize_gptoss, + pipelining_fn=None, + build_optimizers_fn=build_optimizers_with_moe_load_balancing, + build_lr_schedulers_fn=build_lr_schedulers, + build_dataloader_fn=build_hf_dataloader, + build_tokenizer_fn=build_hf_tokenizer, + build_loss_fn=build_cross_entropy_loss, + ) diff --git a/torchtitan/experiments/gpt_oss/infra/expert_parallel.py b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py new file mode 100644 index 0000000000..96ad157c2f --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/expert_parallel.py @@ -0,0 +1,65 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import torch.nn as nn +from torch.distributed.tensor import distribute_tensor, Replicate, Shard +from torchtitan.distributed.expert_parallel import ExpertTensorParallel, TensorParallel + +# implementation of Tensor Parallel for the GroupedExperts in MoE +class GptossTensorParallel(TensorParallel): + def _partition_fn(self, name, module, device_mesh): + module.register_parameter( + "mlp1_weight", + nn.Parameter( + distribute_tensor(module.mlp1_weight, device_mesh, [Shard(1)]) + ), + ) # Column-wise sharding + module.register_parameter( + "mlp1_bias", + nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), + ) # Column-wise sharding + module.register_parameter( + "mlp2_weight", + nn.Parameter( + distribute_tensor(module.mlp2_weight, device_mesh, [Shard(2)]) + ), + ) # Row-wise sharding + module.register_parameter( + "mlp2_bias", + nn.Parameter( + distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()]) + ), + ) # Replicate + + +# This class is for dp2ep with TP (without TP we can just use GptossExpertParallel) +class GptossExpertTensorParallel(ExpertTensorParallel): + def _partition_fn_2d(self, name, mod, ep_tp_mesh): + mod.register_parameter( + "mlp1_weight", + nn.Parameter( + distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(1)]) + ), + ) # Column-wise sharding + mod.register_parameter( + "mlp1_bias", + nn.Parameter( + distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) + ), + ) # Column-wise sharding + mod.register_parameter( + "mlp2_weight", + nn.Parameter( + distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)]) + ), + ) # Row-wise sharding + mod.register_parameter( + "mlp2_bias", + nn.Parameter( + distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) + ), + ) # Replicate diff --git a/torchtitan/experiments/gpt_oss/infra/parallelize.py b/torchtitan/experiments/gpt_oss/infra/parallelize.py new file mode 100644 index 0000000000..9d538e13a1 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/infra/parallelize.py @@ -0,0 +1,310 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.distributed.device_mesh import DeviceMesh + +from torch.distributed.tensor import distribute_tensor, Partial, Replicate, Shard +from torch.distributed.tensor.parallel import ( + ColwiseParallel, + parallelize_module, + PrepareModuleInput, + PrepareModuleInputOutput, + RowwiseParallel, + SequenceParallel, +) +from torchtitan.config import TORCH_DTYPE_MAP +from torchtitan.config.job_config import JobConfig +from torchtitan.distributed import NoParallel, ParallelDims +from torchtitan.distributed.expert_parallel import ( + ExpertParallel, + ReordererSequenceParallel, +) +from torchtitan.models.llama3.infra.parallelize import apply_ac, apply_ddp +from torchtitan.models.llama4.infra.parallelize import apply_fsdp +from torchtitan.tools.logging import logger + +from .expert_parallel import GptossExpertTensorParallel, GptossTensorParallel + + +# for selective op activation checkpointing +_op_sac_save_list = { + torch.ops.aten.mm.default, + torch.ops.aten._scaled_dot_product_efficient_attention.default, + torch.ops.aten._scaled_dot_product_flash_attention.default, + torch.ops._c10d_functional.reduce_scatter_tensor.default, + torch.ops._c10d_functional.all_to_all_single.default, + # for low precision training, it's useful to always save + # the result of max, since the absolute maximum is + # used to compute the scaling factor for quantization. + torch.ops.aten.max.default, + torch._higher_order_ops.flex_attention, +} + +# Adapted from llama4/infra/parallelize.py +def parallelize_gptoss( + model: nn.Module, + parallel_dims: ParallelDims, + job_config: JobConfig, +): + world_mesh = parallel_dims.world_mesh + + assert ( + job_config.training.seq_len % parallel_dims.seq_len_divisor == 0 + ), f""" + Sequence length {job_config.training.seq_len} must be divisible by the product of TP degree + ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). + """ + + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: + raise NotImplementedError("CP support for FlexAttention is still in progress.") + + if parallel_dims.tp_enabled: + if ( + job_config.parallelism.enable_async_tensor_parallel + and not model_compile_enabled + ): + raise RuntimeError("Async TP requires torch.compile") + + enable_float8_linear = "float8" in job_config.model.converters + float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( + "rowwise", + "rowwise_with_gw_hp", + ) + + # For now, float8 all-gather with TP is only supported for tensorwise + # float8 scaling recipes. For rowwise recipes, we use regular TP and + # all-gather happens in high precision. + enable_float8_tensorwise_tp = enable_float8_linear and not float8_is_rowwise + + apply_non_moe_tp( + model, + world_mesh["tp"], + loss_parallel=not job_config.parallelism.disable_loss_parallel, + enable_float8_tensorwise_tp=False, + enable_async_tp=False, + ) + + if parallel_dims.tp_enabled or parallel_dims.ep_enabled: + apply_moe_ep_tp( + model, + tp_mesh=world_mesh["tp"] if parallel_dims.tp_enabled else None, + ep_mesh=world_mesh["ep"] if parallel_dims.ep_enabled else None, + ep_tp_mesh=( + world_mesh["ep", "tp"] + if parallel_dims.tp_enabled + and parallel_dims.ep_enabled + and parallel_dims.etp_enabled + else None + ), + etp_enabled=parallel_dims.etp_enabled, + ) + + model_compile_enabled = ( + job_config.compile.enable and "model" in job_config.compile.components + ) + + if job_config.activation_checkpoint.mode != "none": + apply_ac( + model, + job_config.activation_checkpoint, + model_compile_enabled=model_compile_enabled, + use_flex_attn=use_flex_attn, + save_list=_op_sac_save_list, + ) + + dp_mesh: DeviceMesh | None = None + if parallel_dims.fsdp_enabled or parallel_dims.ep_enabled: + # apply FSDP or HSDP, potentially with Context Parallel + if parallel_dims.dp_replicate_enabled: + dp_mesh_dim_names = ("dp_replicate", "dp_shard_cp") + else: + dp_mesh_dim_names = ("dp_shard_cp",) + dp_mesh = world_mesh[tuple(dp_mesh_dim_names)] + + # the mesh dim names of which the MoE params are sharded on via FSDP/HSDP + dp_mod_ep_mesh_dim_names = [] + if parallel_dims.ep_enabled: + if parallel_dims.dp_replicate_enabled: + dp_mod_ep_mesh_dim_names.append("dp_replicate") + dp_mod_ep_mesh_dim_names.append("dp_shard_mod_ep") + + apply_fsdp( + model, + dp_mesh, + param_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_param], + reduce_dtype=TORCH_DTYPE_MAP[job_config.training.mixed_precision_reduce], + pp_enabled=parallel_dims.pp_enabled, + cpu_offload=job_config.training.enable_cpu_offload, + reshard_after_forward_policy=job_config.parallelism.fsdp_reshard_after_forward, + ep_degree=parallel_dims.ep, + dp_mod_ep_mesh=( + world_mesh[tuple(dp_mod_ep_mesh_dim_names)] + if parallel_dims.ep_enabled + else None + ), + ) + + if parallel_dims.dp_replicate_enabled: + logger.info("Applied HSDP to the model") + else: + logger.info("Applied FSDP to the model") + + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + + if job_config.training.enable_cpu_offload: + logger.info("Applied CPU Offloading to the model") + elif parallel_dims.dp_replicate_enabled: + if world_mesh.ndim > 1: + raise RuntimeError("DDP has not supported > 1D parallelism") + dp_mesh = world_mesh + apply_ddp( + model, + dp_mesh, + enable_compile=model_compile_enabled, + enable_compiled_autograd=job_config.parallelism.enable_compiled_autograd, + ) + + return model + + +def apply_non_moe_tp( + model: nn.Module, + tp_mesh: DeviceMesh, + loss_parallel: bool, + enable_float8_tensorwise_tp: bool, + enable_async_tp: bool, +): + """Apply tensor parallelism.""" + # 1. Parallelize the embedding and shard its outputs (which are the first + # transformer block's inputs) + # 2. Parallelize the root norm layer over the sequence dim + # 3. Parallelize the final linear output layer + parallelize_module( + model, + tp_mesh, + { + "tok_embeddings": RowwiseParallel( + input_layouts=Replicate(), + output_layouts=Shard(1), + ), + "norm": SequenceParallel(), + "output": ColwiseParallel( + input_layouts=Shard(1), + output_layouts=Shard(-1) if loss_parallel else Replicate(), + use_local_output=not loss_parallel, + ), + }, + ) + + # Apply tensor + sequence parallelism to every transformer block + for transformer_block in model.layers.values(): + layer_plan = { + "attention_norm": SequenceParallel(), + "attention": PrepareModuleInput( + input_layouts=(Shard(1), Replicate(), None), + desired_input_layouts=(Replicate(), Replicate(), None), + ), + "attention.wq": ColwiseParallel(use_local_output=False), + "attention.wk": ColwiseParallel(use_local_output=False), + "attention.wv": ColwiseParallel(use_local_output=False), + "attention.inner_attention": PrepareModuleInputOutput( + input_layouts=(Shard(1), Shard(1), Shard(1)), + desired_input_layouts=(Shard(1), Shard(1), Shard(1)), + use_local_input=True, + output_layouts=(Shard(1), Shard(1)), + desired_output_layouts=(Shard(1), Shard(1)), + use_local_output=False, + ), + "attention.wo": RowwiseParallel(output_layouts=Shard(1)), + "ffn_norm": SequenceParallel(), + } + + # shard attention.sinks across heads + attn = transformer_block.attention + attn.register_parameter( + "sinks", + nn.Parameter(distribute_tensor(attn.sinks, tp_mesh, [Shard(0)])), + ) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=layer_plan, + ) + + if enable_async_tp: + from torch.distributed._symmetric_memory import enable_symm_mem_for_group + + torch._inductor.config._micro_pipeline_tp = True + enable_symm_mem_for_group(tp_mesh.get_group().group_name) + + logger.info( + f"Applied {'Float8 tensorwise ' if enable_float8_tensorwise_tp else ''}{'Async ' if enable_async_tp else ''}" + "Tensor Parallelism to the model" + ) + + +def apply_moe_ep_tp( + model: nn.Module, + tp_mesh: DeviceMesh | None, + ep_mesh: DeviceMesh | None, + ep_tp_mesh: DeviceMesh | None, + etp_enabled: bool, +): + assert ep_mesh is not None or tp_mesh is not None + + for transformer_block in model.layers.values(): + if not transformer_block.moe_enabled: + continue + + if tp_mesh is not None: + moe_layer_plan = { + # input / output sharding on the seqlen dim + # all-gather for input, reduce-scatter for output + "moe": PrepareModuleInputOutput( + input_layouts=(Shard(1),), + desired_input_layouts=(Replicate(),), + use_local_input=True, + output_layouts=(Partial(),), + desired_output_layouts=(Shard(1),), + ), + # replicate computation for the router + "moe.router.gate": NoParallel(), + } + if ep_mesh is not None and not etp_enabled: + # If TP is borrowed for EP, then split the tokens across TP ranks so that + # the reorderer, the all-to-all comms, and routed experts computation + # are effectively running Sequence Parallel (split along the folded bs*slen dim) + moe_layer_plan.update({"moe.reorderer": ReordererSequenceParallel()}) + + parallelize_module( + module=transformer_block, + device_mesh=tp_mesh, + parallelize_plan=moe_layer_plan, + ) + + experts_mesh, experts_plan = None, None + if ep_mesh is None: + experts_mesh = tp_mesh + # input Replicate, output Partial + experts_plan = GptossTensorParallel() + elif tp_mesh is None or not etp_enabled: + experts_mesh = ep_mesh + # input / output sharding on the batch / tokens dim + experts_plan = ExpertParallel() + else: + experts_mesh = ep_tp_mesh + experts_plan = GptossExpertTensorParallel() + + parallelize_module( + module=transformer_block.moe.experts, + device_mesh=experts_mesh, + parallelize_plan=experts_plan, + ) diff --git a/torchtitan/experiments/gpt_oss/model/args.py b/torchtitan/experiments/gpt_oss/model/args.py new file mode 100644 index 0000000000..60e59dc821 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/args.py @@ -0,0 +1,97 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from dataclasses import dataclass, field +from typing import Literal + +from torch import nn + +from torchtitan.config.job_config import JobConfig +from torchtitan.models.moe import MoEArgs +from torchtitan.models.utils import get_moe_model_nparams_and_flops +from torchtitan.protocols.train_spec import BaseModelArgs +from torchtitan.tools.logging import logger +from torchtitan.tools.utils import has_cuda_capability + + +@dataclass +class GptOssModelArgs(BaseModelArgs): + """ + Data class for defining model arguments and hyperparameters. + + Attributes: + max_batch_size (int): Maximum batch size. + max_seq_len (int): Maximum sequence length. + dtype (Literal["bf16", "fp8"]): Data type for computations. + vocab_size (int): Vocabulary size. + dim (int): Model dimension. + moe_inter_dim (int): Intermediate dimension for MoE layers. + n_layers (int): Number of transformer layers. + norm_eps (float): Epsilon used for RMSNorm. + moe_args (MoEArgs): Arguments for Mixture of Experts (MoE) layers. + swiglu_limit (float): SwiGLU activation limit. + head_dim (int): Dimension of each attention head. + n_heads (int): Number of attention heads. + n_kv_heads (int): Number of key-value heads. + sliding_window_size (int): Size of the sliding attention window. + attn_mask_type (str): Type of basic attention mask. + use_flex_attn (bool): Whether to use FlexAttention. Only supports True. + original_seq_len (int): Original sequence length. + rope_theta (float): Base for rotary positional encoding. + rope_factor (float): Scaling factor for extended sequence lengths. + beta_fast (int): Fast beta correction factor. + beta_slow (int): Slow beta correction factor. + """ + + max_batch_size: int = 8 + max_seq_len: int = 131072 + dtype: Literal["bf16", "fp8"] = "bf16" + vocab_size: int = 201088 + dim: int = 2880 + moe_inter_dim: int = 2880 + n_layers: int = 24 + norm_eps: float = 1e-5 # eps used for RMSNorm + # MoE + moe_args: MoEArgs = field(default_factory=MoEArgs) + swiglu_limit: float = 7.0 + # Multi-Head Latent Attention (MLA) + head_dim: int = 64 + n_heads: int = 64 + n_kv_heads: int = 8 + sliding_window_size: int = 128 + attn_mask_type: str = "causal" + use_flex_attn: bool = True # NOTE: gpt-oss only support FlexAttention + # yarn + original_seq_len: int = 4096 + rope_theta: float = 150000.0 + rope_factor: float = 32 + beta_fast: int = 32 + beta_slow: int = 1 + + def update_from_config(self, job_config: JobConfig, **kwargs) -> None: + seq_len = job_config.training.seq_len + if seq_len > self.max_seq_len: + logger.warning( + f"Sequence length {seq_len} exceeds original maximum {self.max_seq_len}." + ) + self.max_seq_len = seq_len + + if self.moe_args.use_grouped_mm and not has_cuda_capability(9, 0): + logger.warning( + "Failed to use grouped mm, which is only supported on SM90 or later", + ) + self.moe_args.use_grouped_mm = False + + if job_config.parallelism.context_parallel_degree > 1: + raise NotImplementedError( + "CP support for gpt-oss model is still in progress." + ) + + def get_nparams_and_flops( + self, model: nn.Module, seq_len: int + ) -> tuple[int, float]: + return get_moe_model_nparams_and_flops(self, model, seq_len) diff --git a/torchtitan/experiments/gpt_oss/model/model.py b/torchtitan/experiments/gpt_oss/model/model.py new file mode 100644 index 0000000000..1fcd12eaa9 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/model.py @@ -0,0 +1,394 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from torch import nn +from torch.nn.attention.flex_attention import and_masks, BlockMask +from torchtitan.components.tokenizer import BaseTokenizer +from torchtitan.models.attention import ( + create_attention_mask, + FlexAttentionWrapper, + get_causal_mask_mod, + get_document_mask_mod, + get_sliding_window_mask_mod, +) +from torchtitan.protocols.model import AttentionMasksType +from torchtitan.protocols.train_spec import ModelProtocol + +from .args import GptOssModelArgs +from .moe import GptOssMoE + + +def precompute_rope_cache( + dim: int, max_seq_len: int, base: float = 1_000_000.0 +) -> torch.Tensor: + freqs = 1.0 / (base ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + t = torch.arange(max_seq_len, dtype=freqs.dtype, device=freqs.device) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.outer(t, freqs).float() + + # We cache the cos and sin embeddings instead of the IDs. This helps + # ensure we have correct behavior when training with bf16 + # Size: [max_seq_len, (dim * 2)] + freqs = torch.cat([idx_theta, idx_theta], dim=-1) + rope_cache = torch.cat([freqs.cos(), freqs.sin()], dim=-1) + return rope_cache + + +def rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotates half the hidden dims of the input.""" + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def reshape_for_broadcast(rope_cache: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + """ + Reshape frequency tensor (represented by cos, sin) for broadcasting it with another tensor. + + This function reshapes the frequency tensor to have the same shape as the target tensor 'x' + for the purpose of broadcasting the frequency tensor during element-wise operations. + + The input freqs_cis tensor is assumed to be of shape (max_seqlen, head_dim * 2), + and the first seqlen elements will be sliced, but dim must match x. + + Args: + rope_cache (torch.Tensor): RoPE tensor (cos and sin) to be reshaped. + x (torch.Tensor): Target tensor for broadcasting compatibility. + + Returns: + torch.Tensor: Reshaped frequency tensor. + """ + ndim = x.ndim + assert ndim > 1 + _, seqlen, _, head_dim = x.shape + rope_cache = rope_cache[0:seqlen] + # The shape of rope_cache is (seqlen, head_dim * 2) because we concate cos and sin + assert rope_cache.shape == (seqlen, head_dim * 2) + shape = [-1, seqlen, 1, head_dim * 2] + return rope_cache.view(*shape) + + +def apply_rotary_emb( + xq: torch.Tensor, xk: torch.Tensor, rope_cache: torch.Tensor +) -> tuple[torch.Tensor, torch.Tensor]: + # input tensor x has shape [bsz, seq_len, n_heads, head_dim] + head_dim = xq.shape[-1] + + # reshape for broadcast + rope_cache = reshape_for_broadcast(rope_cache, xq) + + # [bsz, seq_len, 1, head_dim] + cos = rope_cache[..., :head_dim].to(dtype=xq.dtype, device=xq.device) + sin = rope_cache[..., head_dim:].to(dtype=xq.dtype, device=xq.device) + + # xq: [bsz, seq_len, n_heads, head_dim] + # xk: [bsz, seq_len, n_kv_heads, head_dim] + xq_out = (xq * cos) + (rotate_half(xq) * sin) + xk_out = (xk * cos) + (rotate_half(xk) * sin) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + bs, slen, n_kv_heads, head_dim = x.shape + if n_rep == 1: + return x + return ( + torch.unsqueeze(x, dim=3) + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class Attention(nn.Module): + """ + Multi-head attention (MLA) module. + """ + + def __init__(self, model_args: GptOssModelArgs): + super().__init__() + self.head_dim = model_args.head_dim + self.n_heads = model_args.n_heads + self.n_kv_heads = model_args.n_kv_heads + + self.n_rep = self.n_heads // self.n_kv_heads + + self.wq = nn.Linear( + model_args.dim, + model_args.n_heads * model_args.head_dim, + bias=True, + ) + self.wk = nn.Linear( + model_args.dim, + model_args.n_kv_heads * model_args.head_dim, + bias=True, + ) + self.wv = nn.Linear( + model_args.dim, + model_args.n_kv_heads * model_args.head_dim, + bias=True, + ) + self.wo = nn.Linear( + model_args.n_heads * model_args.head_dim, + model_args.dim, + bias=True, + ) + self.sinks = nn.Parameter(torch.empty(model_args.n_heads)) + self.inner_attention = FlexAttentionWrapper() + + def init_weights(self, init_std: float): + linear_list = [ + self.wq, + self.wk, + self.wv, + ] + + nn.init.trunc_normal_(self.sinks, mean=0.0, std=init_std) + for linear in linear_list: + nn.init.trunc_normal_(linear.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(linear.bias, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.wo.weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.wo.bias, mean=0.0, std=init_std) + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType | None, + ): + """ + Forward pass for the Multi-Head Latent Attention (MLA) Layer. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies for rope embedding. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + bsz, seqlen, _ = x.size() + hidden_shape = (bsz, seqlen, -1, self.head_dim) + + q = self.wq(x).view(hidden_shape) + k = self.wk(x).view(hidden_shape) + v = self.wv(x).view(hidden_shape) + + q, k = apply_rotary_emb(q, k, rope_cache) + + # repeat k/v heads if n_kv_heads < n_heads + keys = repeat_kv(k, self.n_rep) + values = repeat_kv(v, self.n_rep) + + xq = q.transpose(1, 2).contiguous() + xk = keys.transpose(1, 2).contiguous() + xv = values.transpose(1, 2).contiguous() + + assert isinstance(attention_masks, BlockMask), attention_masks + output, lse = self.inner_attention( + xq, xk, xv, block_mask=attention_masks, scale=None, return_lse=True + ) + + # Apply attention sink rescaling: rescale by σ(lse - w[h]) + # This is mathematically equivalent to concatenating learnable sink weights + sink_scale = torch.sigmoid(lse - self.sinks.view(1, -1, 1)).unsqueeze(-1) + output = output * sink_scale.to(output.dtype) + + output = output.transpose(1, 2).contiguous() # (B, H, T, D) -> (B, T, H, D) + + # Reshape and project output + output = output.reshape( + bsz, seqlen, -1 + ).contiguous() # (bsz, seqlen, n_heads * v_head_dim) + output = self.wo(output) # (bsz, seqlen, dim) + return output + + +class TransformerBlock(nn.Module): + """ + Transformer block with attention and feed-forward layers. + """ + + def __init__(self, layer_id: int, model_args: GptOssModelArgs): + + super().__init__() + self.use_sliding_attention = layer_id % 2 == 0 + self.attention = Attention(model_args) + self.attention_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.ffn_norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + + self.moe = GptOssMoE( + model_args, dim=model_args.dim, hidden_dim=model_args.moe_inter_dim + ) + self.moe_enabled = True # for composability with load balancing + + self.weight_init_std = 0.02 / (2 * (layer_id + 1)) ** 0.5 + self.layer_id = layer_id + + def forward( + self, + x: torch.Tensor, + rope_cache: torch.Tensor, + attention_masks: AttentionMasksType, + ): + """ + Forward pass for the Transformer block. + + Args: + x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. + attention_masks (AttentionMasksType): a dict of BlockMasks. + + Returns: + torch.Tensor: Output tensor with the same shape as the input. + """ + # Extract the appropriate mask for this layer + if self.use_sliding_attention: + layer_mask = attention_masks.get("sliding_window_mask", None) + else: + layer_mask = attention_masks.get("basic_mask", None) + assert layer_mask is not None + + x = x + self.attention(self.attention_norm(x), rope_cache, layer_mask) + x = x + self.moe(self.ffn_norm(x)) + return x + + def init_weights(self, buffer_device: torch.device): + for norm in (self.attention_norm, self.ffn_norm): + norm.reset_parameters() + self.attention.init_weights(self.weight_init_std) + self.moe.init_weights(self.weight_init_std, buffer_device) + + +class GptOssModel(nn.Module, ModelProtocol): + """ + GPT-OSS Transformer model with attention and feed-forward layers. + """ + + def __init__(self, model_args: GptOssModelArgs): + super().__init__() + self.model_args = model_args + self.max_seq_len = model_args.max_seq_len + self.tok_embeddings = nn.Embedding(model_args.vocab_size, model_args.dim) + self.register_buffer( + "rope_cache", self._precompute_rope_cache(), persistent=False + ) + + self.layers = torch.nn.ModuleDict() + for layer_id in range(model_args.n_layers): + self.layers[str(layer_id)] = TransformerBlock(layer_id, model_args).to( + torch.bfloat16 + ) + + self.norm = nn.RMSNorm(model_args.dim, eps=model_args.norm_eps) + self.output = nn.Linear( + model_args.dim, + model_args.vocab_size, + dtype=torch.get_default_dtype(), + bias=False, + ) + self.model_args = model_args + self.init_weights() + + def init_weights(self, buffer_device: torch.device | None = None) -> None: + buffer_device = buffer_device or self.rope_cache.device + with torch.device(buffer_device): + self.rope_cache = self._precompute_rope_cache() + if self.tok_embeddings is not None: + nn.init.normal_(self.tok_embeddings.weight) + for layer in self.layers.values(): + if layer is not None: + layer.init_weights(buffer_device=buffer_device) + if self.norm is not None: + self.norm.reset_parameters() + final_out_std = self.model_args.dim**-0.5 + cutoff_factor = 3 + if self.output is not None: + nn.init.trunc_normal_( + self.output.weight, + mean=0.0, + std=final_out_std, + a=-cutoff_factor * final_out_std, + b=cutoff_factor * final_out_std, + ) + + def _precompute_rope_cache(self) -> torch.Tensor: + return precompute_rope_cache( + self.model_args.head_dim, + self.model_args.max_seq_len, + self.model_args.rope_theta, + ) + + def get_attention_masks( + self, + input_batch: torch.Tensor, + tokenizer: BaseTokenizer, + extra_inputs: dict[str, torch.Tensor] | None = None, + ) -> AttentionMasksType: + + basic_mask_mods = [] + sliding_window_mask_mods = [ + get_sliding_window_mask_mod(self.model_args.sliding_window_size) + ] + match self.model_args.attn_mask_type: + case "causal": + B = 1 + basic_mask_mods.append(get_causal_mask_mod()) + case "block_causal": + B = input_batch.shape[0] + basic_mask_mods.append( + get_document_mask_mod(input_batch, tokenizer.eos_id) + ) + case _: + raise ValueError( + f"Unknown attention mask type: {self.model_args.attn_mask_type}" + ) + + # create basic attention mask: causal or block_causal + basic_mask = create_attention_mask( + and_masks(*basic_mask_mods), + B, + None, + input_batch.shape[1], + input_batch.shape[1], + ) + + # create sliding window mask, has to be created on top of basic attention mask + sliding_window_mask = create_attention_mask( + and_masks(*basic_mask_mods, *sliding_window_mask_mods), + B, + None, + input_batch.shape[1], + input_batch.shape[1], + ) + + return {"basic_mask": basic_mask, "sliding_window_mask": sliding_window_mask} + + def forward( + self, + tokens: torch.Tensor, + attention_masks: AttentionMasksType, + ): + """ + Forward pass for the Transformer model. + + Args: + tokens (torch.Tensor): Input tensor of token IDs with shape (batch_size, seq_len). + attention_masks (AttentionMasksType): a dict of BlockMasks. + + Returns: + torch.Tensor: Logits tensor of shape (batch_size, vocab_size). + """ + h = self.tok_embeddings(tokens) + + for layer in self.layers.values(): + h = layer(h, self.rope_cache, attention_masks) + h = self.norm(h) + output = self.output(h) + return output diff --git a/torchtitan/experiments/gpt_oss/model/moe.py b/torchtitan/experiments/gpt_oss/model/moe.py new file mode 100644 index 0000000000..94cd266761 --- /dev/null +++ b/torchtitan/experiments/gpt_oss/model/moe.py @@ -0,0 +1,275 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +from typing import Callable + +import torch +from torch import nn +from torch.distributed.tensor import DTensor +from torchtitan.models.moe.moe import MoE +from torchtitan.models.moe.utils import _permute, _unpermute + +from .args import GptOssModelArgs + + +class ScaleBiasForward(torch.autograd.Function): + """ + Custom autograd function that scales bias in forward pass but not in backward. + + For tensor parallel MoE, we need to scale the bias by 1/tp_degree in forward + to cancel the extra reduction effect, but keep the gradient unchanged in backward. + """ + + @staticmethod + def forward(ctx, bias, tp_degree): + ctx.tp_degree = tp_degree + if tp_degree > 1: + return bias / tp_degree + return bias + + @staticmethod + def backward(ctx, grad_output): + # Don't scale the gradient - pass it through as-is + return grad_output, None + + +def indices_padding_wrapper(func: Callable) -> Callable: + """ + In order to use torch._grouped_mm, we need to make sure the number of + tokens each expert gets is a multiple of TOKEN_GROUP_ALIGN_SIZE_M. The + generate_permute_indices kernel also helps achieve this via padding, + without incurring synchronization between device and host. + """ + + def wrapper( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + swiglu_limit: float, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + tp_degree: int = 1, + ) -> torch.Tensor: + num_local_experts = mlp1_weight.shape[0] + ep_degree = num_tokens_per_expert.shape[0] // num_local_experts + + input_shape, x, permuted_indices, num_tokens_per_expert = _permute( + x, num_tokens_per_expert, ep_degree, num_local_experts + ) + + out = func( + mlp1_weight, + mlp1_bias, + mlp2_weight, + mlp2_bias, + swiglu_limit, + x, + num_tokens_per_expert, + tp_degree, + ) + + out = _unpermute(out, input_shape, permuted_indices) + + return out + + return wrapper + + +def swiglu(x, alpha: float = 1.702, limit: float = 7.0): + x_glu, x_linear = x[..., ::2], x[..., 1::2] + # Clamp the input values + x_glu = x_glu.clamp(min=None, max=limit) + x_linear = x_linear.clamp(min=-limit, max=limit) + out_glu = x_glu * torch.sigmoid(alpha * x_glu) + # Note we add an extra bias of 1 to the linear layer + return out_glu * (x_linear + 1) + + +def _run_experts_for_loop( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + swiglu_limit: float, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + tp_degree: int = 1, +) -> torch.Tensor: + # NOTE: this would incur a synchronization between device and host + num_tokens_per_expert = num_tokens_per_expert.tolist() + + # side-effect code due to the usage of generate_permute_indices + num_padding = x.shape[0] - sum(num_tokens_per_expert) + + # a tuple of tensors indexed by experts + # each with shape (tokens_per_expert(varying), dim) + x = torch.split( + x[: sum(num_tokens_per_expert)], + split_size_or_sections=num_tokens_per_expert, + dim=0, + ) + out_experts_splits = [] + for expert_idx, x_expert in enumerate(x): + h = ( + torch.matmul(x_expert, mlp1_weight[expert_idx].transpose(-2, -1)) + + mlp1_bias[expert_idx] + ) + h = swiglu(h, limit=swiglu_limit) + # Apply custom autograd function to scale bias in forward but not in backward + b2 = ScaleBiasForward.apply(mlp2_bias[expert_idx], tp_degree) + h = torch.matmul(h, mlp2_weight[expert_idx].transpose(-2, -1)) + b2 + out_experts_splits.append(h) + out = torch.cat(out_experts_splits, dim=0) + + # side-effect code due to the usage of generate_permute_indices + out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1])))) + + return out + + +def _run_experts_grouped_mm( + mlp1_weight: torch.Tensor, + mlp1_bias: torch.Tensor, + mlp2_weight: torch.Tensor, + mlp2_bias: torch.Tensor, + swiglu_limit: float, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + tp_degree: int = 1, +) -> torch.Tensor: + offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) + num_tokens_per_expert_long = num_tokens_per_expert.to(torch.long) + + h = torch._grouped_mm( + x.bfloat16(), mlp1_weight.transpose(-2, -1).bfloat16(), offs=offsets + ) + + b1 = mlp1_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: + b1 = torch.cat([b1, b1.new_zeros((tail_slack, b1.shape[-1]))], dim=0) + h = h + b1.to(h.dtype) + + h = swiglu(h, limit=swiglu_limit) + h = torch._grouped_mm(h, mlp2_weight.transpose(-2, -1).bfloat16(), offs=offsets) + + # Apply custom autograd function to scale bias in forward but not in backward + b2_base = mlp2_bias.repeat_interleave(num_tokens_per_expert_long, dim=0) + b2 = ScaleBiasForward.apply(b2_base, tp_degree) + tail_slack = x.shape[0] - int(offsets[-1]) + if tail_slack: # padding + b2 = torch.cat([b2, b2.new_zeros((tail_slack, b2.shape[-1]))], dim=0) + h = h + b2.to(h.dtype) + + return h + + +class GptOssGroupedExperts(nn.Module): + def __init__( + self, + dim: int, + hidden_dim: int, + num_experts: int, + swiglu_limit: float, + use_grouped_mm: bool, + ): + super().__init__() + self.num_experts = num_experts + self.use_grouped_mm = use_grouped_mm + self.swiglu_limit = swiglu_limit + + self.mlp1_weight = nn.Parameter( + torch.empty((num_experts, hidden_dim * 2, dim)) + ) # (num_experts, out_dim, in_dim) + self.mlp1_bias = nn.Parameter(torch.empty((num_experts, hidden_dim * 2))) + self.mlp2_weight = nn.Parameter( + torch.empty((num_experts, dim, hidden_dim)) + ) # (num_experts, out_dim, in_dim) + self.mlp2_bias = nn.Parameter(torch.empty((num_experts, dim))) + + def forward( + self, + x: torch.Tensor, + num_tokens_per_expert: torch.Tensor, + ) -> torch.Tensor: + if isinstance(self.mlp1_weight, DTensor): + # Convert parameters from DTensors to plain Tensors, to work with + # dynamic-shape inputs in EP which cannot be easily expressed as DTensors. + mlp1_weight = self.mlp1_weight.to_local() + mlp1_bias = self.mlp1_bias.to_local() + mlp2_weight = self.mlp2_weight.to_local() + mlp2_bias = self.mlp2_bias.to_local() + else: + mlp1_weight = self.mlp1_weight + mlp1_bias = self.mlp1_bias + mlp2_weight = self.mlp2_weight + mlp2_bias = self.mlp2_bias + + # Determine tp_degree from device mesh if available + tp_degree = 1 + if isinstance(self.mlp1_weight, DTensor): + mesh_dim_names = self.mlp1_weight.device_mesh.mesh_dim_names + if "tp" in mesh_dim_names: + tp_dim_idx = mesh_dim_names.index("tp") + tp_degree = self.mlp1_weight.device_mesh.size(tp_dim_idx) + + if self.use_grouped_mm: + if ( + not isinstance(self.mlp1_weight, DTensor) + or "ep" not in self.mlp1_weight.device_mesh.mesh_dim_names + ): + run_experts_fn = indices_padding_wrapper(_run_experts_grouped_mm) + else: + run_experts_fn = _run_experts_grouped_mm + return run_experts_fn( + mlp1_weight, + mlp1_bias, + mlp2_weight, + mlp2_bias, + self.swiglu_limit, + x, + num_tokens_per_expert, + tp_degree, + ) + else: + return _run_experts_for_loop( + mlp1_weight, + mlp1_bias, + mlp2_weight, + mlp2_bias, + self.swiglu_limit, + x, + num_tokens_per_expert, + tp_degree, + ) + + def init_weights(self, init_std: float): + nn.init.trunc_normal_(self.mlp1_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp1_bias, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_weight, mean=0.0, std=init_std) + nn.init.trunc_normal_(self.mlp2_bias, mean=0.0, std=init_std) + + +class GptOssMoE(MoE): + """GptOss MoE implementation that inherits from the base MoE class.""" + + def __init__(self, model_args: GptOssModelArgs, dim: int, hidden_dim: int): + # Convert GptOssModelArgs to MoEArgs for base class compatibility + moe_args = model_args.moe_args + + # Initialize the base MoE class + super().__init__(moe_args, dim, hidden_dim) + + # Override the base GroupedExperts with GptOssGroupedExperts + self.experts = GptOssGroupedExperts( + dim=dim, + hidden_dim=hidden_dim, + num_experts=moe_args.num_experts, + swiglu_limit=model_args.swiglu_limit, + use_grouped_mm=moe_args.use_grouped_mm, + ) diff --git a/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml new file mode 100644 index 0000000000..82b9d0e2ed --- /dev/null +++ b/torchtitan/experiments/gpt_oss/train_configs/debug_model.toml @@ -0,0 +1,82 @@ +[job] +dump_folder = "./outputs" +description = "Gpt-oss debug training" +print_config = false + +[profiling] +enable_profiling = false +save_traces_folder = "profile_trace" +profile_freq = 10 +enable_memory_snapshot = false +save_memory_snapshot_folder = "memory_snapshot" + +[metrics] +log_freq = 1 +disable_color_printing = false +enable_tensorboard = false +save_tb_folder = "tb" +enable_wandb = false + +[model] +name = "gpt_oss" +flavor = "debugmodel" +# test folder with tokenizer.json, for debug purpose only +hf_assets_path = "./tests/assets/tokenizer" +# converters = ["float8"] + +[optimizer] +name = "AdamW" +lr = 8e-4 +eps = 1e-8 + +[lr_scheduler] +warmup_steps = 2 # lr scheduler warm up, normally 20% of the train steps +decay_ratio = 0.8 # lr scheduler decay ratio, 80% of the train steps +decay_type = "linear" +min_lr_factor = 0.0 + +[training] +local_batch_size = 8 +seq_len = 2048 +max_norm = 1.0 # grad norm clipping +steps = 10 +dataset = "c4_test" # supported datasets: c4_test (2K), c4 (177M) + +[parallelism] +data_parallel_replicate_degree = 1 +data_parallel_shard_degree = -1 +fsdp_reshard_after_forward = "default" # default / never / always +tensor_parallel_degree = 1 +enable_async_tensor_parallel = false +expert_parallel_degree = 1 +expert_tensor_parallel_degree = 1 + +[checkpoint] +enable = false +folder = "checkpoint" +interval = 10 +last_save_model_only = false +export_dtype = "float32" +async_mode = "disabled" # ["disabled", "async", "async_with_pinned_mem"] + +[activation_checkpoint] +mode = "none" # ["none", "selective", "full"] +selective_ac_option = '2' # 'int' = ac every positive int layer or 'op', ac based on ops policy + +[compile] +enable = false +components = ["model", "loss"] + +[quantize.linear.float8] +enable_fsdp_float8_all_gather = false +precompute_float8_dynamic_scale_for_fsdp = false +filter_fqns = ["output", "router.gate"] + +[quantize.grouped_mm.float8] +fqns = ["experts"] + +[validation] +enable = false +dataset = "c4_validation" +freq = 5 +steps = 10 diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index bf963a5b5f..3d784f0875 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -15,7 +15,6 @@ from torch.nn.attention import sdpa_kernel, SDPBackend from torch.nn.attention.flex_attention import ( _mask_mod_signature, - AuxOutput, BlockMask, create_block_mask, flex_attention, @@ -27,6 +26,7 @@ "ScaledDotProductAttentionWrapper", "get_causal_mask_mod", "get_document_mask_mod", + "get_sliding_window_mask_mod", "get_fixed_block_mask_mod", "create_attention_mask", ] @@ -57,14 +57,23 @@ def forward( *, block_mask: BlockMask, scale: float | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + return_lse: bool = False, + ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: # 1. _compiled_flex_attn has to be a class variable, otherwise there will # be multiple compiled flex_attention instances, which can be slow. # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in # as the first argument, which will cause an error. # `FlexAttentionWrapper._compiled_flex_attn` is correct. + # 3. Used `return_lse` instead of `return_aux` because of easier TP module notation + # to convert `lse` to be DTensor. + return FlexAttentionWrapper._compiled_flex_attn( - q, k, v, block_mask=block_mask, scale=scale + q, + k, + v, + block_mask=block_mask, + scale=scale, + return_lse=return_lse, ) @@ -174,6 +183,37 @@ def blocked_mask_mod( return blocked_mask_mod +def get_sliding_window_mask_mod(window_size: int) -> _mask_mod_signature: + """Creates a sliding window mask that only attends to tokens within a fixed window size. + + This implements causal sliding window attention where each token can only attend to: + - Itself (current token) + - Up to `window_size - 1` previous tokens + Args: + window_size: The maximum number of tokens to attend to (including current token). + Must be >= 1. A window_size of 1 means attend only to self. + + Returns: + A mask modifier function that implements causal sliding window masking. + """ + + if window_size < 1: + raise ValueError( + f"window_size must be >= 1 for sliding window attention mask, got {window_size}" + ) + + def sliding_window_mod( + b: torch.Tensor, h: torch.Tensor, q_idx: torch.Tensor, kv_idx: torch.Tensor + ) -> torch.Tensor: + # Window mask: can only attend within the window + # q_idx - kv_idx < window_size ensures we look at most window_size-1 tokens back + return (kv_idx <= q_idx) & (q_idx - kv_idx < window_size) + + sliding_window_mod.__name__ = f"sliding_window_mod_window_size_{window_size}" + + return sliding_window_mod + + _compiled_create_block_mask = torch.compile(create_block_mask) diff --git a/torchtitan/models/moe/moe.py b/torchtitan/models/moe/moe.py index 99e0e5d24c..295e2193a5 100644 --- a/torchtitan/models/moe/moe.py +++ b/torchtitan/models/moe/moe.py @@ -109,7 +109,7 @@ def _run_experts_grouped_mm( w2: torch.Tensor, w3: torch.Tensor, x: torch.Tensor, - num_tokens_per_expert: torch.Tensor | None, + num_tokens_per_expert: torch.Tensor, ) -> torch.Tensor: offsets = torch.cumsum(num_tokens_per_expert, dim=0, dtype=torch.int32) diff --git a/torchtitan/models/qwen3/model/model.py b/torchtitan/models/qwen3/model/model.py index 0fff490bf3..d05ed67be4 100644 --- a/torchtitan/models/qwen3/model/model.py +++ b/torchtitan/models/qwen3/model/model.py @@ -336,7 +336,7 @@ def forward( Args: x (torch.Tensor): Input tensor. - freqs_cis (torch.Tensor): Precomputed cosine and sine frequencies. + rope_cache (torch.Tensor): Precomputed cosine and sine frequencies. Returns: torch.Tensor: Output tensor after applying attention and feedforward layers.