Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 107 additions & 27 deletions thunder/benchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import json
import os
import statistics
import sys
import time
import warnings
from typing import Any
Expand All @@ -40,16 +39,17 @@
import thunder
from thunder.dynamo.compiler import thunderfx
from thunder.benchmarks.layers_for_inference_benchmark import (
GroupedLinear,
GroupedSwiGLU,
Llama4MoE,
NVFP4InferenceGroupedLinear,
NVFP4InferenceLinear,
NVFP4InferenceGroupedSwiGLU,
nvfuser_f16a_nvfp4weight_scaled_grouped_mm,
nvfuser_f16a_nvfp4weight_scaled_mm,
FLOAT4_E2M1_MAX,
FLOAT8_E4M3_EPS,
FLOAT8_E4M3_MAX,
)
from thunder.torch.custom_op import _register_custom_op
from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel
from thunder.transforms.cudagraph import CUDAGraphTransform
from thunder.torch.custom_op import _register_custom_op, _register_nvfuser_translator

if TYPE_CHECKING:
from typing import Any
Expand All @@ -73,6 +73,71 @@
LLAMA4_MAVERICK_MODEL_ID: str = "meta-llama/Llama-4-Maverick-17B-128E"


# TODO: Add mm quantization once nvfuser implements nvfp4 gemm
# Register nvfp4 custom ops with Thunder and nvFuser
def _register_nvfp4_ops():
"""Register nvfp4 custom operations with Thunder."""
# Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator
_nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm)

def nvfp4_grouped_mm_translator(
activation,
fp4_weight,
weight_scaling_factor,
global_scale,
offsets,
blockscale_offsets,
problem_sizes,
*,
fd,
lc_to_nv_map,
):
from nvfuser_direct import DataType
from thunder.executors.nvfuserex_impl import getnv

nv_act = getnv(activation, fd, lc_to_nv_map)
nv_fp4_w = getnv(fp4_weight, fd, lc_to_nv_map)
nv_sf_w = getnv(weight_scaling_factor, fd, lc_to_nv_map)
nv_alpha = getnv(global_scale, fd, lc_to_nv_map)
nv_offsets = getnv(offsets, fd, lc_to_nv_map)
nv_blocksf_offsets = getnv(blockscale_offsets, fd, lc_to_nv_map)
nv_problem_sizes = getnv(problem_sizes, fd, lc_to_nv_map)
# dynamic shape support has some concretization issue
m_size = activation.shape[0]
k_size = activation.shape[1]
k_tile_size = k_size // 16

reshaped_mat1 = fd.ops.reshape(nv_act, [m_size, k_tile_size, 16])
scale1 = fd.ops.abs(reshaped_mat1)
scale1 = fd.ops.max(scale1, 2)
scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX)
scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX)

broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True])
reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1)
reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX)

scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size])
fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn)
fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn)
layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, nv_offsets, nv_blocksf_offsets)
out = fd.ops.cutlass_nvfp4_grouped_mm(
fp4_mat1,
nv_fp4_w,
layout_fp8_scale1,
nv_sf_w,
nv_alpha,
# NOTE: we might need to call contiguous on problem_sizes
nv_problem_sizes,
nv_offsets,
nv_blocksf_offsets,
DataType.BFloat16,
)
return out

_register_nvfuser_translator(_nvfp4_grouped_mm_symbol, nvfp4_grouped_mm_translator)


