Skip to content

Commit cdadbf8

Browse files
committed
[RFC] Lift freqs_cis as an input of models
freqs_cis is sensitive to the sequence order. CP load balancing will shuffle the samples, so each batch will have different orders. As a result, we will have to lift these order senstive buffer to the inputs and broadcast them along the batch dimension so that PP will correctly shard freqs_cis without messing up the correctness. ghstack-source-id: 2d88844 Pull-Request-resolved: #1797
1 parent 24f3fd4 commit cdadbf8

File tree

3 files changed

+34
-7
lines changed

3 files changed

+34
-7
lines changed

torchtitan/models/llama3/model/model.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
9292
This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
9393
for the purpose of broadcasting the frequency tensor during element-wise operations.
9494
95-
The input freqs_cis tensor is assumed to be of shape (max_seqlen, dim),
96-
and the first seqlen elements will be sliced, but dim must match x.
95+
The input freqs_cis tensor is assumed to be of shape (batch_size, seqlen, dim).
9796
9897
Args:
9998
freqs_cis (torch.Tensor): Frequency tensor to be reshaped.
@@ -104,10 +103,10 @@ def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor) -> torch.Ten
104103
"""
105104
ndim = x.ndim
106105
assert ndim > 1
106+
batch_size = x.shape[0]
107107
seqlen = x.shape[1]
108-
freqs_cis = freqs_cis[0:seqlen]
109-
assert freqs_cis.shape == (seqlen, x.shape[-1])
110-
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
108+
assert freqs_cis.shape == (batch_size, seqlen, x.shape[-1])
109+
shape = [d if i in (0, 1, ndim - 1) else 1 for i, d in enumerate(x.shape)]
111110
return freqs_cis.view(*shape)
112111

113112

@@ -472,9 +471,18 @@ def get_attention_masks(
472471
and_masks(*mask_mods), B, None, input_batch.shape[1], input_batch.shape[1]
473472
)
474473

474+
def get_order_sensitive_buffers(
475+
self,
476+
batch_size: int,
477+
seq_len: int,
478+
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
479+
freqs_cis = self.freqs_cis[:seq_len].repeat(batch_size, 1, 1)
480+
return ((freqs_cis,), (1,))
481+
475482
def forward(
476483
self,
477484
tokens: torch.Tensor,
485+
freqs_cis: torch.Tensor,
478486
attention_masks: AttentionMasksType | None = None,
479487
input_batch: torch.Tensor | None = None,
480488
):
@@ -499,7 +507,7 @@ def forward(
499507
h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens
500508

501509
for layer in self.layers.values():
502-
h = layer(h, self.freqs_cis, attention_masks=attention_masks)
510+
h = layer(h, freqs_cis, attention_masks=attention_masks)
503511

504512
h = self.norm(h) if self.norm else h
505513
output = self.output(h) if self.output else h

torchtitan/protocols/model.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,3 +70,11 @@ def get_attention_masks(
7070
raise NotImplementedError(
7171
"This model does not support attention masking/Flex Attention."
7272
)
73+
74+
def get_order_sensitive_buffers(
75+
self,
76+
batch_size: int,
77+
seq_len: int,
78+
) -> tuple[tuple[torch.Tensor, ...], tuple[int, ...]]:
79+
raise NotImplementedError()
80+
return ((), ())

torchtitan/train.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,11 @@ def forward_backward_step(
421421
extra_inputs=extra_inputs,
422422
)
423423

424+
# Get the order sensitive buffers
425+
order_sensitive_buffers = model_parts[0].get_order_sensitive_buffers(
426+
inputs.size(0), inputs.size(1)
427+
)
428+
424429
# apply context parallelism if cp is enabled
425430
# ensure CP handles the separate freqs_cis buffer for each pp stage
426431
cp_mesh = parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None
@@ -445,13 +450,15 @@ def forward_backward_step(
445450
if self.pp_has_first_stage:
446451
self.pp_schedule.step(
447452
inputs,
453+
*order_sensitive_buffers[0],
448454
**extra_inputs,
449455
target=targets,
450456
losses=losses,
451457
input_batch=inputs,
452458
)
453459
else:
454460
self.pp_schedule.step(
461+
*order_sensitive_buffers[0],
455462
target=targets,
456463
losses=losses,
457464
input_batch=inputs,
@@ -472,7 +479,11 @@ def forward_backward_step(
472479
with self.train_context(optional_context_parallel_ctx):
473480
assert len(model_parts) == 1
474481
with self.maybe_enable_amp:
475-
pred = model_parts[0](inputs, **extra_inputs)
482+
pred = model_parts[0](
483+
inputs,
484+
*order_sensitive_buffers[0],
485+
**extra_inputs,
486+
)
476487
loss = self.loss_fn(pred, labels)
477488
# need to free pred before bwd to avoid peaking memory
478489
del pred

0 commit comments

Comments
 (0)