|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +from functools import partial |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.distributed._functional_collectives as funcol |
| 11 | +import torch.distributed.distributed_c10d as c10d |
| 12 | +from torch import distributed as dist |
| 13 | +from torch.distributed.device_mesh import DeviceMesh |
| 14 | + |
| 15 | +from torchtitan.components.ft.manager import FTManager |
| 16 | +from torchtitan.config.job_config import JobConfig |
| 17 | +from torchtitan.distributed.parallel_dims import ParallelDims |
| 18 | +from torchtitan.tools.logging import logger |
| 19 | + |
| 20 | + |
| 21 | +IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy |
| 22 | + |
| 23 | + |
| 24 | +# WARNING: currently this does not take into account gradient accumulation |
| 25 | +# and the gradient can still be biased toward grad accum step with less valid tokens |
| 26 | +# See: https://github.com/pytorch/torchtitan/issues/1842 |
| 27 | +def token_imbalance_ce_loss( |
| 28 | + pred: torch.Tensor, |
| 29 | + labels: torch.Tensor, |
| 30 | + token_mesh: DeviceMesh, |
| 31 | + ft_pg: dist.ProcessGroup | None, |
| 32 | +) -> torch.Tensor: |
| 33 | + """ |
| 34 | + Cross‑entropy loss that is *robust* to varying numbers of valid tokens across ranks. |
| 35 | +
|
| 36 | + In a typical distributed training setup (data parallel + sequence parallel), |
| 37 | + each rank computes the loss over **only its local tokens** and returns an |
| 38 | + *average* over those tokens: |
| 39 | +
|
| 40 | + Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients |
| 41 | + across all ranks, the resulting update is equivalent to a **global sample |
| 42 | + average** *only if every rank contains the same number of tokens*. |
| 43 | + In practice that assumption is violated for many workloads: |
| 44 | + - Sequences are padded to a fixed length -> some ranks see fewer real tokens. |
| 45 | + - SFT finetuning where user's queries tokens are masked out. |
| 46 | + - Vision encoders often injects a large number of “ignored” |
| 47 | + tokens as context that are not trained with text tokens' loss. |
| 48 | +
|
| 49 | + This function fixes the issue by **scaling the sum-of-loss** with the *average* |
| 50 | + number of non‑ignored tokens per rank, computed via an all-reduce over |
| 51 | + `token_mesh`. The returned scalar therefore represents the loss that would |
| 52 | + be obtained if every token in the entire distributed batch contributed with |
| 53 | + equal weight to the global gradient, regardless of how many padded or |
| 54 | + ignored tokens each rank contains. |
| 55 | +
|
| 56 | + Parameters |
| 57 | + ---------- |
| 58 | + pred : torch.Tensor |
| 59 | + labels : torch.Tensor |
| 60 | + token_mesh : DeviceMesh |
| 61 | + A device mesh that contains all ranks participating in this training step's |
| 62 | + loss computation. The function performs an ``all_reduce`` (mean) over the |
| 63 | + `num_tokens` tensor of a rank across this mesh. |
| 64 | + ft_pg: dist.ProcessGroup | None |
| 65 | + Optional pg for Fault Tolerance training. |
| 66 | +
|
| 67 | + Returns |
| 68 | + ------- |
| 69 | + torch.Tensor |
| 70 | + A scalar loss tensor, ready for ``backward()`` and FSDP all-reduce mean |
| 71 | +
|
| 72 | + Notes |
| 73 | + ----- |
| 74 | + * The function internally uses :func:`torch.nn.functional.cross_entropy` |
| 75 | + with ``reduction="sum"`` so that each token contributes exactly once to |
| 76 | + the numerator. The denominator is the **average** number of valid tokens |
| 77 | + per rank, not the local count. |
| 78 | + * If a rank contains no valid tokens (i.e., all labels are ``IGNORE_INDEX``), |
| 79 | + its contribution to the sum is zero and its `num_tokens` becomes zero. |
| 80 | + In that case the mean across ranks will still be well‑defined as long as |
| 81 | + at least one rank has non‑zero token count. |
| 82 | + """ |
| 83 | + sum_loss = torch.nn.functional.cross_entropy( |
| 84 | + pred.flatten(0, 1).float(), |
| 85 | + labels.flatten(0, 1), |
| 86 | + reduction="sum", |
| 87 | + ignore_index=IGNORE_INDEX, |
| 88 | + ) |
| 89 | + num_tokens = (labels != IGNORE_INDEX).sum() |
| 90 | + avg_num_tokens_per_rank = funcol.all_reduce( |
| 91 | + num_tokens, reduceOp=c10d.ReduceOp.AVG.name, group=token_mesh |
| 92 | + ) |
| 93 | + if ft_pg is not None: |
| 94 | + avg_num_tokens_per_rank = funcol.all_reduce( |
| 95 | + avg_num_tokens_per_rank, reduceOp=c10d.ReduceOp.AVG.name, group=ft_pg |
| 96 | + ) |
| 97 | + return sum_loss / avg_num_tokens_per_rank |
| 98 | + |
| 99 | + |
| 100 | +def build_token_imbalance_ce_loss( |
| 101 | + job_config: JobConfig, parallel_dims: ParallelDims, ft_manager: FTManager, **kwargs |
| 102 | +): |
| 103 | + del kwargs # delete any unused arguments |
| 104 | + # NOTE: The device mesh where the input tokens w/ shape BSD can be sliced: |
| 105 | + # DP split the batch dim B |
| 106 | + # CP split the sequence dim S |
| 107 | + token_mesh = parallel_dims.world_mesh["dp_cp"] |
| 108 | + ft_pg = ft_manager.loss_sync_pg |
| 109 | + loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg) |
| 110 | + if job_config.compile.enable and "loss" in job_config.compile.components: |
| 111 | + logger.info("Compiling the loss function with torch.compile") |
| 112 | + loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend) |
| 113 | + return loss_fn |
0 commit comments