-
Notifications
You must be signed in to change notification settings - Fork 565
Add Context Parallelism to Flux model training #1851
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Hi @limou102! Thank you for your pull request and welcome to our community. Action RequiredIn order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you. ProcessIn order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA. Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with If you have received this in error or have any questions, please contact us at [email protected]. Thanks! |
Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks! |
46ad69d
to
11e3084
Compare
2025-10-14 update The validation dataset loss curves for both tasks are nearly identical. Moreover, when performing actual inference using the checkpoints from iterations=20000, the generated images both align well with the prompts. Below is all test case outputs. |
Can you also paste memory usage? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you add an integration test here, with FSDP+CP? https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/tests/integration_tests.py
Could you help update supported features and TODO in README.md? https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/README.md#supported-features
In terms of loss comparison, what you did makes sense. A more recommended and rigorous approach is to make sure both data loading behavior and parameter init to be exactly the same across two runs.
So instead of
I conducted an additional experiment using global_batch_size=128 and iterations=20000, training two tasks:
Task A: local_batch=16, fsdp=8, cp=1
Task B: local_batch=32, fsdp=4, cp=2
We can do the following to make sure data loading behaviors align, which only depends on DP degree (batch sharding degree)
- Task A: local_batch=16, fsdp=2, cp=1 (2 GPU)
- Task B: local_batch=16, fsdp=2, cp=4 (8 GPU)
Then we can use a seed checkpoint to make sure model parameters are the same, see https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility
Finally we can fix seed (training.seed) so that random behavior (e.g. dropout) is consistent, too.
4811bd1
to
df7d040
Compare
I first generated a seed checkpoint, then ran 2 tasks with deterministic=True: Their loss curves and values are completely identical. ![]() In addition, integration test and updated to Flux README were added. |
…un with is_causal=False in the Flux model.
try: | ||
from torch.distributed.tensor.experimental._attention import _cp_options | ||
|
||
_cp_options.enable_load_balance = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why setting it to False here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In PyTorch, this is an experimental module where _cp_options.enable_load_balance = True is hardcoded by default.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py#L71
In this state, only attention computations with is_causal=True are supported.
However, for multimodal models like Flux, where is_causal=False in attention, you need to manually set _cp_options.enable_load_balance to False.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py#L389
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This module only affects ring_attention computation under Context Parallelism.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@limou102
Got it, but why do you do the try-catch here?
I think the place to set this is tricky. If you use the context manager in train.py
, it'll anyway shard the input along seq dimension. If it's going to the else path, _cp_options.enable_load_balance
would be True
and I expect the program would always fail (because of this line https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/distributed/tensor/experimental/_attention.py#L388-L389)
What stops us from always setting it to False
?
cc @fegin as you are planning next steps. I think we can land this for now.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason I used a try-catch is that this code is part of PyTorch’s experimental module, and it might change in the future — for example, _cp_options could be moved elsewhere.
If that happens, the catch block will trigger a log message to notify us early.
In theory, if the catch is triggered, we could call logger.error() and then abort.
As for why i don’t always set enable_load_balance to False, I’m not entirely sure whether that would cause a performance regression, however, setting it to False has been verified to be correct in ring_attention when is_causal=False.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@limou102 Got it, but why do you do the try-catch here?
I think the place to set this is tricky. If you use the context manager in
train.py
, it'll anyway shard the input along seq dimension. If it's going to the else path,_cp_options.enable_load_balance
would beTrue
and I expect the program would always fail (because of this line https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/distributed/tensor/experimental/_attention.py#L388-L389)What stops us from always setting it to
False
?cc @fegin as you are planning next steps. I think we can land this for now.
_cp_options.enable_load_balance should be False if it's going to else path, since no exception occured
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lxgsbqylbk sorry I didn't mean to say "else path", I wanted to say "except path". In that case I believe the behavior is unexpected (we do input sharding, but not setting load_balance to False)
@limou102
Given that you can't achieve correct behavior with try-catch, I'd suggest we just do _cp_options.enable_load_balance = False
without try-catch for now. I believe @fegin will anyway do some refactor on this code soon.
1) Add Context Parallelism(CP) support to Flux model training
Context Parallelism mainly used for video generation models, in the Flux model, the sequence length used in attention computations is very small(512), so context parallelism provides no speedup, other multimodal models can refer to this modification.
The comparison of loss curves with CP enabled/disabled is shown below (gray represents CP=4), with the same global_batch_size=32.

The validation loss curve(with coco dataset) is shown below.

2) fix compatibility issues between the Flux code and the latest main branch