Skip to content

Conversation

limou102
Copy link

1) Add Context Parallelism(CP) support to Flux model training
Context Parallelism mainly used for video generation models, in the Flux model, the sequence length used in attention computations is very small(512), so context parallelism provides no speedup, other multimodal models can refer to this modification.

The comparison of loss curves with CP enabled/disabled is shown below (gray represents CP=4), with the same global_batch_size=32.
image

The validation loss curve(with coco dataset) is shown below.
image

2) fix compatibility issues between the Flux code and the latest main branch

Copy link

meta-cla bot commented Oct 10, 2025

Hi @limou102!

Thank you for your pull request and welcome to our community.

Action Required

In order to merge any pull request (code, docs, etc.), we require contributors to sign our Contributor License Agreement, and we don't seem to have one on file for you.

Process

In order for us to review and merge your suggested changes, please sign at https://code.facebook.com/cla. If you are contributing on behalf of someone else (eg your employer), the individual CLA may not be sufficient and your employer may need to sign the corporate CLA.

Once the CLA is signed, our tooling will perform checks and validations. Afterwards, the pull request will be tagged with CLA signed. The tagging process may take up to 1 hour after signing. Please give it that time before contacting us about it.

If you have received this in error or have any questions, please contact us at [email protected]. Thanks!

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 10, 2025
Copy link

meta-cla bot commented Oct 10, 2025

Thank you for signing our Contributor License Agreement. We can now accept your code for this (and any) Meta Open Source project. Thanks!

@limou102
Copy link
Author

2025-10-14 update
I conducted an additional experiment using global_batch_size=128 and iterations=20000, training two tasks:
Task A: local_batch=16, fsdp=8, cp=1
Task B: local_batch=32, fsdp=4, cp=2

The validation dataset loss curves for both tasks are nearly identical.
image

Moreover, when performing actual inference using the checkpoints from iterations=20000, the generated images both align well with the prompts.
image

Below is all test case outputs.
inference_results_gb128_iter20000.zip

@fegin
Copy link
Contributor

fegin commented Oct 14, 2025

Can you also paste memory usage?

@limou102
Copy link
Author

Can you also paste memory usage?

image The log above corresponds to local_batch=32, cp=2, fsdp=4, and the one below is for local_batch=16, cp=1, fsdp=8. In this case, the sequence length is quite small, so enabling CP doesn’t show a significant benefit — instead, it increases memory usage by about 10GB (132GB vs 122GB).

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add an integration test here, with FSDP+CP? https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/tests/integration_tests.py

Could you help update supported features and TODO in README.md? https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/README.md#supported-features


In terms of loss comparison, what you did makes sense. A more recommended and rigorous approach is to make sure both data loading behavior and parameter init to be exactly the same across two runs.
So instead of

I conducted an additional experiment using global_batch_size=128 and iterations=20000, training two tasks:
Task A: local_batch=16, fsdp=8, cp=1
Task B: local_batch=32, fsdp=4, cp=2

We can do the following to make sure data loading behaviors align, which only depends on DP degree (batch sharding degree)

  • Task A: local_batch=16, fsdp=2, cp=1 (2 GPU)
  • Task B: local_batch=16, fsdp=2, cp=4 (8 GPU)

Then we can use a seed checkpoint to make sure model parameters are the same, see https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility

Finally we can fix seed (training.seed) so that random behavior (e.g. dropout) is consistent, too.

@limou102
Copy link
Author

Could you add an integration test here, with FSDP+CP? https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/tests/integration_tests.py

Could you help update supported features and TODO in README.md? https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/flux/README.md#supported-features

In terms of loss comparison, what you did makes sense. A more recommended and rigorous approach is to make sure both data loading behavior and parameter init to be exactly the same across two runs. So instead of

I conducted an additional experiment using global_batch_size=128 and iterations=20000, training two tasks:
Task A: local_batch=16, fsdp=8, cp=1
Task B: local_batch=32, fsdp=4, cp=2

We can do the following to make sure data loading behaviors align, which only depends on DP degree (batch sharding degree)

  • Task A: local_batch=16, fsdp=2, cp=1 (2 GPU)
  • Task B: local_batch=16, fsdp=2, cp=4 (8 GPU)

Then we can use a seed checkpoint to make sure model parameters are the same, see https://github.com/pytorch/torchtitan/blob/main/docs/debugging.md#seed-checkpoint-based-reproducibility

Finally we can fix seed (training.seed) so that random behavior (e.g. dropout) is consistent, too.

I first generated a seed checkpoint, then ran 2 tasks with deterministic=True:
Task A: local_batch=16, fsdp=2, cp=1 using 2 GPUs
Task B: local_batch=16, fsdp=2, cp=4 using 8 GPUs

Their loss curves and values are completely identical.
image

image

In addition, integration test and updated to Flux README were added.

try:
from torch.distributed.tensor.experimental._attention import _cp_options

_cp_options.enable_load_balance = False
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why setting it to False here?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In PyTorch, this is an experimental module where _cp_options.enable_load_balance = True is hardcoded by default.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py#L71

In this state, only attention computations with is_causal=True are supported.
However, for multimodal models like Flux, where is_causal=False in attention, you need to manually set _cp_options.enable_load_balance to False.
https://github.com/pytorch/pytorch/blob/main/torch/distributed/tensor/experimental/_attention.py#L389

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This module only affects ring_attention computation under Context Parallelism.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@limou102
Got it, but why do you do the try-catch here?

I think the place to set this is tricky. If you use the context manager in train.py, it'll anyway shard the input along seq dimension. If it's going to the else path, _cp_options.enable_load_balance would be True and I expect the program would always fail (because of this line https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/distributed/tensor/experimental/_attention.py#L388-L389)

What stops us from always setting it to False?

cc @fegin as you are planning next steps. I think we can land this for now.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason I used a try-catch is that this code is part of PyTorch’s experimental module, and it might change in the future — for example, _cp_options could be moved elsewhere.
If that happens, the catch block will trigger a log message to notify us early.
In theory, if the catch is triggered, we could call logger.error() and then abort.

As for why i don’t always set enable_load_balance to False, I’m not entirely sure whether that would cause a performance regression, however, setting it to False has been verified to be correct in ring_attention when is_causal=False.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@limou102 Got it, but why do you do the try-catch here?

I think the place to set this is tricky. If you use the context manager in train.py, it'll anyway shard the input along seq dimension. If it's going to the else path, _cp_options.enable_load_balance would be True and I expect the program would always fail (because of this line https://github.com/pytorch/pytorch/blob/d0c24b392cbb7b213d22e42c52c6c2d1ac2da1bd/torch/distributed/tensor/experimental/_attention.py#L388-L389)

What stops us from always setting it to False?

cc @fegin as you are planning next steps. I think we can land this for now.

_cp_options.enable_load_balance should be False if it's going to else path, since no exception occured

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lxgsbqylbk sorry I didn't mean to say "else path", I wanted to say "except path". In that case I believe the behavior is unexpected (we do input sharding, but not setting load_balance to False)

@limou102
Given that you can't achieve correct behavior with try-catch, I'd suggest we just do _cp_options.enable_load_balance = False without try-catch for now. I believe @fegin will anyway do some refactor on this code soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants