Skip to content

Commit 638da2e

Browse files
lkhphucgithubsgi
authored andcommitted
[VLM] Add token-imbalance loss (pytorch#1803)
In VLM interleaved training, with native resolution and aspect ratio, the number of tokens participating in loss computation differ per rank. Naive FSDP gradient averaging across data ranks can causes tokens on ranks with fewer valid tokens to contribute more to the loss than on other ranks. This PR address this via loss balancing, which incur an additional comm in the loss computation. In practice, I haven't notice any impacts from this comm. #### Quick sanity check Let have a sum loss of all tokens on each rank i, with $N_i$ number of tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i = \sum_{j=1}^{N_i}\nabla\ell_{ij}$ If we multiply the *loss* on each rank by a constant factor **c** (the same for all ranks), then after `backward()`: $$ \tilde g_i = c \cdot g_i . $$ FSDP will *average* these gradients across ranks: $$ g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i =\frac{c}{R}\sum_{i=1}^{R} g_i . $$ We want this to equal the **global‑sample average**: $$ g_{\text{true}} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla \ell_{ij} =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i . $$ Thus for FSDP gradient to be correct, we need $$ \frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad c=\frac{R}{N_{\text{total}}}. $$ So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide the per-rank sum loss with $N_{\text{total}}/R$, which is **average number of tokens per rank**. Intuitively, this is the same as default cross-entropy loss, but instead of diving sum loss on a rank by the number of tokens **on that rank**, we now divide by the **average number of tokens across all rank** P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the base branch. Maybe it will be easier to review once that PR is merged.
1 parent fd224b3 commit 638da2e

File tree

4 files changed

+118
-3
lines changed

4 files changed

+118
-3
lines changed

torchtitan/components/loss.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ def cross_entropy_loss(pred: torch.Tensor, labels: torch.Tensor) -> torch.Tensor
2323
)
2424

2525

26-
def build_cross_entropy_loss(job_config: JobConfig):
26+
def build_cross_entropy_loss(job_config: JobConfig, **kwargs):
27+
del kwargs # delete any unused arguments
2728
loss_fn = cross_entropy_loss
2829
if job_config.compile.enable and "loss" in job_config.compile.components:
2930
logger.info("Compiling the loss function with torch.compile")

torchtitan/experiments/vlm/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
from dataclasses import asdict, replace
88

