Skip to content

Conversation

XilunWu
Copy link
Contributor

@XilunWu XilunWu commented Oct 15, 2025

Stack from ghstack (oldest at bottom):

Summary
Enable llama3-8B model to use flex_attention.

Test
CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml NGPU=4 LOG_RANK=0,1,2,3 ./run_train.sh --model.flavor "8B_flex_attn" --activation_checkpoint.mode "full" --parallelism.context_parallel_degree 2

[rank3]:[titan] 2025-10-15 10:37:01,390 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 86  tflops: 4.98  mfu: 0.50%
[rank0]:[titan] 2025-10-15 10:37:01,391 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 87  tflops: 5.02  mfu: 0.51%
[rank1]:[titan] 2025-10-15 10:37:01,391 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 87  tflops: 5.03  mfu: 0.51%
[rank2]:[titan] 2025-10-15 10:37:01,391 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 86  tflops: 4.99  mfu: 0.50%
[rank2]:[titan] 2025-10-15 10:37:07,957 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,615  tflops: 325.20  mfu: 32.88%
[rank3]:[titan] 2025-10-15 10:37:07,957 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,615  tflops: 325.20  mfu: 32.88%
[rank0]:[titan] 2025-10-15 10:37:07,957 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,616  tflops: 325.22  mfu: 32.88%
[rank1]:[titan] 2025-10-15 10:37:07,956 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,615  tflops: 325.20  mfu: 32.88%
[rank3]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.73  mfu: 32.63%
[rank0]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.74  mfu: 32.63%
[rank1]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.73  mfu: 32.63%
[rank2]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.73  mfu: 32.63%
[rank3]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank0]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank1]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank2]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank3]:[titan] 2025-10-15 10:37:30,028 - root - INFO - step: 40  loss:  7.4543  grad_norm:  3.6042  memory: 44.84GiB(47.19%)  tps: 5,582  tflops: 323.26  mfu: 32.69%
[rank0]:[titan] 2025-10-15 10:37:30,028 - root - INFO - step: 40  loss:  7.4543  grad_norm:  3.6042  memory: 44.84GiB(47.19%)  tps: 5,582  tflops: 323.27  mfu: 32.69%
[rank1]:[titan] 2025-10-15 10:37:30,028 - root - INFO - step: 40  loss:  7.4543  grad_norm:  3.6042  memory: 44.84GiB(47.19%)  tps: 5,582  tflops: 323.25  mfu: 32.69%

XilunWu added a commit that referenced this pull request Oct 15, 2025
ghstack-source-id: 2e7013d
Pull Request resolved: #1884
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 15, 2025
@XilunWu XilunWu marked this pull request as draft October 15, 2025 18:07
**Summary**
Enable llama3-8B model to use `flex_attention`.

**Test**
`CONFIG_FILE=./torchtitan/models/llama3/train_configs/llama3_8b.toml NGPU=4 LOG_RANK=0,1,2,3 ./run_train.sh --model.flavor "8B_flex_attn" --activation_checkpoint.mode "full" --parallelism.context_parallel_degree 2`

```
[rank3]:[titan] 2025-10-15 10:37:01,390 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 86  tflops: 4.98  mfu: 0.50%
[rank0]:[titan] 2025-10-15 10:37:01,391 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 87  tflops: 5.02  mfu: 0.51%
[rank1]:[titan] 2025-10-15 10:37:01,391 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 87  tflops: 5.03  mfu: 0.51%
[rank2]:[titan] 2025-10-15 10:37:01,391 - root - INFO - step:  1  loss: 12.2465  grad_norm:  4.5253  memory: 33.21GiB(34.95%)  tps: 86  tflops: 4.99  mfu: 0.50%
[rank2]:[titan] 2025-10-15 10:37:07,957 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,615  tflops: 325.20  mfu: 32.88%
[rank3]:[titan] 2025-10-15 10:37:07,957 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,615  tflops: 325.20  mfu: 32.88%
[rank0]:[titan] 2025-10-15 10:37:07,957 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,616  tflops: 325.22  mfu: 32.88%
[rank1]:[titan] 2025-10-15 10:37:07,956 - root - INFO - step: 10  loss: 10.2391  grad_norm: 16.1413  memory: 44.84GiB(47.19%)  tps: 5,615  tflops: 325.20  mfu: 32.88%
[rank3]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.73  mfu: 32.63%
[rank0]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.74  mfu: 32.63%
[rank1]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.73  mfu: 32.63%
[rank2]:[titan] 2025-10-15 10:37:15,308 - root - INFO - step: 20  loss:  8.5913  grad_norm:  7.1942  memory: 44.84GiB(47.19%)  tps: 5,573  tflops: 322.73  mfu: 32.63%
[rank3]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank0]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank1]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank2]:[titan] 2025-10-15 10:37:22,689 - root - INFO - step: 30  loss:  7.7257  grad_norm:  2.7261  memory: 44.84GiB(47.19%)  tps: 5,550  tflops: 321.44  mfu: 32.50%
[rank3]:[titan] 2025-10-15 10:37:30,028 - root - INFO - step: 40  loss:  7.4543  grad_norm:  3.6042  memory: 44.84GiB(47.19%)  tps: 5,582  tflops: 323.26  mfu: 32.69%
[rank0]:[titan] 2025-10-15 10:37:30,028 - root - INFO - step: 40  loss:  7.4543  grad_norm:  3.6042  memory: 44.84GiB(47.19%)  tps: 5,582  tflops: 323.27  mfu: 32.69%
[rank1]:[titan] 2025-10-15 10:37:30,028 - root - INFO - step: 40  loss:  7.4543  grad_norm:  3.6042  memory: 44.84GiB(47.19%)  tps: 5,582  tflops: 323.25  mfu: 32.69%
```

[ghstack-poisoned]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant