-
Notifications
You must be signed in to change notification settings - Fork 106
feat: add context parallel support for SFT #420
Changes from all commits
ee06038
0a5c5da
5b26ab3
2919f22
35f8be9
06caee7
c17c489
82b5166
8b31d38
f8b5130
4e8e7bf
19467e8
ea3b5ba
aa717f1
165bcb6
480dd16
77436c9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -379,10 +379,22 @@ def build_dataset(index, name): | |||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None): | ||||||||||||||||||||||||
| def build_sft_dataset( | ||||||||||||||||||||||||
| data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None, model_cfg=None | ||||||||||||||||||||||||
| ): | ||||||||||||||||||||||||
| packed_sequence = data_cfg.get("packed_sequence", False) | ||||||||||||||||||||||||
| dataset_kwargs = {} | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| # TE requires that the first input dim is divisible by 8 and the second by 16 for fp8 | ||||||||||||||||||||||||
| # When using sequence parallel, sequence will further be split by TP size | ||||||||||||||||||||||||
| # When using context parallel, sequence is split by CP size as well | ||||||||||||||||||||||||
| pad_seq_length_to_mult = 16 | ||||||||||||||||||||||||
| if model_cfg is not None: | ||||||||||||||||||||||||
| pad_seq_length_to_mult = ( | ||||||||||||||||||||||||
| 8 * model_cfg.get("tensor_model_parallel_size", 1) if model_cfg.get("sequence_parallel", False) else 16 | ||||||||||||||||||||||||
| ) | ||||||||||||||||||||||||
| pad_seq_length_to_mult *= model_cfg.get("context_parallel_size", 1) | ||||||||||||||||||||||||
|
Comment on lines
+391
to
+396
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. According to that fp8 comment above, should this be:
Suggested change
? From the comment it sounds like if someone is doing fp8 SFT with TP=1 and set sequence_parallel, then the padding would be too small
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That chunk was taken directly from here: https://github.com/NVIDIA/NeMo/blob/b847bf75c371931e4f17ea426741c1d023afa0c0/nemo/collections/nlp/models/language_modeling/megatron_gpt_sft_model.py#L262-L268, but the code does seem to contradict the comment. I'll follow up with the TE team
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. From the TE team: "when SP=True, the dimensions are flipped so the sequence dimension is first. So we only need to make sure it's divisible by 8 after the TP split to comply with TE's expectations." |
||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| if is_chat: | ||||||||||||||||||||||||
| assert not packed_sequence, "Sequence packing is currently not supported with chat datasets." | ||||||||||||||||||||||||
| dataset_cls = GPTSFTChatDataset | ||||||||||||||||||||||||
|
|
@@ -401,6 +413,7 @@ def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, i | |||||||||||||||||||||||
| tokenizer=tokenizer, | ||||||||||||||||||||||||
| max_seq_length=data_cfg.max_seq_length, | ||||||||||||||||||||||||
| min_seq_length=data_cfg.min_seq_length, | ||||||||||||||||||||||||
| pad_seq_length_to_mult=pad_seq_length_to_mult, | ||||||||||||||||||||||||
| add_bos=data_cfg.get("add_bos", False), | ||||||||||||||||||||||||
| add_eos=data_cfg.get("add_eos", True), | ||||||||||||||||||||||||
| add_sep=data_cfg.get("add_sep", False), | ||||||||||||||||||||||||
|
|
||||||||||||||||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -88,7 +88,7 @@ def get_loss_and_metrics(self, batch, forward_only): | |
| set_sync_funcs(self, forward_only) | ||
|
|
||
| fwd_bwd_function = get_forward_backward_func() | ||
| fwd_loss_fn = self.get_forward_output_and_loss_func(forward_only) | ||
| fwd_loss_fn = self.get_forward_output_and_loss_func(forward_only, tuning=True) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what does tuning do?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It controls the keys that are returned in the batch: https://github.com/NVIDIA/NeMo/blob/b847bf75c371931e4f17ea426741c1d023afa0c0/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py#L1211-L1222. If Note also that |
||
|
|
||
| losses_reduced = fwd_bwd_function( | ||
| forward_step_func=fwd_loss_fn, | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For my education, why do we need to specify this now?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't actually think it's necessary since TE is enabled by default. I just wanted to make explicit the fact that we were using TE. But I will remove this