Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions TraceLens/PerfModel/perf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4590,6 +4590,53 @@ def bytes_bwd(self):
return bytes


# 9. TransformerEngine standalone LayerNorm
# TE's _Linear and _LayerNormLinear are handled by the Megatron extension
# (example_megatron_extension.py) which decomposes them into pseudo-ops.
# LayerNormFn is a standalone op not covered by the extension.


class te_layer_norm_fn(Normalization):
"""
TransformerEngine LayerNormFn: standalone LayerNorm (no fused GEMM).
Memory-bound; reports both FLOPS (via inherited Normalization.flops())
and TB/s.

Comment thread
gphuang marked this conversation as resolved.
In the trace:
Input[0] = X with shape [..., hidden]
Input[1] = gamma with shape [hidden]
Comment thread
gphuang marked this conversation as resolved.
Input[2] = beta with shape [hidden] (optional; detected via has_bias)
"""

@staticmethod
def get_param_details(event):
input_dims = event["args"]["Input Dims"]
op_shape = tuple(input_dims[0])
gamma_shape = input_dims[1]
num_channels = prod(gamma_shape)
dtype_in = event["args"]["Input type"][0]
stride_input = tuple(event["args"]["Input Strides"][0])
has_bias = (
len(input_dims) > 2 and input_dims[2] is not None and len(input_dims[2]) > 0
)
return {
"op_shape": op_shape,
"dtype_in_out": (dtype_in, None),
"stride_input": stride_input,
"stride_output": None,
"num_channels": num_channels,
"has_bias": has_bias,
"is_affine": True,
"is_training": True,
}

def flops_bwd(self):
raise NotImplementedError(f"Backward pass for te_layer_norm_fn is not defined.")

def bytes_bwd(self, bytes_per_element):
raise NotImplementedError(f"Backward pass for te_layer_norm_fn is not defined.")


# ==============================================================================
# MoE Communication – MoEDispatch / MoECombine (token routing)
# ==============================================================================
Expand Down
12 changes: 10 additions & 2 deletions TraceLens/PerfModel/torch_op_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
"FusedRoPEFunc": perf_model.fused_rope_fwd,
# CrossEntropy (fused softmax + nll loss)
"CrossEntropyFunction": perf_model.cross_entropy_fwd,
# TE standalone LayerNorm (memory-bound; reports both FLOPS and TB/s).
# _Linear / _LayerNormLinear are handled by the Megatron extension
# (example_megatron_extension.py) which decomposes them into pseudo-ops.
"LayerNormFn": perf_model.te_layer_norm_fn,
# Mamba SSD (fused conv1d + selective scan, issue #552)
"MambaSplitConv1dScanCombinedFn": perf_model.mamba_ssd_fwd,
}
Expand Down Expand Up @@ -221,8 +225,10 @@ def categorize_torch_op(row):
"ConvBiasReLU_Backward",
]:
return "CONV_bwd"
elif row["name"] in norm_ops.keys():
if row["name"].endswith("_backward"):
elif row["name"] in norm_ops.keys() or row["name"] in dict_cat2names.get(
"Normalization", []
):
if row["name"].endswith("_backward") or row["name"].endswith("Backward"):
return "NORM_bwd"
else:
return "NORM_fwd"
Expand All @@ -247,6 +253,8 @@ def categorize_torch_op(row):
return "MoE_fused"
elif row["name"] in dict_cat2names.get("MoE_unfused", []):
return "MoE_unfused"
elif row["name"] == "LayerNormFnBackward":
return "NORM_bwd"
elif row["name"] in dict_cat2names.get("SSM", []) or row["name"] in [
"MambaSplitConv1dScanCombinedFn",
]:
Expand Down
86 changes: 86 additions & 0 deletions tests/test_te_linear_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
###############################################################################
# Copyright (c) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.
###############################################################################

from TraceLens.PerfModel.perf_model import te_layer_norm_fn
from TraceLens.PerfModel.torch_op_mapping import (
categorize_torch_op,
op_to_perf_model_class_map,
)

# ---------------------------------------------------------------------------
# Mapping and categorization
# ---------------------------------------------------------------------------


def test_te_layer_norm_fn_mapped():
assert op_to_perf_model_class_map["LayerNormFn"] is te_layer_norm_fn


def test_layer_norm_fn_categorizes_as_norm_fwd():
row = {"name": "LayerNormFn", "kernel_details": []}
assert categorize_torch_op(row) == "NORM_fwd"


def test_layer_norm_fn_backward_categorizes_as_norm_bwd():
row = {"name": "LayerNormFnBackward", "kernel_details": []}
assert categorize_torch_op(row) == "NORM_bwd"


# ---------------------------------------------------------------------------
# Helpers — synthetic events matching real TE trace format
# ---------------------------------------------------------------------------


def _layer_norm_fn_event(X_shape, gamma_shape, dtype="c10::BFloat16"):
"""LayerNormFn: Input[0]=X, Input[1]=gamma."""
strides_X = []
s = 1
for d in reversed(X_shape):
strides_X.insert(0, s)
s *= d
return {
"name": "LayerNormFn",
"args": {
"Input Dims": [X_shape, gamma_shape, [], [], [], [], [], []],
"Input type": [
dtype,
dtype,
"",
"",
"Scalar",
"Scalar",
"Scalar",
"Scalar",
],
"Input Strides": [strides_X, [1], [], [], [], [], [], []],
"Concrete Inputs": ["", "", "", "", "1e-05", "256", "False", "True"],
},
}


# ---------------------------------------------------------------------------
# te_layer_norm_fn — normalization model
# ---------------------------------------------------------------------------


def test_te_layer_norm_fn_instantiates():
event = _layer_norm_fn_event(X_shape=[2048, 4, 2048], gamma_shape=[2048])
model = te_layer_norm_fn(event)
assert model.num_elems == 2048 * 4 * 2048
assert model.num_channels == 2048


def test_te_layer_norm_fn_bytes():
event = _layer_norm_fn_event(X_shape=[2048, 4, 2048], gamma_shape=[2048])
model = te_layer_norm_fn(event)
num_elems = 2048 * 4 * 2048
num_channels = 2048
bpe = 2 # bf16
# is_affine=True, is_training=True, has_bias=False
# num_weight_tensors = 2 (mean+var) + 1 (gamma) + 2 (training) = 5
activation_bytes = num_elems * bpe + num_elems * bpe
weight_bytes = 5 * num_channels * bpe
assert model.bytes() == activation_bytes + weight_bytes
Loading