Skip to content

[Feature] add support for gpu oss#616

Open
GeneDer wants to merge 11 commits intomainfrom
feature/gpt-oss
Open

[Feature] add support for gpu oss#616
GeneDer wants to merge 11 commits intomainfrom
feature/gpt-oss

Conversation

@GeneDer
Copy link
Member

@GeneDer GeneDer commented Mar 18, 2026

Copilot AI review requested due to automatic review settings March 18, 2026 19:46
@GeneDer
Copy link
Member Author

GeneDer commented Mar 18, 2026

cc @JohnQinAMD

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds GPT-OSS (20B/120B) support to the Megatron + Primus-Turbo stack by introducing GPT-OSS model configs and enabling GPT-OSS-style “sink attention” wiring in the Primus-Turbo attention path, plus example configs/scripts for MI355X runs.

Changes:

  • Add sink attention configuration knobs to Primus-Turbo Megatron module config.
  • Implement learned sink parameters + (optional) layer-parity sliding-window selection in PrimusTurboAttention.
  • Add GPT-OSS 120B model config and MI355X example pretrain configs/scripts for GPT-OSS 20B/120B.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
primus/configs/modules/megatron/primus_turbo.yaml Adds sink-attention-related config flags to Primus-Turbo module config.
primus/configs/models/megatron/gpt_oss_120B.yaml Introduces GPT-OSS 120B model definition extending DeepSeek v2 base.
primus/backends/megatron/core/extensions/primus_turbo.py Adds sink attention support to Primus-Turbo attention (learned sinks + windowing) and passes new args into flash-attn op.
examples/run_pretrain.sh Adds a commented pip install hint for a Primus-Turbo wheel.
examples/moe_package/run_gpt_oss_120B_mi355x.sh New run script for GPT-OSS 120B on MI355X cluster setup.
examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yaml New GPT-OSS 20B FP8 example config (enables sink attention).
examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml New GPT-OSS 20B BF16 example config (enables sink attention).
examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml New GPT-OSS 120B FP8 example config (enables sink attention + pipeline settings).
examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml New GPT-OSS 120B BF16 example config (enables sink attention + pipeline settings).
Comments suppressed due to low confidence (1)

primus/backends/megatron/core/extensions/primus_turbo.py:462

  • Sink attention adds new behavior (learned sinks, layer-parity windowing, CP fallback) but there are no unit tests covering it. Add tests that exercise (1) CP mode to ensure no unsupported kwargs are passed to the USP attention op, and (2) non-CP mode to ensure window_size follows the even-layer-only policy when configured.
        # Sink attention configuration (PR 208) - GPT-OSS style learned sinks
        # Reference: Primus-Turbo/primus_turbo/pytorch/ops/attention/flash_attn_interface.py
        # Note: We store config here but create self.sinks AFTER super().__init__()
        # because PyTorch requires Module.__init__() to be called before assigning parameters
        _use_sink_attention = getattr(args, 'use_sink_attention', False)
        # Sliding window size (gpt-oss uses 128, applied to even layers only)
        self.sink_sliding_window = getattr(args, 'sink_sliding_window', 0)
        # Whether to apply sliding window only to even layers (gpt-oss pattern)
        self.sink_window_even_layers_only = getattr(args, 'sink_window_even_layers_only', True)

        # Note: Sink attention is currently only supported in non-CP mode
        # (flash_attn_usp_func does not support sink parameter yet)
        if _use_sink_attention and self.config.context_parallel_size > 1:
            import warnings
            warnings.warn(
                "Sink attention is not supported with Context Parallel (CP > 1). "
                "Disabling sink attention for this configuration."
            )
            _use_sink_attention = False

        # Store for later use after super().__init__()
        self._init_sink_attention = _use_sink_attention
        self._num_heads_for_sinks = self.config.num_attention_heads

        self.offload = args.offload and "attn" in args.offload_ops
        if args.enable_turbo_attention_float8:
            self.attn = (
                pt.ops.flash_attn_fp8_usp_func
                if self.config.context_parallel_size > 1
                else pt.ops.flash_attn_fp8_func
            )
        else:
            self.attn = (
                pt.ops.flash_attn_usp_func
                if self.config.context_parallel_size > 1
                else pt.ops.flash_attn_func
            )
        if pg_collection is None:
            # For backward compatibility, remove in v0.14 and raise error
            # raise ValueError("TEDotProductAttention was called without ProcessGroupCollection")
            pg_collection = ProcessGroupCollection(
                tp=get_tensor_model_parallel_group(check_initialized=False),
                cp=get_context_parallel_group(check_initialized=False),
                hcp=get_hierarchical_context_parallel_groups(check_initialized=False),
            )
        else:
            assert hasattr(pg_collection, "tp"), "TEDotProductAttention pg_collection must have tp pg"
            assert hasattr(pg_collection, "cp"), "TEDotProductAttention pg_collection must have cp pg"
            if cp_comm_type == "a2a+p2p":
                assert hasattr(
                    pg_collection, "hcp"
                ), "TEDotProductAttention pg_collection must have hierarchical cp pg"

        self.attn_kwargs = {}
        if self.config.context_parallel_size > 1:
            self.attn_kwargs["ulysses_group"] = pg_collection.cp
            # TODO (limou)
            # enable ring attention
            self.attn_kwargs["ring_group"] = dist.new_group(ranks=[dist.get_rank()])

        assert config.window_size is None, "primus_turbo does not support sliding window attention"
        # Check version

        kv_channels = (
            (k_channels, v_channels)
            if k_channels is not None and v_channels is not None
            else self.config.kv_channels
        )

        super().__init__(
            num_attention_heads=self.config.num_attention_heads,
            kv_channels=kv_channels,
            num_gqa_groups=self.config.num_query_groups,
            attention_dropout=(
                self.config.attention_dropout if attention_dropout is None else attention_dropout
            ),
            qkv_format="sbhd",
            attn_mask_type=attn_mask_type.name,
            window_size=None,
            sequence_parallel=self.config.sequence_parallel,
            tp_size=self.config.tensor_model_parallel_size,
            get_rng_state_tracker=None,
            tp_group=pg_collection.tp,
            layer_number=layer_number,
            attention_type=attention_type,
            # cp is not support
            softmax_scale=softmax_scale,
        )

        # Initialize learned sink parameters AFTER super().__init__()
        # Shape: (num_attention_heads,) - one sink value per head
        # This matches gpt-oss model: self.sinks = torch.nn.Parameter(torch.empty(num_attention_heads))
        self.use_sink_attention = self._init_sink_attention
        if self.use_sink_attention:
            self.sinks = torch.nn.Parameter(
                torch.zeros(self._num_heads_for_sinks, dtype=torch.bfloat16)
            )
        else:
            self.sinks = None
        # Clean up temporary attributes
        del self._init_sink_attention
        del self._num_heads_for_sinks

    def forward(
        self,
        query: Tensor,
        key: Tensor,
        value: Tensor,
        attention_mask: Tensor,
        attn_mask_type: AttnMaskType,
        attention_bias: Tensor = None,
        packed_seq_params: PackedSeqParams = None,
    ):
        """Forward."""
        packed_seq_kwargs = (
            {key: getattr(packed_seq_params, key) for key in self.kept_packed_seq_params}

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines 463 to 477
if packed_seq_params is not None
else {}
)

