diff --git a/examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml new file mode 100644 index 000000000..33f5ba517 --- /dev/null +++ b/examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml @@ -0,0 +1,123 @@ +work_group: ${PRIMUS_TEAM:tas} +user_name: ${PRIMUS_USER:qyy} +exp_name: ${PRIMUS_EXP_NAME:gpt_oss_120B-pretrain} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_120B}.yaml + overrides: + # log + wandb_project: "Primus_GPT_OSS_120B_Pretrain" + stderr_sink_level: DEBUG + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # mla false + # multi_latent_attention: false + # # attn uses "bshd" layout, enabling AMD optimized kernel. + # apply_rope_fusion: true + + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: false + + # Sink attention (PR 208) - GPT-OSS style learned sinks + # Reference: gpt-oss/gpt_oss/triton/attention.py + use_sink_attention: true + # Note: sliding window not yet supported by aiter Triton backend + # Set to 0 to disable, or wait for backend support + sink_sliding_window: 0 # gpt-oss default is 128, but disabled for now + sink_window_even_layers_only: true # apply sliding window only to even layers + + apply_rope_fusion: true + + # profile + profile: true + use_pytorch_profiler: true + profile_step_end: 7 + profile_step_start: 6 + + # hyper parameters + train_iters: 10 + micro_batch_size: 8 + global_batch_size: 2048 + seq_length: ${PRIMUS_SEQ_LENGTH:4096} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:4096} + lr: 1.0e-5 + min_lr: 0.0 + lr_warmup_iters: 2 + lr_decay_iters: null + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:2} + virtual_pipeline_model_parallel_size: ${PRIMUS_VP:2} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: true + # train_data_path: data + train_data_path: ${TOKENIZED_DATA_PATH:null} + valid_data_path: null + test_data_path: null + + # fusion + # 20250321: need latest megatron docker image + moe_permute_fusion: false + # fused wgrad gemm and accumulation + gradient_accumulation_fusion: false + # recommend set `false` in fp8 + moe_use_legacy_grouped_gemm: false + # fused topk router with aux score + moe_use_fused_router_with_aux_score: false + # pad 192/128 for deepseek attention + # fused_padded_mla_attention: false + + multi_latent_attention: false + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + eval_iters: 0 + + cross_entropy_loss_fusion: true + + # recompute + recompute_granularity: full # full, selective + recompute_method: block # uniform, block + recompute_num_layers: 4 # int + + # Turbo + # fp8: hybrid + # enable_primus_turbo: true + # use_turbo_attention: true + # use_turbo_grouped_mlp: false + # enable_primus_turbo: false + # enable_turbo_attention_float8 : false + # enable_turbo_gemm_float8 : false \ No newline at end of file diff --git a/examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml new file mode 100644 index 000000000..7e483e604 --- /dev/null +++ b/examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml @@ -0,0 +1,122 @@ +work_group: ${PRIMUS_TEAM:tas} +user_name: ${PRIMUS_USER:qyy} +exp_name: ${PRIMUS_EXP_NAME:gpt_oss_120B-pretrain} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_120B}.yaml + overrides: + # log + wandb_project: "Primus_GPT_OSS_120B_Pretrain" + stderr_sink_level: DEBUG + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # mla false + # multi_latent_attention: false + # # attn uses "bshd" layout, enabling AMD optimized kernel. + # apply_rope_fusion: true + + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: false + + # Sink attention (PR 208) - GPT-OSS style learned sinks + # Reference: gpt-oss/gpt_oss/triton/attention.py + use_sink_attention: true + # Note: sliding window not yet supported by aiter Triton backend + # Set to 0 to disable, or wait for backend support + sink_sliding_window: 0 # gpt-oss default is 128, but disabled for now + sink_window_even_layers_only: true # apply sliding window only to even layers + + apply_rope_fusion: true + + # profile + profile: true + use_pytorch_profiler: true + profile_step_end: 7 + profile_step_start: 6 + + # hyper parameters + train_iters: 10 + micro_batch_size: 8 + global_batch_size: 2048 + seq_length: ${PRIMUS_SEQ_LENGTH:4096} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:4096} + lr: 1.0e-5 + min_lr: 0.0 + lr_warmup_iters: 2 + lr_decay_iters: null + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:2} + virtual_pipeline_model_parallel_size: ${PRIMUS_VP:2} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: true + # train_data_path: data + train_data_path: ${TOKENIZED_DATA_PATH:null} + valid_data_path: null + test_data_path: null + + # fusion + # 20250321: need latest megatron docker image + moe_permute_fusion: false + # fused wgrad gemm and accumulation + gradient_accumulation_fusion: false + # recommend set `false` in fp8 + moe_use_legacy_grouped_gemm: false + # fused topk router with aux score + moe_use_fused_router_with_aux_score: false + # pad 192/128 for deepseek attention + # fused_padded_mla_attention: false + + multi_latent_attention: false + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + eval_iters: 0 + + cross_entropy_loss_fusion: true + + # recompute + recompute_granularity: full # full, selective + recompute_method: block # uniform, block + recompute_num_layers: 4 # int + + fp8: hybrid + # enable_primus_turbo: true + # use_turbo_attention: true + # use_turbo_grouped_mlp: false + # enable_primus_turbo: false + # enable_turbo_attention_float8 : false + # enable_turbo_gemm_float8 : false \ No newline at end of file diff --git a/examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml b/examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml new file mode 100644 index 000000000..5f8771042 --- /dev/null +++ b/examples/megatron/configs/MI355X/gpt_oss_20B-BF16-pretrain.yaml @@ -0,0 +1,116 @@ +work_group: ${PRIMUS_TEAM:tas} +user_name: ${PRIMUS_USER:qyy} +exp_name: ${PRIMUS_EXP_NAME:gpt_oss_20B-pretrain} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_20B}.yaml + overrides: + # log + wandb_project: "Primus_GPT_OSS_20B_Pretrain" + stderr_sink_level: DEBUG + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # mla false + # multi_latent_attention: false + # # attn uses "bshd" layout, enabling AMD optimized kernel. + # apply_rope_fusion: true + + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: false + + # Sink attention (PR 208) - GPT-OSS style learned sinks + # Reference: gpt-oss/gpt_oss/triton/attention.py + use_sink_attention: true + # Note: sliding window not yet supported by aiter Triton backend + # Set to 0 to disable, or wait for backend support + sink_sliding_window: 0 # gpt-oss default is 128, but disabled for now + sink_window_even_layers_only: true # apply sliding window only to even layers + + apply_rope_fusion: true + + # profile + profile: true + use_pytorch_profiler: true + profile_step_end: 7 + profile_step_start: 6 + + # hyper parameters + train_iters: 10 + micro_batch_size: 8 + global_batch_size: 512 + seq_length: ${PRIMUS_SEQ_LENGTH:4096} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:4096} + lr: 1.0e-5 + min_lr: 0.0 + lr_warmup_iters: 2 + lr_decay_iters: null + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:1} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: true + # train_data_path: data + train_data_path: ${TOKENIZED_DATA_PATH:null} + valid_data_path: null + test_data_path: null + + # fusion + # 20250321: need latest megatron docker image + moe_permute_fusion: false + # fused wgrad gemm and accumulation + gradient_accumulation_fusion: false + # recommend set `false` in fp8 + moe_use_legacy_grouped_gemm: false + # fused topk router with aux score + moe_use_fused_router_with_aux_score: false + # pad 192/128 for deepseek attention + # fused_padded_mla_attention: false + + multi_latent_attention: false + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + eval_iters: 0 + + cross_entropy_loss_fusion: true + + # fp8: hybrid + # enable_primus_turbo: true + # use_turbo_attention: true + # use_turbo_grouped_mlp: false + # enable_primus_turbo: false + # enable_turbo_attention_float8 : false + # enable_turbo_gemm_float8 : false \ No newline at end of file diff --git a/examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yaml b/examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yaml new file mode 100644 index 000000000..b3d85b7bd --- /dev/null +++ b/examples/megatron/configs/MI355X/gpt_oss_20B-FP8-pretrain.yaml @@ -0,0 +1,116 @@ +work_group: ${PRIMUS_TEAM:tas} +user_name: ${PRIMUS_USER:qyy} +exp_name: ${PRIMUS_EXP_NAME:gpt_oss_20B-pretrain} +workspace: ./output + +modules: + pre_trainer: + framework: megatron + config: pre_trainer.yaml + + # model to run + model: ${PRIMUS_MODEL:gpt_oss_20B}.yaml + overrides: + # log + wandb_project: "Primus_GPT_OSS_20B_Pretrain" + stderr_sink_level: DEBUG + + # debug + moe_router_force_load_balancing: true + log_avg_skip_iterations: 2 + log_avg_reset_interval: 50 + + # mla false + # multi_latent_attention: false + # # attn uses "bshd" layout, enabling AMD optimized kernel. + # apply_rope_fusion: true + + enable_primus_turbo: true + use_turbo_attention: true + use_turbo_grouped_mlp: false + + # Sink attention (PR 208) - GPT-OSS style learned sinks + # Reference: gpt-oss/gpt_oss/triton/attention.py + use_sink_attention: true + # Note: sliding window not yet supported by aiter Triton backend + # Set to 0 to disable, or wait for backend support + sink_sliding_window: 0 # gpt-oss default is 128, but disabled for now + sink_window_even_layers_only: true # apply sliding window only to even layers + + apply_rope_fusion: true + + # profile + profile: true + use_pytorch_profiler: true + profile_step_end: 7 + profile_step_start: 6 + + # hyper parameters + train_iters: 10 + micro_batch_size: 8 + global_batch_size: 512 + seq_length: ${PRIMUS_SEQ_LENGTH:4096} + max_position_embeddings: ${PRIMUS_MAX_POSITION_EMBEDDINGS:4096} + lr: 1.0e-5 + min_lr: 0.0 + lr_warmup_iters: 2 + lr_decay_iters: null + lr_decay_style: cosine + weight_decay: 0.1 + adam_beta1: 0.9 + adam_beta2: 0.95 + eod_mask_loss: true + init_method_std: 0.008 + norm_epsilon: 1.0e-6 + + # parallel + tensor_model_parallel_size: ${PRIMUS_TP:1} + pipeline_model_parallel_size: ${PRIMUS_PP:1} + expert_model_parallel_size: ${PRIMUS_EP:8} + overlap_grad_reduce: true + overlap_param_gather: true + + # data + mock_data: true + # train_data_path: data + train_data_path: ${TOKENIZED_DATA_PATH:null} + valid_data_path: null + test_data_path: null + + # fusion + # 20250321: need latest megatron docker image + moe_permute_fusion: false + # fused wgrad gemm and accumulation + gradient_accumulation_fusion: false + # recommend set `false` in fp8 + moe_use_legacy_grouped_gemm: false + # fused topk router with aux score + moe_use_fused_router_with_aux_score: false + # pad 192/128 for deepseek attention + # fused_padded_mla_attention: false + + multi_latent_attention: false + + # ckpt + finetune: false + auto_continue_train: false + load: null + no_load_optim: null + no_load_rng: null + save: null + save_interval: 20000 + no_save_optim: null + no_save_rng: null + disable_last_saving: true + ckpt_format: torch + eval_iters: 0 + + cross_entropy_loss_fusion: true + + fp8: hybrid + # enable_primus_turbo: true + # use_turbo_attention: true + # use_turbo_grouped_mlp: false + # enable_primus_turbo: false + # enable_turbo_attention_float8 : false + # enable_turbo_gemm_float8 : false \ No newline at end of file diff --git a/examples/moe_package/run_gpt_oss_120B_mi355x.sh b/examples/moe_package/run_gpt_oss_120B_mi355x.sh new file mode 100755 index 000000000..4519729ed --- /dev/null +++ b/examples/moe_package/run_gpt_oss_120B_mi355x.sh @@ -0,0 +1,244 @@ +#!/bin/bash +############################################################################### +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# +# See LICENSE for license information. +############################################################################### + +######################### Training Docker and Variables ######################### +# export DOCKER_IMAGE="docker.io/tasimage/primus:pr-316-gfx950-ainic" +export DOCKER_IMAGE="docker.io/tasimage/primus:pr-316-v25.10-ainic" +# export DOCKER_IMAGE="docker.io/rocm/megatron-lm:v25.10" +export CLEAN_DOCKER_CONTAINER=1 +export SKIP_TRAIN=0 +export CPUS_PER_TASK=96 + +######################### Training Environment Variables ######################### +export HF_TOKEN=${HF_TOKEN:-"your_hf_token"} +export WANDB_API_KEY=${WANDB_API_KEY:-"your_wandb_api_key"} +export GPU_MAX_HW_QUEUES=${GPU_MAX_HW_QUEUES:-2} +export HSA_NO_SCRATCH_RECLAIM=${HSA_NO_SCRATCH_RECLAIM:-1} +export NVTE_CK_USES_BWD_V3=${NVTE_CK_USES_BWD_V3:-1} +# export USE_ROCM_AITER_ROPE_BACKEND=0 + +# Set on Primus-Safe Platform +# export MASTER_ADDR=${MASTER_ADDR:-localhost} +# export MASTER_PORT=${MASTER_PORT:-1234} +# export NNODES=${PET_NNODES:-1} +# export NODE_RANK=${PET_NODE_RANK:-0} +# export GPUS_PER_NODE=${GPUS_PER_NODE:-8} + +# Set on Vultr cluster +export NNODES=4 +export USING_AINIC=1 +export NCCL_IB_HCA="ionic_0,ionic_1,ionic_2,ionic_3,ionic_4,ionic_5,ionic_6,ionic_7" # modify based on the GPU NiC settings +export NCCL_SOCKET_IFNAME="enp193s0f1np1" +export GLOO_SOCKET_IFNAME="enp193s0f1np1" +export NCCL_IB_RETRY_CNT=20 +export NCCL_IB_TIMEOUT=300 + +######################### Training Config ######################### +export MBS=${MBS:-8} +export GBS=${GBS:-2048} +export SEQ_LENGTH=${SEQ_LENGTH:-4096} +export TP=${TP:-1} +export ETP=${ETP:-1} +export PP=${PP:-2} +export VPP=${VPP:-2} +export EP=${EP:-8} +export CP=${CP:-1} +export CP_COMM_TYPE=${CP_COMM_TYPE:-"a2a"} # p2p, a2a, allgather or a2a+p2p +export ENABLE_MLA=${ENABLE_MLA:-False} +export ENABLE_MTP=${ENABLE_MTP:-False} +export LOAD_BALANCE=${LOAD_BALANCE:-True} +export OPTIMIZER=${OPTIMIZER:-adam} +export RECOMPUTE_LAYERS=${RECOMPUTE_LAYERS:-4} +export LEGACY_GG=${LEGACY_GG:-False} +export FP8=${FP8:-True} # True for fp8, False for bf16 +export PROFILE=${PROFILE:-False} +export DISABLE_CPU_TRACE=${DISABLE_CPU_TRACE:-False} +export PROFILE_STEP_START=${PROFILE_STEP_START:-5} +export PROFILE_STEP_END=${PROFILE_STEP_END:-6} +export TRAIN_ITERS=${TRAIN_ITERS:-10} + +# MoE_Features legend: +# 0 - Baseline (no extra optimization toggles) +# 1 - Turbo attention acceleration +# 2 - Turbo grouped GEMM / MLP fusion +# 3 - Loss fusion helper +# 4 - DeepEP acceleration +# 5 - Sync-free MoE (stage 1/2) +# 6 - CPU NUMA binding helper +# 7 - Manual GC helper +if [ -z "${MoE_Features}" ]; then + # MoE_Features=(0 7) + MoE_Features=(1 3 7) + # MoE_Features=(3 4 7) + # MoE_Features=(3 4 6 7) + # MoE_Features=(3 4 5 6 7) +else + # Convert string to array + # shellcheck disable=SC2128 + read -ra MoE_Features <<< "$MoE_Features" +fi + +FEATURE_ARGS=() +PRIMUS_TURBO_ENABLED="False" +ensure_primus_turbo() { + if [ "$PRIMUS_TURBO_ENABLED" = "False" ]; then + FEATURE_ARGS+=("--enable_primus_turbo" "True") + PRIMUS_TURBO_ENABLED="True" + fi +} + +for feature in "${MoE_Features[@]}"; do + case "$feature" in + 0) ;; + 1) + ensure_primus_turbo + FEATURE_ARGS+=("--use_turbo_attention" "True") + ;; + 2) + ensure_primus_turbo + FEATURE_ARGS+=("--use_turbo_grouped_mlp" "True") + ;; + 3) + FEATURE_ARGS+=("--cross_entropy_fusion_impl" "te") + FEATURE_ARGS+=("--cross_entropy_loss_fusion" "True") + ;; + 4) + ensure_primus_turbo + FEATURE_ARGS+=("--use_turbo_deepep" "True") + FEATURE_ARGS+=("--turbo_deepep_num_cu" "64") + FEATURE_ARGS+=("--turbo_deepep_use_comm_stream" "False") + FEATURE_ARGS+=("--moe_shared_expert_overlap" "False") + FEATURE_ARGS+=("--moe_router_dtype" "fp32") + ;; + 5) + ensure_primus_turbo + # mi355 + # sync_free moe stage 1 will open router and permutation fusion + FEATURE_ARGS+=("--turbo_sync_free_moe_stage" "1") + + # mi300/mi325 + # sync_free moe stage 2 will open deepep automatically + # FEATURE_ARGS+=("--turbo_sync_free_moe_stage" "2") + # FEATURE_ARGS+=("--moe_shared_expert_overlap" "False") + # FEATURE_ARGS+=("--moe_use_legacy_grouped_gemm" "True") + # FEATURE_ARGS+=("--moe_router_dtype" "fp32") + ;; + 6) + # Enable NUMA binding for better memory locality (increase stability for large models) + export ENABLE_NUMA_BINDING=1 + export HSA_KERNARG_POOL_SIZE=12582912 + ;; + 7) + FEATURE_ARGS+=("--manual_gc" "True") + FEATURE_ARGS+=("--manual_gc_interval" "1") + ;; + *) ;; + esac +done + +FEATURE_LIST="${MoE_Features[*]}" +FEATURE_TAG=$(printf "%s" "${FEATURE_LIST}" | tr ' ' '-') + +MLA_ARGS=() +if [ "$ENABLE_MLA" = "True" ]; then + MLA_ARGS+=("--multi_latent_attention" "True") +else + MLA_ARGS+=("--multi_latent_attention" "False") +fi + +MTP_ARGS=() +if [ "$ENABLE_MTP" = "True" ]; then + MTP_ARGS+=("--mtp_num_layers" "1") + MTP_ARGS+=("--mtp_loss_scaling_factor" "0.1") +else + MTP_ARGS+=("--mtp_num_layers" "None") +fi + +VPP_ARGS=() +if [ "$VPP" -gt 1 ]; then + VPP_ARGS+=("--num_virtual_stages_per_pipeline_rank" "$VPP") +fi + +FP8_ARGS=() +if [ "$FP8" = "True" ]; then + FP8_ARGS+=("--fp8" "hybrid") +fi + +RECOMPUTE_ARGS=() +if [ "$RECOMPUTE_LAYERS" -gt 0 ]; then + RECOMPUTE_ARGS+=("--recompute_granularity" "full") + RECOMPUTE_ARGS+=("--recompute_method" "block") + RECOMPUTE_ARGS+=("--recompute_num_layers" "${RECOMPUTE_LAYERS}") +fi + +PROFILE_ARGS=() +if [ "$PROFILE" = "True" ]; then + # --profile-ranks 0 1 2 3 4 5 6 7 + PROFILE_ARGS+=("--profile" "True") + PROFILE_ARGS+=("--disable_profiler_activity_cpu" "${DISABLE_CPU_TRACE}") + PROFILE_ARGS+=("--use_pytorch_profiler" "True") + PROFILE_ARGS+=("--profile_step_start" "${PROFILE_STEP_START}") + PROFILE_ARGS+=("--profile_step_end" "${PROFILE_STEP_END}") +fi + +######################### Training Experiments ######################### +PRIMUS_TEAM="date-$(date +%Y%m%d)-GPT_OSS_120B-Vultr-MI355X" +export PRIMUS_TEAM +PRIMUS_USER=user-tas +export PRIMUS_USER +# export PRIMUS_EXP_NAME="debug" +export PRIMUS_EXP_NAME="GPT_OSS_120B_MI355X_FP8${FP8}_MBS${MBS}_GBS${GBS}_SEQ${SEQ_LENGTH}_MLA${ENABLE_MLA}_MTP${ENABLE_MTP}_REC${RECOMPUTE_LAYERS}_TP${TP}_ETP${ETP}_PP${PP}_VPP${VPP}_EP${EP}_CP${CP}_Balance${LOAD_BALANCE}_LegacyGG${LEGACY_GG}_Profile${PROFILE}-${PROFILE_STEP_START}-${PROFILE_STEP_END}_NoCPUTrace${DISABLE_CPU_TRACE}_Queue${GPU_MAX_HW_QUEUES}_Features${FEATURE_TAG}" + +LOG_DIR=./output/$PRIMUS_TEAM/$PRIMUS_USER/$PRIMUS_EXP_NAME +export DUMP_PP_DIR=$LOG_DIR/pp_dump +export LOG_FILE=$LOG_DIR/training.log +export EXPORT_CONFIG=$LOG_DIR/config.yaml +mkdir -p "$LOG_DIR" +rm -rf "$LOG_FILE" + +######################### Training Job ######################### +# export EXP="examples/megatron/configs/MI355X/gpt_oss_120B-BF16-pretrain.yaml" +export EXP="examples/megatron/configs/MI355X/gpt_oss_120B-FP8-pretrain.yaml" + +echo "--------------------------------" | tee -a "$LOG_FILE" +echo "Begin Training... $(date +%Y%m%d_%H%M%S)" | tee -a "$LOG_FILE" +echo "Training Config: $EXP" | tee -a "$LOG_FILE" +echo "LOG_DIR=${LOG_DIR}" | tee -a "$LOG_FILE" +echo "LOG_FILE=${LOG_FILE}" | tee -a "$LOG_FILE" +echo "FEATURE_ARGS=${FEATURE_ARGS[*]}" | tee -a "$LOG_FILE" +echo "MoE_Features=${FEATURE_LIST}" | tee -a "$LOG_FILE" +echo "MLA_ARGS=${MLA_ARGS[*]}" | tee -a "$LOG_FILE" +echo "MTP_ARGS=${MTP_ARGS[*]}" | tee -a "$LOG_FILE" +echo "FP8_ARGS=${FP8_ARGS[*]}" | tee -a "$LOG_FILE" +echo "RECOMPUTE_ARGS=${RECOMPUTE_ARGS[*]}" | tee -a "$LOG_FILE" +echo "PROFILE_ARGS=${PROFILE_ARGS[*]}" | tee -a "$LOG_FILE" +echo "--------------------------------" | tee -a "$LOG_FILE" + +bash ./examples/run_slurm_pretrain.sh \ + --micro_batch_size "$MBS" \ + --global_batch_size "$GBS" \ + --seq_length "$SEQ_LENGTH" \ + --tensor_model_parallel_size "$TP" \ + --expert_tensor_parallel_size "$ETP" \ + --pipeline_model_parallel_size "$PP" \ + --expert_model_parallel_size "$EP" \ + --context_parallel_size "$CP" \ + --cp_comm_type "$CP_COMM_TYPE" \ + --mock_data True \ + --pp_warmup True \ + --moe_router_force_load_balancing "$LOAD_BALANCE" \ + --optimizer "$OPTIMIZER" \ + --moe_use_legacy_grouped_gemm "$LEGACY_GG" \ + --torch_profiler_use_gzip False \ + "${MLA_ARGS[@]}" \ + "${MTP_ARGS[@]}" \ + "${VPP_ARGS[@]}" \ + "${FEATURE_ARGS[@]}" \ + "${RECOMPUTE_ARGS[@]}" \ + "${FP8_ARGS[@]}" \ + "${PROFILE_ARGS[@]}" \ + --train_iters "$TRAIN_ITERS" 2>&1 | tee -a "$LOG_FILE" diff --git a/examples/run_pretrain.sh b/examples/run_pretrain.sh index ea504fa86..25cf3e510 100755 --- a/examples/run_pretrain.sh +++ b/examples/run_pretrain.sh @@ -398,6 +398,7 @@ fi # install primus turbo from source export REBUILD_PRIMUS_TURBO=${REBUILD_PRIMUS_TURBO:-0} if [ "$REBUILD_PRIMUS_TURBO" == "1" ]; then + # pip3 install --extra-index-url https://test.pypi.org/simple primus_turbo-0.2.0+69d2386-cp310-cp310-linux_x86_64.whl LOG_INFO "Rebuilding Primus Turbo from source..." mkdir -p "/workspace/turbo" cd "/workspace/turbo" || exit diff --git a/primus/backends/megatron/core/extensions/primus_turbo.py b/primus/backends/megatron/core/extensions/primus_turbo.py index fb70b093c..688c8561e 100644 --- a/primus/backends/megatron/core/extensions/primus_turbo.py +++ b/primus/backends/megatron/core/extensions/primus_turbo.py @@ -309,6 +309,19 @@ class PrimusTurboAttention(te.pytorch.DotProductAttention): Note that if Megatron's parallel_state has not been initialized yet, the tp_group and cp_group passed to TE will be None and must be set later via set_tensor_parallel_group() and set_context_parallel_group(). + + Supports sink attention (PR 208) when use_sink_attention is enabled. + GPT-OSS style sink attention uses learned sink parameters per attention head, + which act as virtual attention targets that help stabilize attention patterns + especially with sliding window attention. + + Primus-Turbo API (flash_attn_interface.py): + flash_attn_func(..., sink: Optional[torch.Tensor] = None) + - sink: learned sink parameters, shape (num_attention_heads,) + - When sink is provided, the Triton backend is automatically used + (C++ backend does not support sink attention) + + Reference: gpt-oss/gpt_oss/triton/attention.py """ def __init__( @@ -327,8 +340,35 @@ def __init__( self.config = config self.qkv_format: str = "sbhd" self.softmax_scale = softmax_scale + self.layer_number = layer_number args = get_args() + + # Sink attention configuration (PR 208) - GPT-OSS style learned sinks + # Reference: Primus-Turbo/primus_turbo/pytorch/ops/attention/flash_attn_interface.py + # Note: We store config here but create self.sinks AFTER super().__init__() + # because PyTorch requires Module.__init__() to be called before assigning parameters + _use_sink_attention = getattr(args, "use_sink_attention", False) + # Sliding window size (gpt-oss uses 128, applied to even layers only) + self.sink_sliding_window = getattr(args, "sink_sliding_window", 0) + # Whether to apply sliding window only to even layers (gpt-oss pattern) + self.sink_window_even_layers_only = getattr(args, "sink_window_even_layers_only", True) + + # Note: Sink attention is currently only supported in non-CP mode + # (flash_attn_usp_func does not support sink parameter yet) + if _use_sink_attention and self.config.context_parallel_size > 1: + import warnings + + warnings.warn( + "Sink attention is not supported with Context Parallel (CP > 1). " + "Disabling sink attention for this configuration." + ) + _use_sink_attention = False + + # Store for later use after super().__init__() + self._init_sink_attention = _use_sink_attention + self._num_heads_for_sinks = self.config.num_attention_heads + self.offload = args.offload and "attn" in args.offload_ops if args.enable_turbo_attention_float8: self.attn = ( @@ -394,6 +434,18 @@ def __init__( softmax_scale=softmax_scale, ) + # Initialize learned sink parameters AFTER super().__init__() + # Shape: (num_attention_heads,) - one sink value per head + # This matches gpt-oss model: self.sinks = torch.nn.Parameter(torch.empty(num_attention_heads)) + self.use_sink_attention = self._init_sink_attention + if self.use_sink_attention: + self.sinks = torch.nn.Parameter(torch.zeros(self._num_heads_for_sinks, dtype=torch.bfloat16)) + else: + self.sinks = None + # Clean up temporary attributes + del self._init_sink_attention + del self._num_heads_for_sinks + def forward( self, query: Tensor, @@ -423,6 +475,31 @@ def forward( else: raise ValueError(f"Unsupported mask type: {mask_type}") + # Sink attention support (PR 208) - GPT-OSS style + # Learned sinks act as virtual attention targets that help stabilize + # attention patterns, especially with sliding window attention. + # + # Primus-Turbo API (flash_attn_interface.py line 316-348): + # flash_attn_func(..., sink: Optional[torch.Tensor] = None) + # - sink: learned sink parameters, shape (num_attention_heads,) + # - When sink is provided, Triton backend is automatically used + # + # Reference: gpt-oss/gpt_oss/triton/attention.py + sink_tensor = None + window_size = (-1, -1) + + if self.use_sink_attention and self.sinks is not None: + sink_tensor = self.sinks + + # Apply sliding window based on layer pattern (gpt-oss: even layers only) + # gpt-oss pattern: self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else 0 + if self.sink_sliding_window > 0: + if self.sink_window_even_layers_only: + # Only apply sliding window to even layers (layer_number is 1-indexed in Megatron) + if (self.layer_number - 1) % 2 == 0: + window_size = (self.sink_sliding_window, 0) + else: + window_size = (self.sink_sliding_window, 0) if self.offload: OFFLOAD_BUFFER.add_offload_tensor(f"attn_q", query) OFFLOAD_BUFFER.add_offload_tensor(f"attn_k", key) @@ -435,12 +512,13 @@ def forward( dropout_p=0.0, softmax_scale=self.softmax_scale, causal=causal, - window_size=(-1, -1), + window_size=window_size, bias=None, alibi_slopes=None, deterministic=False, return_lse=False, return_attn_probs=False, + sink=sink_tensor, # PR 208: pass sink tensor to Primus-Turbo **self.attn_kwargs, ) diff --git a/primus/configs/models/megatron/gpt_oss_120B.yaml b/primus/configs/models/megatron/gpt_oss_120B.yaml new file mode 100644 index 000000000..44551f932 --- /dev/null +++ b/primus/configs/models/megatron/gpt_oss_120B.yaml @@ -0,0 +1,33 @@ +extends: + - deepseek_v2_base.yaml + +tokenizer_type: DeepSeekV2Tokenizer +tokenizer_model: deepseek-ai/DeepSeek-V2-Lite + +# model +num_layers: 36 +hidden_size: 2880 +ffn_hidden_size: 10944 +num_attention_heads: 64 + +# GQA +multi_latent_attention: false +apply_rope_fusion: false +qk_head_dim: 128 +kv_channels: 64 +group_query_attention: true +num_query_groups: 8 + +# moe +moe_layer_freq: 1 +num_experts: 128 +moe_router_topk: 4 +moe_ffn_hidden_size: 2880 # moe_intermediate_size +moe_shared_expert_intermediate_size: 2880 # num_shared_experts * moe_ffn_hidden_size + + +# device limited routing +expert_model_parallel_size: 8 +moe_router_num_groups: null # int +moe_router_group_topk: null # int +moe_router_topk_scaling_factor: 1.0 # float diff --git a/primus/configs/modules/megatron/primus_turbo.yaml b/primus/configs/modules/megatron/primus_turbo.yaml index ee85b80ec..4b3603908 100644 --- a/primus/configs/modules/megatron/primus_turbo.yaml +++ b/primus/configs/modules/megatron/primus_turbo.yaml @@ -7,6 +7,14 @@ use_turbo_parallel_linear: false use_turbo_grouped_mlp: false moe_use_fused_router_with_aux_score: false +# Sink attention settings (PR 208) - GPT-OSS style learned sinks +# Reference: gpt-oss/gpt_oss/triton/attention.py +use_sink_attention: false +# Sliding window size for sink attention (gpt-oss uses 128) +sink_sliding_window: 0 +# Whether to apply sliding window only to even layers (gpt-oss pattern) +sink_window_even_layers_only: true + # inner features flags enable_turbo_attention_float8 : false