Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion torchtitan/experiments/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,4 +28,5 @@ We provide this `experiments/` folder to host experiments that add significant v
| [vlm](./vlm/) | [![VLM 8 GPU Integration Tests](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml/badge.svg?branch=main)](https://github.com/pytorch/torchtitan/actions/workflows/integration_test_8gpu_vlm.yaml?query=branch%3Amain) | [@lkhphuc](https://github.com/lkhphuc) |
| [forge](./forge/) | TBA | [@allenwang28](https://github.com/allenwang28) [@ebsmothers](https://github.com/ebsmothers) [@joecummings](https://github.com/joecummings) [@pbontrager](https://github.com/pbontrager) |
| [torchcomms](./torchcomms/) | TBA | [@d4l3k](https://https://github.com/d4l3k) [@fduwjj](https://github.com/fduwjj) [@mori360 ](https://github.com/mori360) |
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/pytorch/torchtitan/pulls/kwen2501) |
| [moe_symm_mem_kernels](./moe_symm_mem_kernels/) | TBA | [@kwen2501](https://github.com/kwen2501) |
| [gpt_oss](./gpt_oss/) | TBA | [@jianiw](https://github.com/jianiw) |
2 changes: 1 addition & 1 deletion torchtitan/experiments/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
# LICENSE file in the root directory of this source tree.

_supported_experiments = frozenset(
["flux", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
["flux", "gpt_oss", "simple_fsdp.llama3", "simple_fsdp.deepseek_v3", "vlm"]
)
17 changes: 17 additions & 0 deletions torchtitan/experiments/gpt_oss/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# gpt-oss Model in torchtitan

## Quick Start
```bash
CONFIG_FILE="./torchtitan/experiments/gpt_oss/train_configs/debug_model.toml" ./run_train.sh
```

## Supported Features
- FSDP/HSDP, TP, EP, ETP
- Grouped matrix multiplication for efficient computation


## TODO
1. More parallelism support: CP, PP
2. Conversion between HF weights (StateDictAdapter)
3. Forward parity verification
4. CI support
87 changes: 87 additions & 0 deletions torchtitan/experiments/gpt_oss/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from torchtitan.components.loss import build_cross_entropy_loss
from torchtitan.components.lr_scheduler import build_lr_schedulers
from torchtitan.components.optimizer import build_optimizers_with_moe_load_balancing
from torchtitan.components.tokenizer import build_hf_tokenizer
from torchtitan.datasets.hf_datasets import build_hf_dataloader
from torchtitan.models.moe import MoEArgs

from torchtitan.protocols.train_spec import TrainSpec

from .infra.parallelize import parallelize_gptoss
from .model.args import GptOssModelArgs
from .model.model import GptOssModel

__all__ = [
"parallelize_gptoss",
"GptOssModelArgs",
"GptOssModel",
"gptoss_configs",
]


gptoss_configs = {
"debugmodel": GptOssModelArgs(
dim=256,
n_layers=4,
moe_args=MoEArgs(
num_experts=8,
num_shared_experts=0,
score_func="softmax",
route_norm=False,
route_scale=1.0,
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
),
attn_mask_type="causal",
),
"20b": GptOssModelArgs(
n_layers=24,
moe_args=MoEArgs(
num_experts=32,
num_shared_experts=0,
score_func="softmax",
route_norm=False,
route_scale=1.0,
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
),
),
"120b": GptOssModelArgs(
n_layers=36,
moe_args=MoEArgs(
num_experts=128,
num_shared_experts=0,
score_func="softmax",
route_norm=False,
route_scale=1.0,
score_before_experts=False,
top_k=4,
use_grouped_mm=True,
load_balance_coeff=1e-3,
),
),
}


def get_train_spec() -> TrainSpec:
return TrainSpec(
model_cls=GptOssModel,
model_args=gptoss_configs,
parallelize_fn=parallelize_gptoss,
pipelining_fn=None,
build_optimizers_fn=build_optimizers_with_moe_load_balancing,
build_lr_schedulers_fn=build_lr_schedulers,
build_dataloader_fn=build_hf_dataloader,
build_tokenizer_fn=build_hf_tokenizer,
build_loss_fn=build_cross_entropy_loss,
)
65 changes: 65 additions & 0 deletions torchtitan/experiments/gpt_oss/infra/expert_parallel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.


import torch.nn as nn
from torch.distributed.tensor import distribute_tensor, Replicate, Shard
from torchtitan.distributed.expert_parallel import ExpertTensorParallel, TensorParallel

# implementation of Tensor Parallel for the GroupedExperts in MoE
class GptossTensorParallel(TensorParallel):
def _partition_fn(self, name, module, device_mesh):
module.register_parameter(
"mlp1_weight",
nn.Parameter(
distribute_tensor(module.mlp1_weight, device_mesh, [Shard(1)])
),
) # Column-wise sharding
module.register_parameter(
"mlp1_bias",
nn.Parameter(distribute_tensor(module.mlp1_bias, device_mesh, [Shard(1)])),
) # Column-wise sharding
module.register_parameter(
"mlp2_weight",
nn.Parameter(
distribute_tensor(module.mlp2_weight, device_mesh, [Shard(2)])
),
) # Row-wise sharding
module.register_parameter(
"mlp2_bias",
nn.Parameter(
distribute_tensor(module.mlp2_bias, device_mesh, [Replicate()])
),
) # Replicate


# This class is for dp2ep with TP (without TP we can just use GptossExpertParallel)
class GptossExpertTensorParallel(ExpertTensorParallel):
def _partition_fn_2d(self, name, mod, ep_tp_mesh):
mod.register_parameter(
"mlp1_weight",
nn.Parameter(
distribute_tensor(mod.mlp1_weight, ep_tp_mesh, [Shard(0), Shard(1)])
),
) # Column-wise sharding
mod.register_parameter(
"mlp1_bias",
nn.Parameter(
distribute_tensor(mod.mlp1_bias, ep_tp_mesh, [Shard(0), Shard(1)])
),
) # Column-wise sharding
mod.register_parameter(
"mlp2_weight",
nn.Parameter(
distribute_tensor(mod.mlp2_weight, ep_tp_mesh, [Shard(0), Shard(2)])
),
) # Row-wise sharding
mod.register_parameter(
"mlp2_bias",
nn.Parameter(
distribute_tensor(mod.mlp2_bias, ep_tp_mesh, [Shard(0), Replicate()])
),
) # Replicate
Loading
Loading