qkv_format = packed_seq_kwargs.get("qkv_format", self.qkv_format)
assert qkv_format in ("sbhd", "bhsd"), "qkv_format only support bshd, but got {qkv_format}"
if qkv_format == "sbhd":
query, key, value = [x.transpose(0, 1).contiguous() for x in (query, key, value)]
mask_type = attn_mask_type.name
if mask_type == AttnMaskType.causal.name:
causal = True
elif mask_type == AttnMaskType.no_mask.name:
causal = False
else:
raise ValueError(f"Unsupported mask type: {mask_type}")
"Sink attention is not supported with Context Parallel (CP > 1). "
"Disabling sink attention for this configuration."
)
_use_sink_attention = False
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Copilot AI review requested due to automatic review settings March 18, 2026 20:51
GeneDer and others added 5 commits March 18, 2026 13:51
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Copilot Autofix powered by AI <175728472+Copilot@users.noreply.github.com>
Signed-off-by: Gene Der Su <e870252314@gmail.com>
This reverts commit 84065bd.
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds GPT-OSS GPU model support by introducing a new 120B Megatron model config and enabling GPT-OSS-style “sink attention” plumbing in the Primus-Turbo attention wrapper, along with runnable MI355X example configs/scripts.

Changes:

  • Add sink-attention configuration knobs to the Primus-Turbo Megatron module config and wire learned sink parameters through PrimusTurboAttention.
  • Add a new Megatron model definition for gpt_oss_120B (MoE scaling vs existing gpt_oss_20B).
  • Add MI355X example pretrain configs/scripts for GPT-OSS 20B/120B (BF16/FP8), and a helper run script.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 3 comments.

Show a summary per file
File Description
primus/configs/modules/megatron/primus_turbo.yaml Adds sink-attention config flags (enable + sliding-window controls).
primus/configs/models/megatron/gpt_oss_120B.yaml Introduces GPT-OSS 120B model YAML extending DeepSeek v2 base.
primus/backends/megatron/core/extensions/primus_turbo.py Implements learned sink parameters and forwards sink/window_size into Turbo flash-attn.
examples/run_pretrain.sh Adds a commented wheel-install hint for Primus-Turbo.
examples/moe_package/run_gpt_oss_120B_mi355x.sh New runnable MI355X script for GPT-OSS 120B experiments.
examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yaml New MI355X FP8 pretrain config enabling Primus-Turbo + sink attention.
examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml New MI355X BF16 pretrain config enabling Primus-Turbo + sink attention.
examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml New MI355X FP8 pretrain config for GPT-OSS 120B enabling sink attention.
examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml New MI355X BF16 pretrain config for GPT-OSS 120B enabling sink attention.
Comments suppressed due to low confidence (1)

