-
Couldn't load subscription status.
- Fork 580
gpt-oss model enablement #1754
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
gpt-oss model enablement #1754
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
4d3a84f
gptoss experimental support
e0306fe
clean up tentative licensing
b93f650
training fixes: expert load balancing, TP for sinks + experts, EP wor…
b5290a2
only assert sdpa backends if using sdpa; improve conversion script
c08a485
fixed conversion script with param by param
6b51de6
new lse-based flexattn implementation for sinks
2aef02e
test
wwwjn ace1a0f
rebase
wwwjn 0245062
fix flexattn
wwwjn 0e846f5
check and replace rope
wwwjn 0fb65b8
FSDP work, TP doesn't work
wwwjn c6748c4
test
wwwjn 54b7748
fix sink
wwwjn 92a68e1
test EP
wwwjn 7bd9e4d
working on ETP
wwwjn 7e4f38f
clean up
wwwjn afdb630
clean up
wwwjn a093727
rebase + address comments
wwwjn db2f6b6
rebase to main
wwwjn 40bd901
rebase on to expert parallel changes
wwwjn da672b2
refactor FlexAttention
wwwjn f7b9d84
fix ep
wwwjn 0da857a
fix TP
wwwjn c424815
address comments
wwwjn 331e6d8
address comments
wwwjn e26190c
fix args
wwwjn 9db1e98
add scaled bias
wwwjn d1dff5f
optimize using col major experts
wwwjn 237f07d
lint
wwwjn 3b1ba3d
add comments
wwwjn File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| # gpt-oss Model in torchtitan | ||
|
|
||
| ## Quick Start | ||
| ```bash | ||
| CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh | ||
| ``` | ||
|
|
||
| ## Supported Features | ||
| - FSDP/HSDP, TP, EP, ETP | ||
| - Grouped matrix multiplication for efficient computation | ||
|
|
||
|
|
||
| ## TODO | ||
| 1. More parallelism support: CP, PP | ||
| 2. Conversion between HF weights (StateDictAdapter) | ||
| 3. Forward parity verification | ||
| 4. CI support |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,87 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from torchtitan.components.loss import build_cross_entropy_loss | ||
| from torchtitan.components.lr_scheduler import build_lr_schedulers | ||
| from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing | ||
| from torchtitan.components.tokenizer import build_hf_tokenizer | ||
| from torchtitan.datasets.hf_datasets import build_hf_dataloader | ||
| from torchtitan.models.moe import MoEArgs | ||
|
|
||
| from torchtitan.protocols.train_spec import TrainSpec | ||
|
|
||
| from .infra.parallelize import parallelize_gptoss | ||
| from .model.args import GptOssModelArgs | ||
| from .model.model import GptOssModel | ||
|
|
||
| __all__ = [ | ||
| "parallelize_gptoss", | ||
| "GptOssModelArgs", | ||
| "GptOssModel", | ||
| "gptoss_configs", | ||
| ] | ||
|
|
||
|
|
||
| gptoss_configs = { | ||
| "debugmodel": GptOssModelArgs( | ||
| dim=256, | ||
| n_layers=4, | ||
| moe_args=MoEArgs( | ||
| num_experts=8, | ||
| num_shared_experts=0, | ||
| score_func="softmax", | ||
| route_norm=False, | ||
| route_scale=1.0, | ||
| score_before_experts=False, | ||
| top_k=4, | ||
| use_grouped_mm=True, | ||
| load_balance_coeff=1e-3, | ||
| ), | ||
| attn_mask_type="causal", | ||
| ), | ||
| "20b": GptOssModelArgs( | ||
| n_layers=24, | ||
| moe_args=MoEArgs( | ||
| num_experts=32, | ||
| num_shared_experts=0, | ||
| score_func="softmax", | ||
| route_norm=False, | ||
| route_scale=1.0, | ||
| score_before_experts=False, | ||
| top_k=4, | ||
| use_grouped_mm=True, | ||
| load_balance_coeff=1e-3, | ||
| ), | ||
| ), | ||
| "120b": GptOssModelArgs( | ||
| n_layers=36, | ||
| moe_args=MoEArgs( | ||
| num_experts=128, | ||
| num_shared_experts=0, | ||
| score_func="softmax", | ||
| route_norm=False, | ||
| route_scale=1.0, | ||
| score_before_experts=False, | ||
| top_k=4, | ||
| use_grouped_mm=True, | ||
| load_balance_coeff=1e-3, | ||
| ), | ||
| ), | ||
| } | ||
|
|
||
|
|
||
| def get_train_spec() -> TrainSpec: | ||
| return TrainSpec( | ||
| model_cls=GptOssModel, | ||
| model_args=gptoss_configs, | ||
| parallelize_fn=parallelize_gptoss, | ||
| pipelining_fn=None, | ||
| build_optimizers_fn=build_optimizers_with_moe_load_balancing, | ||
| build_lr_schedulers_fn=build_lr_schedulers, | ||
| build_dataloader_fn=build_hf_dataloader, | ||
| build_tokenizer_fn=build_hf_tokenizer, | ||
| build_loss_fn=build_cross_entropy_loss, | ||
| ) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,65 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
| # All rights reserved. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
|
|
||
| import torch.nn as nn | ||
| from torch.distributed.tensor import distribute_tensor, Replicate, Shard | ||
| from torchtitan.distributed.expert_parallel import ExpertTensorParallel, TensorParallel | ||
|
|
||
| # implementation of Tensor Parallel for the GroupedExperts in MoE | ||
| class GptossTensorParallel(TensorParallel): | ||
| def _partition_fn(self, name, module, device_mesh): | ||
| module.register_parameter( | ||
| "mlp1_weight", | ||
| nn.Parameter( | ||
| distribute_tensor(module.mlp1_weight, device_mesh, [Shard(1)]) | ||
| ), | ||
| ) # Column-wise sharding | ||
| module.register_parameter( | ||
| "mlp1_bias", | ||
| nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])), | ||
| ) # Column-wise sharding | ||
| module.register_parameter( | ||
| "mlp2_weight", | ||
| nn.Parameter( | ||
| distribute_tensor(module.mlp2_weight, device_mesh, [Shard(2)]) | ||
| ), | ||
| ) # Row-wise sharding | ||
| module.register_parameter( | ||
| "mlp2_bias", | ||
| nn.Parameter( | ||
| distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()]) | ||
| ), | ||
| ) # Replicate | ||
|
|
||
|
|
||
| # This class is for dp2ep with TP (without TP we can just use GptossExpertParallel) | ||
| class GptossExpertTensorParallel(ExpertTensorParallel): | ||
| def _partition_fn_2d(self, name, mod, ep_tp_mesh): | ||
| mod.register_parameter( | ||
| "mlp1_weight", | ||
| nn.Parameter( | ||
| distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(1)]) | ||
| ), | ||
| ) # Column-wise sharding | ||
| mod.register_parameter( | ||
| "mlp1_bias", | ||
| nn.Parameter( | ||
| distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)]) | ||
| ), | ||
| ) # Column-wise sharding | ||
| mod.register_parameter( | ||
| "mlp2_weight", | ||
| nn.Parameter( | ||
| distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)]) | ||
| ), | ||
| ) # Row-wise sharding | ||
| mod.register_parameter( | ||
| "mlp2_bias", | ||
| nn.Parameter( | ||
| distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()]) | ||
| ), | ||
| ) # Replicate |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.