Skip to content

Conversation

githubsgi
Copy link
Contributor

Incorporating input from converastion in #1761

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 7, 2025
@githubsgi
Copy link
Contributor Author

@tianyu-l , please review.

@githubsgi githubsgi changed the title Second version of degub/deterinistic configs. Second version of degub/deterministic configs. Oct 7, 2025
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Left some comments.

@tianyu-l tianyu-l linked an issue Oct 9, 2025 that may be closed by this pull request
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@githubsgi
Copy link
Contributor Author

@tianyu-l , do not review yet. Not sure the rebase was valid.

@githubsgi
Copy link
Contributor Author

@tianyu-l , please review.

Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

@tianyu-l
Copy link
Contributor

@githubsgi Please address lint issues.

@tianyu-l
Copy link
Contributor

@githubsgi
sorry it seems you need to rebase

@tianyu-l
Copy link
Contributor

@githubsgi unfortunately lint needs to be fixed again

githubsgi and others added 7 commits October 15, 2025 15:52
Incorporating input from converastion in pytorch#1761
…ch#1804)

## Benchmarking
<meta charset="utf-8"><b style="font-weight:normal;"
id="docs-internal-guid-852d634c-7fff-a3ae-72e8-d17e64bb4b2c"><div
dir="ltr" style="margin-left:0pt;" align="center">
Step | time | log
-- | -- | --
to_hf() | 0.1103s | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root
- INFO - Completed to_hf conversion, generated 189 keys, duration:
0.1103s
Split local GroupedExperts DTensor to individual experts’ weight | 0.008
s per layer per matrix (total 58 MoE Layers * 3 weight matrices per
layer) | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO -
Completed _get_local_experts_weights for layer 6, abstract_key:
model.layers.{}.mlp.experts.{}.up_proj.weight, duration: 0.0082s
dcp.load()Threads count=4 | 193.20s | [trainer0\|0]:[titan] 2025-10-03
17:10:58,899 - root - INFO - dcp.load with HuggingFaceStorageReader
completed in 193.20 seconds
from_hf() | 0.48s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,378 - root
- INFO - Completed from_hf conversion, processed 189 keys, duration:
0.4787s
Concatenate individual experts weight into GroupedExperts weight | 0.01s
per layer per matrix (total 58 MoE Layers * 3 weight matrices) |
[trainer0\|0]:[titan] 2025-10-03 17:10:59,120 - root - INFO - Completed
_concatenate_expert_weights_dtensor for layer 5, abstract_key:
layers.{}.moe.experts.w2, duration: 0.0142s
Total | 193.87s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,458 - root -
INFO - Finished loading the checkpoint in 193.87 seconds.

</div></b>

## End-to-End verification for 671B model
Parallelsim: FSDP=32, PP=8, 1F1B, EP=32

<img width="393" height="421" alt="Screenshot 2025-10-06 at 8 32 37 PM"
src="https://github.com/user-attachments/assets/6d8dab00-a188-4c57-8348-02bae1d21d03"
/>
<img width="393" height="421" alt="Screenshot 2025-10-06 at 8 32 54 PM"
src="https://github.com/user-attachments/assets/a730f71b-3dc8-45e0-8d3e-b21080884f8d"
/>
…h#1808)

With max-autotune, FlexAttention is not deterministic even if
torch.use_deterministic_algorithms is True. When deterministic mode is
set, we should also remove the usage of `max-autotune`.
Fix the number of layer issue introduced by pytorch#1804
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
tianyu-l and others added 26 commits October 15, 2025 16:14
As titled, given the `experiments/deepseek_v3` has been out of
maintenance for long time.

The folder could still be valuable, so I'm keeping the content in the
branch `experiments/deepseek_v3` as reference
https://github.com/pytorch/torchtitan/tree/experiments/deepseek_v3/torchtitan/experiments/deepseek_v3

This PR keeps the symmetric memory kernels for EP communication, whose
integration will come later.
This pr adds the autobucketing pass at aten-level to simplefsdp. It runs
autobucketing + aot_eager backend without inductor. The aten fx
autobucketing pass can be find in this PR:
pytorch/pytorch#163960.

Key updates are:

