Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 165 additions & 0 deletions primus/backends/megatron/patches/zebra_llama_flops_patches.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
###############################################################################
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################

"""
Zebra-Llama FLOPs Patch

Monkey-patches megatron.training.training.num_floating_point_operations so that
hybrid models using Multi-Latent Attention (MLA) get an accurate FLOPs estimate
instead of falling through to the GQA-based hybrid_flops path.
"""

from primus.core.patches import PatchContext, get_args, register_patch
from primus.modules.module_utils import log_rank_0


def zebra_llama_flops(args, batch_size):
"""Calculate FLOPs for the hybrid model with MLA attention."""

def calculate_layer_counts():
"""Calculate the number of attention, Mamba, and MLP layers."""
if args.hybrid_override_pattern:
counts = {'M': 0, '*': 0, '-': 0}
for layer_type in args.hybrid_override_pattern:
if layer_type in counts:
counts[layer_type] += 1
return counts['*'], counts['M'], counts['-']
else:
num_attn_layers = round(args.num_layers * args.hybrid_attention_ratio)
num_mlp_layers = round(args.num_layers * args.hybrid_mlp_ratio)
num_mamba_layers = args.num_layers - num_attn_layers - num_mlp_layers
return num_attn_layers, num_mamba_layers, num_mlp_layers

def mlp_layer_flops(batch_size, seq_len, hidden_size, expansion=4.0, swiglu=False):
"""Calculate FLOPs for an MLP layer."""
scale_factor = 3.0 / 2.0 if swiglu else 1.0
return 4 * expansion * scale_factor * batch_size * seq_len * hidden_size**2

def mamba_layer_flops(batch_size, seq_len, hidden_size, state_dim=16,
head_dim=64, num_groups=1, num_heads=128):
"""Calculate FLOPs for a Mamba layer."""
# Note (rwaleffe): flops estimate for scan should be updated based on new SSD kernels,
# but small percent of overall layer flops
d_in = 2 * hidden_size
if num_heads:
nheads = num_heads
else:
nheads = d_in // head_dim
return (
(
2
* batch_size
* seq_len
* hidden_size
* (2 * d_in + 2 * num_groups * state_dim + nheads)
) # in_proj
+ (7 * batch_size * seq_len * d_in * state_dim) # scan
+ (2 * batch_size * seq_len * d_in * hidden_size) # out_proj
)

def mla_attn_layer_flops(
batch_size, seq_len, hidden_size, num_heads,
q_lora_rank=None, kv_lora_rank=512,
qk_head_dim=128, qk_pos_emb_head_dim=64, v_head_dim=128
):
"""Calculate FLOPs for an MLA (Multi-Latent Attention) layer."""
'''
Basic arithmetic
let B is batch size, s is seq_len, h is embedding dim,
for one self_attnetion block (prenorm is not included)
qkv projection: 6Bsh^2
attn: 2Bs^2h
attn over value: 2Bs^2h
oproj: 2Bsh^2

references
https://arxiv.org/abs/2305.10403
https://arxiv.org/abs/2205.05198
'''
## MLA
if q_lora_rank is None:
q_term = (
hidden_size
* num_heads
* (qk_head_dim + qk_pos_emb_head_dim)
)
else:
q_term = q_lora_rank * (
hidden_size
+ num_heads * (qk_head_dim + qk_pos_emb_head_dim)
+ 1
)
return (
2 # FMA
* (
## q lora + rope + q norm
q_term
## kv lora + rope + kv norm
+ kv_lora_rank
* (
hidden_size
+ num_heads * (qk_head_dim + v_head_dim)
+ 1
)
+ hidden_size * qk_pos_emb_head_dim
## o proj
+ (num_heads * v_head_dim) * hidden_size
## core attn
+ seq_len
* (num_heads * (qk_head_dim + qk_pos_emb_head_dim))
/ 2
+ seq_len * num_heads * v_head_dim / 2
)
) * batch_size * seq_len

num_attn_layers, num_mamba_layers, num_mlp_layers = calculate_layer_counts()

flops_fwd = (
num_attn_layers * mla_attn_layer_flops(batch_size, args.seq_length, args.hidden_size,
args.num_attention_heads,
args.q_lora_rank, args.kv_lora_rank,
args.qk_head_dim, args.qk_pos_emb_head_dim,
args.v_head_dim) +
num_mlp_layers * mlp_layer_flops(batch_size, args.seq_length, args.hidden_size,
args.ffn_hidden_size / args.hidden_size,
args.swiglu) +
num_mamba_layers * mamba_layer_flops(batch_size, args.seq_length, args.hidden_size,
args.mamba_state_dim, args.mamba_head_dim,
args.mamba_num_groups, args.mamba_num_heads) +
(2 * batch_size * args.seq_length * args.hidden_size * args.padded_vocab_size) # logits computation
)
return flops_fwd * 3


# ---------------------------------------------------------------------------
# Patch registration
# ---------------------------------------------------------------------------


@register_patch(
"megatron.training.zebra_llama_flops",
backend="megatron",
phase="before_train",
description="Use MLA-aware FLOPs estimate for hybrid (Zebra-Llama) models.",
condition=lambda ctx: (
getattr(get_args(ctx), "is_hybrid_model", False)
and getattr(get_args(ctx), "multi_latent_attention", False)
),
)
def patch_zebra_llama_flops(ctx: PatchContext):
import megatron.training.training as orig_training

_orig_num_flops = orig_training.num_floating_point_operations

def _patched_num_floating_point_operations(args, batch_size):
if getattr(args, "is_hybrid_model", False) and getattr(
args, "multi_latent_attention", False
):
return zebra_llama_flops(args, batch_size)
return _orig_num_flops(args, batch_size)

orig_training.num_floating_point_operations = _patched_num_floating_point_operations
log_rank_0("MegatronPatches: using zebra_llama_flops for hybrid + MLA FLOPs.")
18 changes: 8 additions & 10 deletions primus/modules/trainer/megatron/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1461,11 +1461,10 @@ def get_e2e_base_metrics():
else:
assert num_skipped_samples_in_batch == 0
args.skipped_train_samples += num_skipped_samples_in_batch
flops_calc = (
num_floating_point_operations
if not args.multi_latent_attention
else self.num_floating_point_operations_mla_moe
)
if args.multi_latent_attention and not getattr(args, "is_hybrid_model", False):
flops_calc = self.num_floating_point_operations_mla_moe
else:
flops_calc = num_floating_point_operations
num_floating_point_operations_in_batch = flops_calc(args, batch_size)
num_floating_point_operations_so_far += num_floating_point_operations_in_batch
num_floating_point_operations_since_last_log_event += num_floating_point_operations_in_batch
Expand Down Expand Up @@ -2000,11 +1999,10 @@ def training_log(
elapsed_time = timers("interval-time").elapsed(barrier=True)
elapsed_time_per_iteration = elapsed_time / total_iterations

flops_calc = (
num_floating_point_operations
if not args.multi_latent_attention
else self.num_floating_point_operations_mla_moe
)
if args.multi_latent_attention and not getattr(args, "is_hybrid_model", False):
flops_calc = self.num_floating_point_operations_mla_moe
else:
flops_calc = num_floating_point_operations
throughput = flops_calc(args, batch_size) / (
elapsed_time_per_iteration * 10**12 * args.world_size
)
Expand Down
Loading