diff --git a/torchtitan/distributed/utils.py b/torchtitan/distributed/utils.py index c2ec7bd777..30bbf207e4 100644 --- a/torchtitan/distributed/utils.py +++ b/torchtitan/distributed/utils.py @@ -19,6 +19,7 @@ from torchtitan.config import Comm as CommConfig, TORCH_DTYPE_MAP from torchtitan.distributed.parallel_dims import ParallelDims +from torchtitan.protocols.model import AttentionMasksType from torchtitan.tools.logging import logger from torchtitan.tools.utils import device_module, device_type @@ -449,3 +450,52 @@ def _clip_grad_norm_with_ep( torch.nn.utils.clip_grads_with_norm_(non_ep_params, max_norm, total_norm, foreach) return total_norm + + +def cp_shard( + cp_mesh: DeviceMesh, + inputs: torch.Tensor, + labels: torch.Tensor, + attention_masks: AttentionMasksType | None, + order_sensitive_buffers: dict[str, torch.Tensor], + order_sensitive_buffers_seq_dims: dict[str, int], +): + from torch.distributed.tensor.experimental._attention import _context_parallel_shard + from torch.nn.attention.flex_attention import BlockMask + + load_balancer = None + inputs, labels = _context_parallel_shard( + mesh=cp_mesh, + buffers=(inputs, labels), + seq_dims=(1, 1), + load_balancer=load_balancer, + ) + + order_sensitive_buffers = _context_parallel_shard( + mesh=cp_mesh, + buffers=order_sensitive_buffers, + seq_dims=order_sensitive_buffers_seq_dims, + load_balancer=load_balancer, + ) + + if attention_masks is None: + return inputs, labels, None, order_sensitive_buffers + + masks = ( + [attention_masks] + if isinstance(attention_masks, BlockMask) + else list(attention_masks.values()) + ) + masks = _context_parallel_shard( + mesh=cp_mesh, + buffers=masks, + seq_dims=(2,) * len(masks), + load_balancer=load_balancer, + ) + attention_masks = ( + masks[0] + if isinstance(attention_masks, BlockMask) + else {k: v for k, v in zip(attention_masks.keys(), masks)} + ) + + return inputs, labels, attention_masks, order_sensitive_buffers diff --git a/torchtitan/models/attention.py b/torchtitan/models/attention.py index bf963a5b5f..9e959978b3 100644 --- a/torchtitan/models/attention.py +++ b/torchtitan/models/attention.py @@ -16,7 +16,6 @@ from torch.nn.attention.flex_attention import ( _mask_mod_signature, AuxOutput, - BlockMask, create_block_mask, flex_attention, ) @@ -49,23 +48,13 @@ class FlexAttentionWrapper(torch.nn.Module): flex_attention, mode="max-autotune-no-cudagraphs" ) - def forward( - self, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - *, - block_mask: BlockMask, - scale: float | None = None, - ) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: + def forward(self, *args, **kwargs) -> torch.Tensor | tuple[torch.Tensor, AuxOutput]: # 1. _compiled_flex_attn has to be a class variable, otherwise there will # be multiple compiled flex_attention instances, which can be slow. # 2. `self._compiled_flex_attn` is not correct, `self` will be passed in # as the first argument, which will cause an error. # `FlexAttentionWrapper._compiled_flex_attn` is correct. - return FlexAttentionWrapper._compiled_flex_attn( - q, k, v, block_mask=block_mask, scale=scale - ) + return FlexAttentionWrapper._compiled_flex_attn(*args, **kwargs) class ScaledDotProductAttentionWrapper(torch.nn.Module): diff --git a/torchtitan/models/llama3/infra/parallelize.py b/torchtitan/models/llama3/infra/parallelize.py index 4944af569e..afcfdff58c 100644 --- a/torchtitan/models/llama3/infra/parallelize.py +++ b/torchtitan/models/llama3/infra/parallelize.py @@ -27,6 +27,7 @@ from torchtitan.distributed import ParallelDims from torchtitan.distributed.activation_checkpoint import apply_ac from torchtitan.distributed.tensor_parallel import maybe_enable_async_tp +from torchtitan.protocols.model import AttentionMasksType from torchtitan.tools.logging import logger @@ -67,10 +68,6 @@ def parallelize_llama( ({parallel_dims.tp}) and 2 * CP degree ({parallel_dims.cp}). """ - use_flex_attn = getattr(model.model_args, "use_flex_attn", False) - if job_config.parallelism.context_parallel_degree > 1 and use_flex_attn: - raise NotImplementedError("CP support for FlexAttention is still in progress.") - if parallel_dims.tp_enabled: enable_float8_linear = "float8" in job_config.model.converters float8_is_rowwise = job_config.quantize.linear.float8.recipe_name in ( @@ -91,6 +88,11 @@ def parallelize_llama( ) maybe_enable_async_tp(job_config, world_mesh["tp"]) + use_flex_attn = getattr(model.model_args, "use_flex_attn", False) + if parallel_dims.cp_enabled: + logger.info("Applied Context Parallel to the model") + apply_cp(model, world_mesh["cp"], use_flex_attn) + model_compile_enabled = ( job_config.compile.enable and "model" in job_config.compile.components ) @@ -131,9 +133,6 @@ def parallelize_llama( else: logger.info("Applied FSDP to the model") - if parallel_dims.cp_enabled: - logger.info("Applied Context Parallel to the model") - if job_config.training.enable_cpu_offload: logger.info("Applied CPU Offloading to the model") elif parallel_dims.dp_replicate_enabled: @@ -335,3 +334,87 @@ def apply_ddp( replicate(model, device_mesh=dp_mesh, bucket_cap_mb=100) logger.info("Applied DDP to the model") + + +def apply_cp( + model: nn.Module, + cp_mesh: DeviceMesh, + use_flex_attn: bool, +) -> None: + """ + Apply context parallelism to the model. + """ + from torch.distributed.tensor.experimental._attention import ( + _ContextParallel, + _enable_context_parallel_dispatcher, + ) + + # Apply context parallelism to every transformer block + # TODO: make seq_sim configurable once the implementation doesn't assume 2 + # internally. + if use_flex_attn: + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.FLEX + ) + else: + # This is currently required as DTensor dispatcher is not enabled to + # dispatch SDPA to CP implementation. We don't disable the CP + # dispatching in TorchTitan as it is not needed. But there is a + # corresponding API, _disable_context_parallel_dispatcher to do + # that if users have this use case. + _enable_context_parallel_dispatcher() + cp_plan = _ContextParallel( + seq_dim=2, attention_type=_ContextParallel.AttentionType.SDPA + ) + + for transformer_block in model.layers.values(): + parallelize_module( + module=transformer_block.attention.inner_attention, + device_mesh=cp_mesh, + parallelize_plan=cp_plan, + ) + + +def cp_shard( + cp_mesh: DeviceMesh, + inputs: torch.Tensor, + labels: torch.Tensor, + attention_masks: AttentionMasksType, + order_sensitive_buffers: dict[str, torch.Tensor], + order_sensitive_buffers_seq_dims: dict[str, int], +): + from torch.distributed.tensor.experimental._attention import _context_parallel_shard + from torch.nn.attention.flex_attention import BlockMask + + load_balancer = None + inputs, labels = _context_parallel_shard( + mesh=cp_mesh, + buffers=(inputs, labels), + seq_dims=(1, 1), + load_balancer=load_balancer, + ) + + masks = ( + [attention_masks] + if isinstance(attention_masks, BlockMask) + else list(attention_masks.values()) + ) + masks = _context_parallel_shard( + mesh=cp_mesh, + buffers=masks, + seq_dims=(2,) * len(masks), + load_balancer=load_balancer, + ) + attention_masks = ( + masks[0] + if isinstance(attention_masks, BlockMask) + else {k: v for k, v in zip(attention_masks.keys(), masks)} + ) + + order_sensitive_buffers = _context_parallel_shard( + mesh=cp_mesh, + buffers=order_sensitive_buffers, + seq_dims=order_sensitive_buffers_seq_dims, + load_balancer=load_balancer, + ) + return inputs, labels, attention_masks, order_sensitive_buffers diff --git a/torchtitan/models/llama3/model/args.py b/torchtitan/models/llama3/model/args.py index 2bdafa6964..f727b6f80d 100644 --- a/torchtitan/models/llama3/model/args.py +++ b/torchtitan/models/llama3/model/args.py @@ -55,11 +55,6 @@ def update_from_config(self, job_config: JobConfig, **kwargs) -> None: ) self.max_seq_len = seq_len - if job_config.parallelism.context_parallel_degree > 1 and self.use_flex_attn: - raise NotImplementedError( - "CP support for FlexAttention is still in progress." - ) - def get_nparams_and_flops( self, model: nn.Module, seq_len: int ) -> tuple[int, float]: diff --git a/torchtitan/train.py b/torchtitan/train.py index 0a3831f485..8c60d858f0 100644 --- a/torchtitan/train.py +++ b/torchtitan/train.py @@ -426,50 +426,48 @@ def forward_backward_step( tokenizer=self.tokenizer, extra_inputs=extra_inputs, ) + else: + extra_args["attention_masks"] = None # Get the order sensitive buffers order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers( inputs.size(0), inputs.size(1) ) - extra_args.update(order_sensitive_buffers[0]) - - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + if cp_mesh: + ( + inputs, + labels, + extra_args["attention_masks"], + *order_sensitive_buffers, + ) = dist_utils.cp_shard( + cp_mesh, + inputs, + labels, + extra_args["attention_masks"], + *order_sensitive_buffers, ) - if parallel_dims.cp_enabled - else None - ) + extra_args.update(order_sensitive_buffers[0]) if parallel_dims.pp_enabled: # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): - targets, losses = ( - (labels, []) if self.pp_has_last_stage else (None, None) + targets, losses = (labels, []) if self.pp_has_last_stage else (None, None) + if self.pp_has_first_stage: + self.pp_schedule.step( + inputs, + **extra_inputs, + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, + ) + else: + self.pp_schedule.step( + **extra_args, + target=targets, + losses=losses, + input_batch=inputs, ) - if self.pp_has_first_stage: - self.pp_schedule.step( - inputs, - **extra_inputs, - **extra_args, - target=targets, - losses=losses, - input_batch=inputs, - ) - else: - self.pp_schedule.step( - **extra_args, - target=targets, - losses=losses, - input_batch=inputs, - ) # accumulate losses across pipeline microbatches # TODO: PP+FSDP unexpectedly puts the loss back to the CPU @@ -483,18 +481,17 @@ def forward_backward_step( ) else: # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 - with self.maybe_enable_amp: - pred = model_parts[0]( - inputs, - **extra_inputs, - **extra_args, - ) - loss = self.loss_fn(pred, labels) - # need to free pred before bwd to avoid peaking memory - del pred - loss.backward() + assert len(model_parts) == 1 + with self.maybe_enable_amp: + pred = model_parts[0]( + inputs, + **extra_inputs, + **extra_args, + ) + loss = self.loss_fn(pred, labels) + # need to free pred before bwd to avoid peaking memory + del pred + loss.backward() return loss