1. Support customized `aot_eger_autobucketing` backend to perform
autobucketing optimization.
2. In simplefsdp, the model_backend can be replaced by user's customized
passes using `compile.model_backend_override`.
Incorporating input from converastion in pytorch#1761
…ch#1804)

## Benchmarking
<meta charset="utf-8"><b style="font-weight:normal;"
id="docs-internal-guid-852d634c-7fff-a3ae-72e8-d17e64bb4b2c"><div
dir="ltr" style="margin-left:0pt;" align="center">
Step | time | log
-- | -- | --
to_hf() | 0.1103s | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root
- INFO - Completed to_hf conversion, generated 189 keys, duration:
0.1103s
Split local GroupedExperts DTensor to individual experts’ weight | 0.008
s per layer per matrix (total 58 MoE Layers * 3 weight matrices per
layer) | [trainer0\|0]:[titan] 2025-10-03 17:07:45,697 - root - INFO -
Completed _get_local_experts_weights for layer 6, abstract_key:
model.layers.{}.mlp.experts.{}.up_proj.weight, duration: 0.0082s
dcp.load()Threads count=4 | 193.20s | [trainer0\|0]:[titan] 2025-10-03
17:10:58,899 - root - INFO - dcp.load with HuggingFaceStorageReader
completed in 193.20 seconds
from_hf() | 0.48s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,378 - root
- INFO - Completed from_hf conversion, processed 189 keys, duration:
0.4787s
Concatenate individual experts weight into GroupedExperts weight | 0.01s
per layer per matrix (total 58 MoE Layers * 3 weight matrices) |
[trainer0\|0]:[titan] 2025-10-03 17:10:59,120 - root - INFO - Completed
_concatenate_expert_weights_dtensor for layer 5, abstract_key:
layers.{}.moe.experts.w2, duration: 0.0142s
Total | 193.87s | [trainer0\|0]:[titan] 2025-10-03 17:10:59,458 - root -
INFO - Finished loading the checkpoint in 193.87 seconds.

</div></b>

## End-to-End verification for 671B model
Parallelsim: FSDP=32, PP=8, 1F1B, EP=32

<img width="393" height="421" alt="Screenshot 2025-10-06 at 8 32 37 PM"
src="https://github.com/user-attachments/assets/6d8dab00-a188-4c57-8348-02bae1d21d03"
/>
<img width="393" height="421" alt="Screenshot 2025-10-06 at 8 32 54 PM"
src="https://github.com/user-attachments/assets/a730f71b-3dc8-45e0-8d3e-b21080884f8d"
/>
…h#1808)

With max-autotune, FlexAttention is not deterministic even if
torch.use_deterministic_algorithms is True. When deterministic mode is
set, we should also remove the usage of `max-autotune`.
In VLM interleaved training, with native resolution and aspect ratio,
the number of tokens participating in loss computation differ per rank.
Naive FSDP gradient averaging across data ranks can causes tokens on
ranks with fewer valid tokens to contribute more to the loss than on
other ranks.
This PR address this via loss balancing, which incur an additional comm
in the loss computation.
In practice, I haven't notice any impacts from this comm.

#### Quick sanity check
Let have a sum loss of all tokens on each rank i, with $N_i$ number of
tokens $L_i = \sum_{j=1}^{N_i}\ell_{ij}$ and its gradient $g_i =
\sum_{j=1}^{N_i}\nabla\ell_{ij}$

If we multiply the *loss* on each rank by a constant factor **c** (the
same for all ranks), then after `backward()`:

$$
\tilde g_i = c \cdot g_i .
$$

FSDP will *average* these gradients across ranks:

$$
g_{\text{FSDP}}=\frac{1}{R}\sum_{i=1}^{R} \tilde g_i
                =\frac{c}{R}\sum_{i=1}^{R} g_i .
$$

We want this to equal the **global‑sample average**:

$$
g_{\text{true}}
=\frac{1}{N_{\text{total}}}\sum_{i=1}^{R}\sum_{j=1}^{N_i}\nabla
\ell_{ij}
   =\frac{1}{N_{\text{total}}}\sum_{i=1}^{R} g_i .
$$

Thus for FSDP gradient to be correct, we need

$$
\frac{c}{R}= \frac{1}{N_{\text{total}}}\quad\Longrightarrow\quad
c=\frac{R}{N_{\text{total}}}.
$$

