Skip to content

Commit 134cff8

Browse files
committed
[RFC] Lift freqs_cis as an input of models
freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness. ghstack-source-id: c30c532 Pull-Request-resolved: #1797
1 parent da0756b commit 134cff8

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

torchtitan/models/llama3/model/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
9292
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
9393
for the purpose of broadcasting the frequency tensor during element-wise operations.
9494
95-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
96-
and the first seqlen elements will be sliced, but dim must match x.
95+
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).
9796
9897
Args:
9998
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
@@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
104103
"""
105104
ndim = x.ndim
106105
assert ndim > 1
106+
batch_size = x.shape[0]
107107
seqlen = x.shape[1]
108-
freqs_cis = freqs_cis[0:seqlen]
109-
assert freqs_cis.shape == (seqlen, x.shape[-1])
110-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
108+
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
109+
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
111110
return freqs_cis.view(*shape)
112111

113112

@@ -474,9 +473,18 @@ def get_attention_masks(
474473
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
475474
)
476475

476+
def get_order_sensitive_buffers(
477+
self,
478+
batch_size: int,
479+
seq_len: int,
480+
) -> tuple[dict[str, torch.Tensor], dict[str, int]]:
481+
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
482+
return ({"freqs_cis": freqs_cis}, {"freqs_cis": 1})
483+
477484
def forward(
478485
self,
479486
tokens: torch.Tensor,
487+
freqs_cis: torch.Tensor,
480488
attention_masks: AttentionMasksType | None = None,
481489
input_batch: torch.Tensor | None = None,
482490
):
@@ -501,7 +509,7 @@ def forward(
501509
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
502510

503511
for layer in self.layers.values():
504-
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
512+
h = layer(h, freqs_cis, attention_masks=attention_masks)
505513

506514
h = self.norm(h) if self.norm else h
507515
output = self.output(h) if self.output else h

torchtitan/protocols/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ def get_attention_masks(
7070
raise NotImplementedError(
7171
"This model does not support attention masking/Flex Attention."
7272
)
73+
74+
def get_order_sensitive_buffers(
75+
self,
76+
batch_size: int,
77+
seq_len: int,
78+
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
79+
raise NotImplementedError()
80+
return ((), ())

torchtitan/train.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,12 @@ def forward_backward_step(
422422
extra_inputs=extra_inputs,
423423
)
424424

425+
# Get the order sensitive buffers
426+
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
427+
inputs.size(0), inputs.size(1)
428+
)
429+
extra_args.update(order_sensitive_buffers[0])
430+
425431
# apply context parallelism if cp is enabled
426432
# ensure CP handles the separate freqs_cis buffer for each pp stage
427433
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
@@ -475,7 +481,11 @@ def forward_backward_step(
475481
with self.train_context(optional_context_parallel_ctx):
476482
assert len(model_parts) == 1
477483
with self.maybe_enable_amp:
478-
pred = model_parts[0](inputs, **extra_inputs, **extra_args)
484+
pred = model_parts[0](
485+
inputs,
486+
**extra_inputs,
487+
**extra_args,
488+
)
479489
loss = self.loss_fn(pred, labels)
480490
# need to free pred before bwd to avoid peaking memory
481491
del pred

0 commit comments

Comments
 (0)