# The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230
def _replace_with_custom_fn_if_matches_filter_with_name(
model,
Expand Down Expand Up @@ -117,16 +182,18 @@ def _replace_llama4_moe(model: nn.Module) -> None:


def _quantize_llama4(model: nn.Module) -> None:
"""Replace linear and moe with nvfp4 inference version."""
_replace_with_custom_fn_if_matches_filter_with_name(
model,
NVFP4InferenceLinear.from_linear,
lambda model, cur_fqn: isinstance(model, nn.Linear),
)
"""Replace linear and/or MoE with nvfp4 inference version.

Args:
model: The model to quantize

Note: GroupedSwiGLU is always quantized when this function is called.
"""
# Always quantize GroupedSwiGLU when this function is called
_replace_with_custom_fn_if_matches_filter_with_name(
model,
NVFP4InferenceGroupedLinear.from_grouped_linear,
lambda model, cur_fqn: isinstance(model, GroupedLinear),
NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu,
lambda model, cur_fqn: isinstance(model, GroupedSwiGLU),
)


Expand All @@ -150,7 +217,7 @@ class InferenceBenchmarkConfig:
num_layers: int | None
num_iterations: int
warmup_iterations: int
enable_nvfp4: bool # Enable NVFP4 quantization
enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE
fx_report_folder: str | None
enable_nv_linear: bool
mode: str
Expand Down Expand Up @@ -670,7 +737,11 @@ def parse_args() -> argparse.Namespace:
help="Specify the folder for thunderfx_benchmark_report.",
)

parser.add_argument("--enable-nvfp4", action="store_true", help="Enable NVFP4 quantization for linear layers")
parser.add_argument(
"--enable-nvfp4",
action="store_true",
help="Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)",
)
Comment on lines +740 to +744
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems to requite NVFUSER_ENABLE="id_model(all)" at the moment. We might want to set the env var when this option is set.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

linking nvfuser issue: NVIDIA/Fuser#5200

parser.add_argument(
"--enable-nv-linear",
action="store_true",
Expand All @@ -682,6 +753,11 @@ def parse_args() -> argparse.Namespace:
help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat:<N> ... --profile` to record only the non-warmup iterations.",
)

parser.add_argument(
"--thunder-trace",
action="store_true",
help="Enable debug dump of thunder trace",
)
parser.add_argument("--save-results", action="store_true", help="Save results to JSON file")
parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results")
parser.add_argument(
Expand All @@ -702,13 +778,13 @@ def main():
if args.save_results:
os.makedirs(args.output_dir, exist_ok=True)

# TODO: Override the forward with nvfuser_direct based implementation like
# https://github.com/Lightning-AI/lightning-thunder/blob/8b72715d/thunder/tests/test_torch_library_custom_op.py#L250-L266 does.
# Note that the linked code is in a draft pull request of https://github.com/Lightning-AI/lightning-thunder/pull/2481
# so we might want to do it more clumsily by copying the code in the pull request for now.
# Register NVFP4 custom ops with nvfuser translators when enabled
if args.enable_nvfp4:
sym_of_nvfp4_scaled_mm = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm) # noqa: F841
sym_of_nvfp4_scaled_grouped_mm = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) # noqa: F841
try:
_register_nvfp4_ops()
except Exception as e:
# If registration fails (e.g., nvfuser not available), warn and continue
warnings.warn(f"Failed to register nvfp4 custom ops: {e}")

config = InferenceBenchmarkConfig(
model_name=args.model_name,
Expand All @@ -730,13 +806,17 @@ def main():
)
benchmark = InferenceBenchmark(config)

if args.enable_nvfp4:
msg = "NVFP4 kernels are not yet available. `--enable-nvfp4` runs only quantization but not benchmark"
warnings.warn(msg)
sys.exit(0)

benchmark.run_benchmark()
benchmark.print_results()

if args.thunder_trace and args.mode == "thunder":
backend = benchmark.model._backend
for subgraph_info in backend.subgraph_infos:
assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule)
assert len(subgraph_info.thunder_compiled_fns)
for thunder_fn in subgraph_info.thunder_compiled_fns:
print(thunder.last_traces(thunder_fn)[-1])

if args.save_results:
timestamp = time.strftime("%Y%m%d_%H%M%S")
filename = f"thunder_inference_{args.model_name.replace('/', '_')}_{timestamp}.json"
Expand Down
Loading
Loading