9-
from torchtitan.components.loss import build_cross_entropy_loss
109
from torchtitan.components.lr_scheduler import build_lr_schedulers
1110
from torchtitan.components.optimizer import build_optimizers
1211
from torchtitan.components.tokenizer import build_hf_tokenizer
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from functools import partial
8+
9+
import torch
10+
import torch.distributed._functional_collectives as funcol
11+
import torch.distributed.distributed_c10d as c10d
12+
from torch import distributed as dist
13+
from torch.distributed.device_mesh import DeviceMesh
14+
15+
from torchtitan.components.ft.manager import FTManager
16+
from torchtitan.config.job_config import JobConfig
17+
from torchtitan.distributed.parallel_dims import ParallelDims
18+
from torchtitan.tools.logging import logger
19+
20+
21+
IGNORE_INDEX = -100 # Pytorch's default for F.cross_entropy
22+
23+
24+
# WARNING: currently this does not take into account gradient accumulation
25+
# and the gradient can still be biased toward grad accum step with less valid tokens
26+
# See: https://github.com/pytorch/torchtitan/issues/1842
27+
def token_imbalance_ce_loss(
28+
pred: torch.Tensor,
29+
labels: torch.Tensor,
30+
token_mesh: DeviceMesh,
31+
ft_pg: dist.ProcessGroup | None,
32+
) -> torch.Tensor:
33+
"""
34+
Cross‑entropy loss that is *robust* to varying numbers of valid tokens across ranks.
35+
36+
In a typical distributed training setup (data parallel + sequence parallel),
37+
each rank computes the loss over **only its local tokens** and returns an
38+
*average* over those tokens:
39+
40+
Afterwards, when Fully‑Sharded Data Parallel (FSDP) averages the gradients
41+
across all ranks, the resulting update is equivalent to a **global sample
42+
average** *only if every rank contains the same number of tokens*.
43+
In practice that assumption is violated for many workloads:
44+
- Sequences are padded to a fixed length -> some ranks see fewer real tokens.
45+
- SFT finetuning where user's queries tokens are masked out.
46+
- Vision encoders often injects a large number of “ignored”
47+
tokens as context that are not trained with text tokens' loss.
48+
49+
This function fixes the issue by **scaling the sum-of-loss** with the *average*
50+
number of non‑ignored tokens per rank, computed via an all-reduce over
51+
`token_mesh`. The returned scalar therefore represents the loss that would
52+
be obtained if every token in the entire distributed batch contributed with
53+
equal weight to the global gradient, regardless of how many padded or
54+
ignored tokens each rank contains.
55+
56+
Parameters
57+
----------
58+
pred : torch.Tensor
59+
labels : torch.Tensor
60+
token_mesh : DeviceMesh
61+
A device mesh that contains all ranks participating in this training step's
62+
loss computation. The function performs an ``all_reduce`` (mean) over the
63+
`num_tokens` tensor of a rank across this mesh.
64+
ft_pg: dist.ProcessGroup | None
65+
Optional pg for Fault Tolerance training.
66+
67+
Returns
68+
-------
69+
torch.Tensor
70+
A scalar loss tensor, ready for ``backward()`` and FSDP all-reduce mean
71+
72+
Notes
73+
-----
74+
* The function internally uses :func:`torch.nn.functional.cross_entropy`
75+
with ``reduction="sum"`` so that each token contributes exactly once to
76+
the numerator. The denominator is the **average** number of valid tokens
77+
per rank, not the local count.
78+
* If a rank contains no valid tokens (i.e., all labels are ``IGNORE_INDEX``),
79+
its contribution to the sum is zero and its `num_tokens` becomes zero.
80+
In that case the mean across ranks will still be well‑defined as long as
81+
at least one rank has non‑zero token count.
82+
"""
83+
sum_loss = torch.nn.functional.cross_entropy(
84+
pred.flatten(0, 1).float(),
85+
labels.flatten(0, 1),
86+
reduction="sum",
87+
ignore_index=IGNORE_INDEX,
88+
)
89+
num_tokens = (labels != IGNORE_INDEX).sum()
90+
avg_num_tokens_per_rank = funcol.all_reduce(
91+
num_tokens, reduceOp=c10d.ReduceOp.AVG.name, group=token_mesh
92+
)
93+
if ft_pg is not None:
94+
avg_num_tokens_per_rank = funcol.all_reduce(
95+
avg_num_tokens_per_rank, reduceOp=c10d.ReduceOp.AVG.name, group=ft_pg
96+
)
97+
return sum_loss / avg_num_tokens_per_rank
98+
99+
100+
def build_token_imbalance_ce_loss(
101+
job_config: JobConfig, parallel_dims: ParallelDims, ft_manager: FTManager, **kwargs
102+
):
103+
del kwargs # delete any unused arguments
104+
# NOTE: The device mesh where the input tokens w/ shape BSD can be sliced:
105+
# DP split the batch dim B
106+
# CP split the sequence dim S
107+
token_mesh = parallel_dims.world_mesh["dp_cp"]
108+
ft_pg = ft_manager.loss_sync_pg
109+
loss_fn = partial(token_imbalance_ce_loss, token_mesh=token_mesh, ft_pg=ft_pg)
110+
if job_config.compile.enable and "loss" in job_config.compile.components:
111+
logger.info("Compiling the loss function with torch.compile")
112+
loss_fn = torch.compile(loss_fn, backend=job_config.compile.backend)
113+
return loss_fn

torchtitan/train.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,9 @@ def __init__(self, job_config: JobConfig):
203203
init_device = device_type
204204
buffer_device = None
205205

206-
self.loss_fn = self.train_spec.build_loss_fn(job_config)
206+
self.loss_fn = self.train_spec.build_loss_fn(
207+
job_config, parallel_dims=parallel_dims, ft_manager=self.ft_manager
208+
)
207209

208210
# verify batch sizes
209211
global_batch_size = job_config.training.global_batch_size

0 commit comments

Comments
 (0)