-
Couldn't load subscription status.
- Fork 576
Refactor attention and make attention mask an argument to the model #1776
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
[ghstack-poisoned]
…odel **Status** The PR is not landable yet but server as a RFC. If people are okay with this design, this PR requires following changes and verifications: 1. Change all models, including the experimental ones. 2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet). 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is #1723. Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix #1723 with pytorch/pytorch#164111 and this PR. 3. Provide a single AttentionOp instead of two. Justification: since the masking logic is moved outside, we don't need to do bookkeeping of masks in FlexAttentionWrapper. The logic is so simple that one AttentionOp makes things cleaner. Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp. See the discussion in #1723. ghstack-source-id: e869695 Pull-Request-resolved: #1776
…odel **Status** The PR is not landable yet but server as a RFC. If people are okay with this design, this PR requires following changes and verifications: 1. Change all models, including the experimental ones. 2. E2E loss verification (this has been done for functional check, but loss verification is noot done yet). 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a seperate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is #1723. Now that pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix #1723 with pytorch/pytorch#164111 and this PR. 3. Provide a single AttentionOp instead of two. Justification: since the masking logic is moved outside, we don't need to do bookkeeping of masks in FlexAttentionWrapper. The logic is so simple that one AttentionOp makes things cleaner. Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certinaly can be confusion for Titan's users. I'm opn to merge them to AttentionOp. See the discussion in #1723. ghstack-source-id: 35aa425 Pull-Request-resolved: #1776
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice refactor! Left many comments lol.
@wwwjn for sliding window attention, you could just create another mask_mod following the examples here.
torchtitan/models/attention.py
Outdated
| return _FlexAttentionWrapper._flex_attn(*args, **kwargs) | ||
|
|
||
|
|
||
| class _ScaledDotProductAttentionWrapper(torch.nn.Module): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
would be good to add comments why we have such wrappers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'll evaluate if we can merge the two wrapper into AttentionOp. This seems to cause a lot of confusion. Even if we enable FlexCP in the future, people may still confuse.
|
Looks like the biggest concerns of this PR
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Very nice refactor.
I have a small suggestion around the get_attention_masks called in train.py .
As all things called in train.py, that would be better if they are a bit more flexible.
[ghstack-poisoned]
|
Hi, unsure where the best place to ask this is but this seems like a relevant recent PR. I have two questions:
|
SDPA doesn't support this so CP + SDPA doesn't support this. The current plan is to wait until SDPA support packed sequences.
Yes, it will support packing sequence. But the current implementation will use allgather only. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sounds good to me
|
|
||
| self.use_flex_attn = model_args.use_flex_attn | ||
| if self.use_flex_attn: | ||
| self.inner_attention = FlexAttentionWrapper() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
another option is to call it self.kernel, as used by some internal
Sorry, could you elaborate why it is unsupported even in the non-CP case? According to the pseudocode can't we pass in |
When I said SDPA supports packed sequence, what I meant is that when SDPA supports varlen version. |
|
All losses match except for llama4 irope. The reason is that the original llama4 irope implementation in TorchTitan is incorrect. More precisely the fixed_block_size_mask_mod implementation is not correct. This PR also fixes it. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Had two more questions.
[ghstack-poisoned]
#1850 removed `name` field in `TrainSpec`. The experiments in simple_fsdp should also be updated. Otherwise it won't run. #1776 added `use_flex_attn` field to `apply_non_moe_tp()`, which is missing in simple_fsdp experiments ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable ``` ``` NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable ```
…ytorch#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1797 * __->__ pytorch#1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is pytorch#1723. pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix pytorch#1723 with pytorch/pytorch#164111 and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in pytorch#1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ```
pytorch#1850 removed `name` field in `TrainSpec`. The experiments in simple_fsdp should also be updated. Otherwise it won't run. pytorch#1776 added `use_flex_attn` field to `apply_non_moe_tp()`, which is missing in simple_fsdp experiments ``` NGPU=8 CONFIG_FILE=./torchtitan/models/llama3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.llama3 --compile.enable ``` ``` NGPU=8 CONFIG_FILE=./torchtitan/models/deepseek_v3/train_configs/debug_model.toml ./run_train.sh --model.name simple_fsdp.deepseek_v3 --compile.enable ```
…ytorch#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1797 * __->__ pytorch#1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is pytorch#1723. pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix pytorch#1723 with pytorch/pytorch#164111 and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in pytorch#1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ```
Introduced by #1776. Verified with the comment: ``` CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --parallelism.context_parallel_degree=8 ```
Introduced by #1776. Verified with the comment: ``` CONFIG_FILE="./torchtitan/models/llama3/train_configs/debug_model.toml" ./run_train.sh --training.steps=10 --parallelism.context_parallel_degree=8 ``` w/o this PR ``` [rank4]:[titan] 2025-10-14 23:10:44,306 - root - INFO - step: 1 loss: 8.0385 grad_norm: 1.3444 memory: 1.21GiB(1.27%) tps: 2,904 tflops: 0.21 mfu: 0.02% [rank4]:[titan] 2025-10-14 23:10:44,306 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank4]:[titan] 2025-10-14 23:10:44,347 - root - INFO - step: 2 loss: 7.6989 grad_norm: 1.4401 memory: 1.34GiB(1.41%) tps: 49,366 tflops: 3.53 mfu: 0.36% [rank4]:[titan] 2025-10-14 23:10:44,388 - root - INFO - step: 3 loss: 7.0687 grad_norm: 1.8302 memory: 1.34GiB(1.41%) tps: 51,400 tflops: 3.68 mfu: 0.37% [rank4]:[titan] 2025-10-14 23:10:44,425 - root - INFO - step: 4 loss: 6.2672 grad_norm: 2.2684 memory: 1.34GiB(1.41%) tps: 55,749 tflops: 3.99 mfu: 0.40% [rank4]:[titan] 2025-10-14 23:10:44,465 - root - INFO - step: 5 loss: 5.3015 grad_norm: 2.5508 memory: 1.34GiB(1.41%) tps: 50,835 tflops: 3.64 mfu: 0.37% [rank4]:[titan] 2025-10-14 23:10:44,522 - root - INFO - step: 6 loss: 4.7779 grad_norm: 2.4103 memory: 1.34GiB(1.41%) tps: 36,188 tflops: 2.59 mfu: 0.26% [rank4]:[titan] 2025-10-14 23:10:44,573 - root - INFO - step: 7 loss: 4.4823 grad_norm: 2.2675 memory: 1.34GiB(1.41%) tps: 41,167 tflops: 2.95 mfu: 0.30% [rank4]:[titan] 2025-10-14 23:10:44,618 - root - INFO - step: 8 loss: 4.3291 grad_norm: 1.9877 memory: 1.34GiB(1.41%) tps: 45,962 tflops: 3.29 mfu: 0.33% [rank4]:[titan] 2025-10-14 23:10:44,656 - root - INFO - step: 9 loss: 4.7022 grad_norm: 1.5639 memory: 1.34GiB(1.41%) tps: 53,689 tflops: 3.84 mfu: 0.39% [rank4]:[titan] 2025-10-14 23:10:44,695 - root - INFO - step: 10 loss: 4.1905 grad_norm: 1.8200 memory: 1.34GiB(1.41%) tps: 52,967 tflops: 3.79 mfu: 0.38% ``` w/ this PR ``` [rank4]:[titan] 2025-10-14 23:09:32,084 - root - INFO - step: 1 loss: 8.1003 grad_norm: 1.4468 memory: 0.23GiB(0.24%) tps: 150 tflops: 0.01 mfu: 0.00% [rank4]:[titan] 2025-10-14 23:09:32,085 - root - INFO - Synchronizing and adjusting timeout for all ProcessGroups to 0:01:40 [rank4]:[titan] 2025-10-14 23:09:32,151 - root - INFO - step: 2 loss: 7.7710 grad_norm: 1.5711 memory: 0.25GiB(0.26%) tps: 30,879 tflops: 2.21 mfu: 0.22% [rank4]:[titan] 2025-10-14 23:09:32,218 - root - INFO - step: 3 loss: 7.0456 grad_norm: 1.9929 memory: 0.25GiB(0.26%) tps: 30,642 tflops: 2.19 mfu: 0.22% [rank4]:[titan] 2025-10-14 23:09:32,283 - root - INFO - step: 4 loss: 6.1601 grad_norm: 2.3669 memory: 0.25GiB(0.26%) tps: 31,723 tflops: 2.27 mfu: 0.23% [rank4]:[titan] 2025-10-14 23:09:32,349 - root - INFO - step: 5 loss: 5.2561 grad_norm: 2.5374 memory: 0.25GiB(0.26%) tps: 31,047 tflops: 2.22 mfu: 0.22% [rank4]:[titan] 2025-10-14 23:09:32,420 - root - INFO - step: 6 loss: 4.8109 grad_norm: 2.8868 memory: 0.25GiB(0.26%) tps: 29,067 tflops: 2.08 mfu: 0.21% [rank4]:[titan] 2025-10-14 23:09:32,488 - root - INFO - step: 7 loss: 4.4534 grad_norm: 2.4835 memory: 0.25GiB(0.26%) tps: 30,383 tflops: 2.17 mfu: 0.22% [rank4]:[titan] 2025-10-14 23:09:32,554 - root - INFO - step: 8 loss: 4.2613 grad_norm: 2.1554 memory: 0.25GiB(0.26%) tps: 31,078 tflops: 2.22 mfu: 0.22% [rank4]:[titan] 2025-10-14 23:09:32,619 - root - INFO - step: 9 loss: 4.6215 grad_norm: 1.7431 memory: 0.25GiB(0.26%) tps: 31,814 tflops: 2.28 mfu: 0.23% [rank4]:[titan] 2025-10-14 23:09:32,687 - root - INFO - step: 10 loss: 4.0993 grad_norm: 2.0867 memory: 0.25GiB(0.26%) tps: 30,272 tflops: 2.17 mfu: 0.22% ```
…ytorch#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1797 * __->__ pytorch#1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is pytorch#1723. pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix pytorch#1723 with pytorch/pytorch#164111 and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in pytorch#1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ```
…ytorch#1776) Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0) (oldest at bottom): * pytorch#1797 * __->__ pytorch#1776 **Status** 1. Change all models, including the experimental ones. 2. E2E loss verification. 3. We should add an unittest for attention. But since we don't have GPU unittest, this can be done in a separate PR. **Summary** This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks. The previous design has several issues, one particular one is pytorch#1723. pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward(). The new design: 1. Model needs to provide `get_attention_masks()` that accepts `create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask. Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks. 2. `get_attention_masks()` will be called from the trainer and the resulting masks are passed to the model.forward(). Justification: this will allow us to fix pytorch#1723 with pytorch/pytorch#164111 and this PR. 3. Now SDPA and FlexAttention are wrapped in two different classes. ~~Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.~~ See the discussion in pytorch#1723. **Verification** *llama3* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" ``` *llama3 flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ``` *llama4* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *llama4 irope* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint ``` *deepseek* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" ``` *deepseek flex* ``` ./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn" ```
Stack from ghstack (oldest at bottom):
Status
Summary
This PR aims to refactor how TorchTitan build the attention masks and pass to model. Before this PR, init_attention_masks() is called in Trainer but the masks are stored as a class variable of FlexAttentionWrapper(). We chose this shortcut to support the case where a single model requires multiple masks.
The previous design has several issues, one particular one is #1723.
pytorch/pytorch#164111 proves that we can let PP split BlockMask, this PR performs the refactor to pass masks as an argument of model.forward().
The new design:
get_attention_masks()that acceptscreate_mask_fn,batch, andeos_id. If the attention op is SDPA, then this API should return None as SDPA currently doesn't support varlen. But once it does, we may have to return some tuple of int that represents the mask.Justification: attention logic is technically a part of the model, but requires some information from trainer/dataloader. So it's model author's responsibility to provide some API that let trainer calls to get the masks.
get_attention_masks()will be called from the trainer and the resulting masks are passed to the model.forward().Justification: this will allow us to fix #1723 with pytorch/pytorch#164111 and this PR.
Note: we still have two very very thin op wrappers that are used for CP. I keep these two for the CP education purpose. But this certainly can be confusion for Titan's users. I'm opnn to merge them to AttentionOp.See the discussion in #1723.
Verification
llama3
llama3 flex
llama4
llama4 irope
deepseek
deepseek flex