So the *right* scaling factor is $R/N_{\text{total}}$, which mean divide
the per-rank sum loss with $N_{\text{total}}/R$, which is **average
number of tokens per rank**.
Intuitively, this is the same as default cross-entropy loss, but instead
of diving sum loss on a rank by the number of tokens **on that rank**,
we now divide by the **average number of tokens across all rank**


P/s: sorry this PR is based on pytorch#1802 but I couldn't choose that as the
base branch. Maybe it will be easier to review once that PR is merged.
…ytorch#1776)

Stack from [ghstack](https://github.com/ezyang/ghstack/tree/0.12.0)
(oldest at bottom):
* pytorch#1797
* __->__ pytorch#1776

**Status**
1. Change all models, including the experimental ones.
2. E2E loss verification.
3. We should add an unittest for attention. But since we don't have GPU
unittest, this can be done in a separate PR.

**Summary**
This PR aims to refactor how TorchTitan build the attention masks and
pass to model. Before this PR, init_attention_masks() is called in
Trainer but the masks are stored as a class variable of
FlexAttentionWrapper(). We chose this shortcut to support the case where
a single model requires multiple masks.

The previous design has several issues, one particular one is
pytorch#1723.

pytorch/pytorch#164111 proves that we can let
PP split BlockMask, this PR performs the refactor to pass masks as an
argument of model.forward().

The new design:
1. Model needs to provide `get_attention_masks()` that accepts
`create_mask_fn`, `batch`, and `eos_id`. If the attention op is SDPA,
then this API should return None as SDPA currently doesn't support
varlen. But once it does, we may have to return some tuple of int that
represents the mask.

Justification: attention logic is technically a part of the model, but
requires some information from trainer/dataloader. So it's model
author's responsibility to provide some API that let trainer calls to
get the masks.

2. `get_attention_masks()` will be called from the trainer and the
resulting masks are passed to the model.forward().

Justification: this will allow us to fix
pytorch#1723 with
pytorch/pytorch#164111 and this PR.

3. Now SDPA and FlexAttention are wrapped in two different classes.
~~Note: we still have two very very thin op wrappers that are used for
CP. I keep these two for the CP education purpose. But this certainly
can be confusion for Titan's users. I'm opnn to merge them to
AttentionOp.~~

See the discussion in pytorch#1723.

**Verification**
*llama3*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml"
```
*llama3 flex*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/llama3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn"
```
*llama4*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint
```
*llama4 irope*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint
```
*deepseek*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml"
```
*deepseek flex*
```
./loss_compare.sh main 9dc16675b272ffdc3ed616e3244bcf7dc2d257f2 --steps=100 --no-seed-checkpoint --config="./torchtitan/models/deepseek_v3/train_configs/debug_model.toml" --baseline-train-options="--model.flavor=debugmodel_flex_attn"
```
Summary:
the script adds configuration options to run training locally with ft
enabled

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1812).
* pytorch#1840
* pytorch#1811
* pytorch#1810
* __->__ pytorch#1812
* pytorch#1809

---------

Co-authored-by: Tushar Jain <[email protected]>
Incorporating input from converastion in pytorch#1761
Summary:
the script adds configuration options to run training locally with ft
enabled

---
[//]: # (BEGIN SAPLING FOOTER)
Stack created with [Sapling](https://sapling-scm.com). Best reviewed
with
[ReviewStack](https://reviewstack.dev/pytorch/torchtitan/pull/1812).
* pytorch#1840
* pytorch#1811
* pytorch#1810
* __->__ pytorch#1812
* pytorch#1809

---------

Co-authored-by: Tushar Jain <[email protected]>
This PR:

- let `ExpertParallel` handles indices permute / unpermute when EP is
used
- move `to_local` to model code to be more explicit
- rename the `expert_parallel` wrapper which does permute / unpermute to
`indices_permutation_wrapper` to be more accurate
Next is step is to move `qwen3` and `llama4` to core, and remove
outdated experiments.
Summary:
Composability testing with TorchComms and distributed training in
TorchTitan.
  - Training with `torchcomms.new_comm`
  - Device mesh initialization with `torchcomms.init_device_mesh`
  - Integration and testing with `fully_shard`

Differential Revision: D82171763

Test plan:
TEST_BACKEND=nccl TRAIN_FILE=torchtitan.experiments.torchcomms.train
./run_train.sh --model.name torchcomms

Loss curve:
running 1000 steps on llama3_8b.toml

<img width="1095" height="469" alt="Screenshot 2025-10-13 at 4 14 46 PM"
src="https://github.com/user-attachments/assets/3d9ddf06-af76-44cf-ac75-b9f92e6d0f06"
/>
This PR also
- updates the README of `deepseek_v3` folder.
- move the `generate_permute_indices` triton kernel to
`torchtitan/models/moe` so that core doesn't depend on `experiments`
- deprecate the gpu checkpoint conversion scripts as now we natively
support loading checkpoint from HF using GPUs (although it is only using
GPU when doing online conversion right before training starts)
We originally thought each model should have its own `pipeline.py`
function.

However, for most LLMs, it turns out a single function would suffice,
and all models which needs PP are reusing `pipeline_llama.py` originally
written for llama3.

(For diffusion models, the model size doesn't justify the usage of PP.)

This PR consolidates them and moves `pipeline_llm` into
`torchtitan/distributed/pipeline_parallel.py`.

We can refactor later if things change.
…pytorch#1871)

In the past, the terms "args" and "config" have been used in a mix.

To make it unambiguous, in torchtitan we use
- "args" as in `ModelArgs` to refer to parameters used to define a model
in model code
- "config" as in `JobConfig` to refer to configurable training job
commands used in training script

This also PR also moves `custom_args_module` (which should be
`custom_config_module` according to the naming rule above) from
`Experimental` to `Job`, as it has been extensively used by various
models in torchtitan, especially those in the `experiments` folder.
@tianyu-l
Copy link
Contributor

@githubsgi linting failed

def apply_ac(
model: nn.Module,
ac_config: ACConfig,
job_config: JobConfig,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why changing this to pass in entire JobConfig? I think we should keep the function signature to use only ac config and debug config.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lint'er was failing on missing DebugConfig That could be passed as another parameter. That also changes the function foot print. Additional changes would also be required in most of the files to reflect that. Passing the full JobConfig instead of AcConfig and DebugConfig appears to be cleaner, simpler and more extendable to me.

Copy link
Contributor

@tianyu-l tianyu-l Oct 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Passing the full JobConfig instead of AcConfig and DebugConfig appears to be cleaner, simpler and more extendable to me.

This is config leakage we should avoid. It's convenient to use, but an anti-pattern. E.g., why apply_ac should know anything about FaultTolerance config, etc.

I know we have a few places we are doing this, but they should be refactored.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is certainly a good point to minimize exposure. On the flip side, it is not everything that is shared, only the JobConfig. Which in a way is the context that every sub-component/part/code is executing under. Hence visibility to that is useful as in this case - in addition to the AcConfig it required visibility to DebugConfig. It also can be viewed through the lens of the premature optimization pattern - TorchTitan is still under heavy development. Hence, it may happen that we need to expose another sections in the future. In that case we would not need to add one more function argument and the supporting code changes.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced by the argument. I think another way that's acceptable is to put AC debug configs into ACConfig, and separate them from other Debug configs(seed, deterministic, moe_force_load_balance), so that we can still put ACConfig as single input to apply_ac.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could do that . However, that will split the debug configs into multiple sections. BTW, that is what my first PR did and you suggested pulling all the debug config into its own section.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. They are not the same, as right now we are still having a dedicated debug section for things not related to AC.
  2. This is partly falling back because you tried and then disagreed with my proposal, and I'm willing to take a step back and try to make it work for you.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tianyu-l , so you want these 3 fields moved out of Debug class ?

    ac_preserve_rng_state: bool = False
    """If deterministic output compared to non-checkpointed passes is required, set to true. Results in stashing and restoring the RNG state during each checkpoint, may be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

    ac_determinism_check: str = "default"
    """A string specifying the determinism function. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

    ac_debug: bool = False
    """ Capture ac debug information. Will be slower. See https://docs.pytorch.org/docs/stable/checkpoint.html for details."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm OK with this.
I'm also OK with keeping them in debug, sending them via DebugConfig to apply_ac.
I'm not OK with sending entire job_config to apply_ac.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, let me see which one pans out better.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Adding options to enable some determinism related configs

9 participants