primus/backends/megatron/core/extensions/primus_turbo.py:351

  • The attention module now introduces its own sliding-window control via sink_sliding_window, but the initializer still contains an assert config.window_size is None with the message "primus_turbo does not support sliding window attention". That assert/message becomes misleading and will also hard-fail if someone enables Megatron's window_size expecting it to work with sink attention. Consider updating the assert/message to clearly differentiate between Megatron's config.window_size (unsupported) vs the sink-attention-specific windowing path, and point users to sink_sliding_window if that's the intended mechanism.
        attention_dropout: Optional[float] = None,
        softmax_scale: Optional[float] = None,
        k_channels: Optional[int] = None,
        v_channels: Optional[int] = None,
        cp_comm_type: str = "p2p",

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines 489 to 493
for key in self.kept_packed_seq_params
}
if packed_seq_params is not None
else {}
)
Comment on lines 441 to 443
kv_channels=kv_channels,
num_gqa_groups=self.config.num_query_groups,
attention_dropout=(
GeneDer added 2 commits March 18, 2026 14:00
Signed-off-by: Gene Der Su <e870252314@gmail.com>
Signed-off-by: Gene Der Su <e870252314@gmail.com>
Copilot AI review requested due to automatic review settings March 18, 2026 21:01
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds GPT-OSS GPU model support by introducing Megatron configs for GPT-OSS 20B/120B and extending the Primus-Turbo Megatron attention wrapper with optional GPT-OSS-style learned sink attention parameters.

Changes:

  • Add sink-attention configuration knobs to the Megatron Primus-Turbo module config.
  • Implement learned sink parameters + optional sliding-window selection logic in PrimusTurboAttention.
  • Add GPT-OSS 120B model config plus MI355X example training configs/scripts for GPT-OSS 20B/120B.

Reviewed changes

Copilot reviewed 9 out of 9 changed files in this pull request and generated 4 comments.

Show a summary per file
File Description
primus/configs/modules/megatron/primus_turbo.yaml Adds sink-attention-related config fields to Primus-Turbo Megatron module defaults.
primus/configs/models/megatron/gpt_oss_120B.yaml Introduces a GPT-OSS 120B Megatron model configuration.
primus/backends/megatron/core/extensions/primus_turbo.py Adds sink attention support (learned sinks + optional sliding window behavior) to PrimusTurboAttention.
examples/run_pretrain.sh Adds a commented Primus-Turbo wheel install hint.
examples/moe_package/run_gpt_oss_120B_mi355x.sh New MI355X example script to launch GPT-OSS 120B training runs.
examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yaml New MI355X GPT-OSS 20B FP8 pretrain example config (enables sink attention).
examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml New MI355X GPT-OSS 20B BF16 pretrain example config (enables sink attention).
examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml New MI355X GPT-OSS 120B FP8 pretrain example config (enables sink attention).
examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml New MI355X GPT-OSS 120B BF16 pretrain example config (enables sink attention).

You can also share your feedback on Copilot code review. Take the survey.

Comment on lines +10 to +17
# Sink attention settings (PR 208) - GPT-OSS style learned sinks
# Reference: gpt-oss/gpt_oss/triton/attention.py
use_sink_attention: false
# Sliding window size for sink attention (gpt-oss uses 128)
sink_sliding_window: 0
# Whether to apply sliding window only to even layers (gpt-oss pattern)
sink_window_even_layers_only: true

Comment on lines +362 to +365
warnings.warn(
"Sink attention is not supported with Context Parallel (CP > 1). "
"Disabling sink attention for this configuration."
)
Comment on lines +515 to 522
window_size=window_size,
bias=None,
alibi_slopes=None,
deterministic=False,
return_lse=False,
return_attn_probs=False,
sink=sink_tensor, # PR 208: pass sink tensor to Primus-Turbo
**self.attn_kwargs,
if (self.layer_number - 1) % 2 == 0:
window_size = (self.sink_sliding_window, 0)
else:
window_size = (self.sink_sliding_window, 0)
@GeneDer
Copy link
Member Author

GeneDer commented Mar 19, 2026

@wenxie-amd I can confirm this worked end2end on this workflow https://github.com/ROCm/unified-training-dockers/actions/runs/23308475943 Can you help to review and merge it in?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants