From 7f13eea7e3b1cb4bc5644805869e53ce972d4460 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 12 Oct 2025 06:03:48 -0700 Subject: [PATCH 01/17] test nvfuser 5230 Signed-off-by: Masaki Kozuki --- thunder/benchmarks/benchmark_inference.py | 8 +- .../layers_for_inference_benchmark.py | 107 ++++++++++++++---- 2 files changed, 90 insertions(+), 25 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 34186857e8..b7c8f8fb0c 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -40,9 +40,9 @@ import thunder from thunder.dynamo.compiler import thunderfx from thunder.benchmarks.layers_for_inference_benchmark import ( - GroupedLinear, + GroupedSwiGLU, Llama4MoE, - NVFP4InferenceGroupedLinear, + NVFP4InferenceGroupedSwiGLU, NVFP4InferenceLinear, nvfuser_f16a_nvfp4weight_scaled_grouped_mm, nvfuser_f16a_nvfp4weight_scaled_mm, @@ -125,8 +125,8 @@ def _quantize_llama4(model: nn.Module) -> None: ) _replace_with_custom_fn_if_matches_filter_with_name( model, - NVFP4InferenceGroupedLinear.from_grouped_linear, - lambda model, cur_fqn: isinstance(model, GroupedLinear), + NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu, + lambda model, cur_fqn: isinstance(model, GroupedSwiGLU), ) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 32ea0a574c..4bbc578ae8 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -28,8 +28,10 @@ __all__ = [ "GroupedLinear", + "GroupedSwiGLU", "Llama4MoE", "NVFP4InferenceGroupedLinear", + "NVFP4InferenceGroupedSwiGLU", "NVFP4InferenceLinear", "nvfuser_f16a_nvfp4weight_scaled_grouped_mm", "nvfuser_f16a_nvfp4weight_scaled_mm", @@ -301,14 +303,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: @staticmethod def from_linear(linear: nn.Linear, fqn: str | None = None) -> NVFP4InferenceLinear: - """ - Creates an NVFP4InferenceLinear layer from a standard nn.Linear layer. - - Args: - linear (nn.Linear): The source linear layer. - fqn (str | None, optional): Fully qualified name of the layer. Currently unused, - but retained for compatibility with interfaces that require it or for future use. - """ weight = linear.weight bias = linear.bias out_features, in_features = weight.size() @@ -435,18 +429,47 @@ def __init__( self.register_buffer("ab_strides", ab_strides) self.register_buffer("c_strides", c_strides) - # TODO: Update this accordingly to the progress of nvfp4 kernel implementation. - def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + @staticmethod + def compute_auxiliary_tensors( + hidden_states: torch.Tensor, + offsets: torch.Tensor, + out_features: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Compute blockscale_offsets and problem_sizes for grouped mm. + + These can be computed once and reused across multiple forward calls with the same offsets. + """ tokens_per_group = offsets[1:] - offsets[:-1] problem_sizes = torch.stack( [ tokens_per_group, - torch.full_like(tokens_per_group, hidden_states.size(0)), - torch.full_like(tokens_per_group, self.fp4_weight.size(2) * 2), + torch.full_like(tokens_per_group, hidden_states.size(1)), + torch.full_like(tokens_per_group, out_features), ], dim=1, ) - blockscale_offsets = torch.cumsum(torch.ceil(tokens_per_group, 128) * 128) + # Calculate block-scale offsets: round up to 128, then cumsum with initial 0 + rounded_tokens = ((tokens_per_group + 127) // 128) * 128 + blockscale_offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=tokens_per_group.device), + torch.cumsum(rounded_tokens, 0, dtype=torch.int32), + ] + ) + return blockscale_offsets, problem_sizes + + # TODO: Update this accordingly to the progress of nvfp4 kernel implementation. + def forward( + self, + hidden_states: torch.Tensor, + offsets: torch.Tensor, + blockscale_offsets: torch.Tensor | None = None, + problem_sizes: torch.Tensor | None = None, + ) -> torch.Tensor: + if blockscale_offsets is None or problem_sizes is None: + # Compute them if not provided (backward compatibility) + out_features = self.fp4_weight.size(2) * 2 + blockscale_offsets, problem_sizes = self.compute_auxiliary_tensors(hidden_states, offsets, out_features) return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_grouped_mm( hidden_states, self.fp4_weight, @@ -461,13 +484,6 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T @staticmethod def from_grouped_linear(grouped_linear: GroupedLinear, fqn: str | None = None) -> NVFP4InferenceGroupedLinear: - """ - Create an NVFP4InferenceGroupedLinear from a GroupedLinear. - - Args: - grouped_linear (GroupedLinear): The source GroupedLinear. - fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility. - """ weight = grouped_linear.weight ( fp4_weight, @@ -499,6 +515,49 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T ) +class NVFP4InferenceGroupedSwiGLU(nn.Module): + """NVFP4 GroupedSwiGLU that efficiently reuses auxiliary tensor computations.""" + + def __init__( + self, + gate_proj: NVFP4InferenceGroupedLinear, + up_proj: NVFP4InferenceGroupedLinear, + down_proj: NVFP4InferenceGroupedLinear, + ): + super().__init__() + self.gate_proj = gate_proj + self.up_proj = up_proj + self.down_proj = down_proj + + def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: + # Compute auxiliary tensors once for all three operations + intermediate_features = self.gate_proj.fp4_weight.size(2) * 2 + blockscale_offsets_gate, problem_sizes_gate = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( + hidden_states, offsets, intermediate_features + ) + + gate_out = self.gate_proj(hidden_states, offsets, blockscale_offsets_gate, problem_sizes_gate) + up_out = self.up_proj(hidden_states, offsets, blockscale_offsets_gate, problem_sizes_gate) + + intermediate = torch.nn.functional.silu(gate_out) * up_out + + # For down_proj, we need different problem_sizes (different output features) + hidden_features = self.down_proj.fp4_weight.size(2) * 2 + blockscale_offsets_down, problem_sizes_down = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( + intermediate, offsets, hidden_features + ) + + return self.down_proj(intermediate, offsets, blockscale_offsets_down, problem_sizes_down) + + @staticmethod + def from_grouped_swiglu(grouped_swiglu: GroupedSwiGLU, fqn: str | None = None) -> NVFP4InferenceGroupedSwiGLU: + """Convert a GroupedSwiGLU to NVFP4InferenceGroupedSwiGLU.""" + gate_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.gate_proj) + up_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.up_proj) + down_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.down_proj) + return NVFP4InferenceGroupedSwiGLU(gate_proj, up_proj, down_proj) + + # Slightly modified version of `thunder.tests.test_networks.Llama4MoE` # to have the same singature as transformers' Llama4TextMoe -- in this file # return values include `router_logits`. @@ -597,7 +656,13 @@ def run_routed_experts(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, # Without `torch.int32`, we see `RuntimeError: Offsets tensor must be integer (int32) tensor, but got torch.int64.` # from PyTorch when calling _grouped_mm. - offsets = torch.cumsum(tokens_per_expert, 0, dtype=torch.int32) # [n] + # Prepend 0 to offsets for correct grouping + offsets = torch.cat( + [ + torch.zeros(1, dtype=torch.int32, device=tokens_per_expert.device), + torch.cumsum(tokens_per_expert, 0, dtype=torch.int32), + ] + ) # [n+1] outs_sorted_by_expert_id = self.routed_experts(tokens_sorted_by_expert_id, offsets) # [s, h] token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id) From 3014c4283dbe8e41c1d6b277f9ed380641efea29 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 12 Oct 2025 07:28:52 -0700 Subject: [PATCH 02/17] Implement NVFP4 custom operations registration and quantization options in inference benchmark. Enhance `_quantize_llama4` to conditionally quantize linear layers. Update command-line arguments for NVFP4 registration and quantization control. Adjust custom operations to ensure correct tensor shapes and handling. --- thunder/benchmarks/benchmark_inference.py | 199 +++++++++++++++--- .../layers_for_inference_benchmark.py | 72 +++---- 2 files changed, 208 insertions(+), 63 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index b7c8f8fb0c..a21277631d 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -17,7 +17,6 @@ import json import os import statistics -import sys import time import warnings from typing import Any @@ -46,10 +45,13 @@ NVFP4InferenceLinear, nvfuser_f16a_nvfp4weight_scaled_grouped_mm, nvfuser_f16a_nvfp4weight_scaled_mm, + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_EPS, + FLOAT8_E4M3_MAX, ) -from thunder.torch.custom_op import _register_custom_op from thunder.tests.distributed.test_moe import GroupedLinearColwiseParallel, GroupedLinearRowwiseParallel from thunder.transforms.cudagraph import CUDAGraphTransform +from thunder.torch.custom_op import _register_custom_op, _register_nvfuser_translator if TYPE_CHECKING: from typing import Any @@ -73,6 +75,114 @@ LLAMA4_MAVERICK_MODEL_ID: str = "meta-llama/Llama-4-Maverick-17B-128E" +# Register nvfp4 custom ops with Thunder and nvFuser +def _register_nvfp4_ops(): + """Register nvfp4 custom operations with Thunder.""" + # Register f16a_nvfp4weight_scaled_mm + _nvfp4_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm) + + def nvfp4_mm_translator( + activation, + fp4_weight, + weight_scaling_factor, + weight_global_scale, + bias, + *, + fd, + lc_to_nv_map, + ): + """Translator for nvfp4 matmul to nvfuser.""" + from thunder.executors.nvfuserex_impl import getnv + + nv_activation = getnv(activation, fd, lc_to_nv_map) + nv_fp4_weight = getnv(fp4_weight, fd, lc_to_nv_map) + nv_weight_sf = getnv(weight_scaling_factor, fd, lc_to_nv_map) + nv_weight_gs = getnv(weight_global_scale, fd, lc_to_nv_map) + + if bias is not None: + nv_bias = getnv(bias, fd, lc_to_nv_map) + else: + nv_bias = None + + # Call nvfuser's nvfp4 operation + result = fd.ops.nvfp4_matmul( + nv_activation, + nv_fp4_weight, + nv_weight_sf, + nv_weight_gs, + ) + + if nv_bias is not None: + result = fd.ops.add(result, nv_bias) + + return result + + _register_nvfuser_translator(_nvfp4_mm_symbol, nvfp4_mm_translator) + + # Register f16a_nvfp4weight_scaled_grouped_mm + _nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) + + def nvfp4_grouped_mm_translator( + activation, + fp4_weight, + weight_scaling_factor, + global_scale, + offsets, + blockscale_offsets, + problem_sizes, + *, + fd, + lc_to_nv_map, + ): + from nvfuser_direct import DataType + from thunder.executors.nvfuserex_impl import getnv + + nv_act = getnv(activation, fd, lc_to_nv_map) + nv_fp4_w = getnv(fp4_weight, fd, lc_to_nv_map) + nv_sf_w = getnv(weight_scaling_factor, fd, lc_to_nv_map) + nv_alpha = getnv(global_scale, fd, lc_to_nv_map) + nv_offsets = getnv(offsets, fd, lc_to_nv_map) + nv_blocksf_offsets = getnv(blockscale_offsets, fd, lc_to_nv_map) + nv_problem_sizes = getnv(problem_sizes, fd, lc_to_nv_map) + # dynamic shape support has some concretization issue + m_size = activation.shape[0] + k_size = activation.shape[1] + k_tile_size = k_size // 16 + + reshaped_mat1 = fd.ops.reshape(nv_act, [m_size, k_tile_size, 16]) + scale1 = fd.ops.abs(reshaped_mat1) + scale1 = fd.ops.max(scale1, 2) + scale1 = fd.ops.div(scale1, FLOAT4_E2M1_MAX) + scale1 = fd.ops.clamp(scale1, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX) + + broadcast_scale1 = fd.ops.broadcast(scale1, [False, False, True]) + reshaped_scaled_mat1 = fd.ops.div(reshaped_mat1, broadcast_scale1) + reshaped_scaled_mat1 = fd.ops.clamp(reshaped_scaled_mat1, -FLOAT8_E4M3_MAX, FLOAT8_E4M3_MAX) + + scaled_mat1 = fd.ops.reshape(reshaped_scaled_mat1, [m_size, k_size]) + fp4_mat1 = fd.ops.cast(scaled_mat1, DataType.Float4_e2m1fn) + fp8_scale1 = fd.ops.cast(scale1, DataType.Float8_e4m3fn) + layout_fp8_scale1 = fd.ops.preprocess_grouped_matmul_input_sf(fp8_scale1, nv_offsets, nv_blocksf_offsets) + out = fd.ops.cutlass_nvfp4_grouped_mm( + fp4_mat1, + nv_fp4_w, + layout_fp8_scale1, + nv_sf_w, + nv_alpha, + # NOTE: we might need to call contiguous on problem_sizes + nv_problem_sizes, + nv_offsets, + nv_blocksf_offsets, + DataType.BFloat16, + ) + return out + + _register_nvfuser_translator(_nvfp4_grouped_mm_symbol, nvfp4_grouped_mm_translator) + + +# Note: _register_nvfp4_ops() is called conditionally in main() when --enable-nvfp4 is specified + + # The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230 def _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -116,18 +226,26 @@ def _replace_llama4_moe(model: nn.Module) -> None: ) -def _quantize_llama4(model: nn.Module) -> None: - """Replace linear and moe with nvfp4 inference version.""" - _replace_with_custom_fn_if_matches_filter_with_name( - model, - NVFP4InferenceLinear.from_linear, - lambda model, cur_fqn: isinstance(model, nn.Linear), - ) - _replace_with_custom_fn_if_matches_filter_with_name( - model, - NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu, - lambda model, cur_fqn: isinstance(model, GroupedSwiGLU), - ) +def _quantize_llama4(model: nn.Module, quantize_linear: bool = True, quantize_grouped_linear: bool = True) -> None: + """Replace linear and moe with nvfp4 inference version. + + Args: + model: The model to quantize + quantize_linear: Whether to quantize regular nn.Linear layers + quantize_grouped_linear: Whether to quantize GroupedSwiGLU layers + """ + if quantize_linear: + _replace_with_custom_fn_if_matches_filter_with_name( + model, + NVFP4InferenceLinear.from_linear, + lambda model, cur_fqn: isinstance(model, nn.Linear), + ) + if quantize_grouped_linear: + _replace_with_custom_fn_if_matches_filter_with_name( + model, + NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu, + lambda model, cur_fqn: isinstance(model, GroupedSwiGLU), + ) @contextmanager @@ -151,6 +269,10 @@ class InferenceBenchmarkConfig: num_iterations: int warmup_iterations: int enable_nvfp4: bool # Enable NVFP4 quantization + dtensor_single_gpu: bool + enable_nvfp4: bool # Enable NVFP4 registration (required for any nvfp4 quantization) + quantize_linear: bool # Quantize regular nn.Linear layers to NVFP4 + quantize_grouped_linear: bool # Quantize GroupedLinear/GroupedSwiGLU to NVFP4 fx_report_folder: str | None enable_nv_linear: bool mode: str @@ -278,7 +400,11 @@ def __init__(self, config: InferenceBenchmarkConfig): self.vocab_size = model.vocab_size if self.config.enable_nvfp4: - _quantize_llama4(model) + _quantize_llama4( + model, + quantize_linear=self.config.quantize_linear, + quantize_grouped_linear=self.config.quantize_grouped_linear, + ) self.model = self._compile_model(model) @property @@ -671,6 +797,26 @@ def parse_args() -> argparse.Namespace: ) parser.add_argument("--enable-nvfp4", action="store_true", help="Enable NVFP4 quantization for linear layers") + parser.add_argument( + "--dtensor-single-gpu", + action="store_true", + help="Use DTensor for single GPU", + ) + parser.add_argument( + "--enable-nvfp4", + action="store_true", + help="Enable NVFP4 custom op registration (required for --quantize-linear or --quantize-grouped-linear)", + ) + parser.add_argument( + "--quantize-linear", + action="store_true", + help="Quantize regular nn.Linear layers to NVFP4 (requires --enable-nvfp4)", + ) + parser.add_argument( + "--quantize-grouped-linear", + action="store_true", + help="Quantize GroupedLinear/GroupedSwiGLU to NVFP4 for MoE layers (requires --enable-nvfp4)", + ) parser.add_argument( "--enable-nv-linear", action="store_true", @@ -702,13 +848,17 @@ def main(): if args.save_results: os.makedirs(args.output_dir, exist_ok=True) - # TODO: Override the forward with nvfuser_direct based implementation like - # https://github.com/Lightning-AI/lightning-thunder/blob/8b72715d/thunder/tests/test_torch_library_custom_op.py#L250-L266 does. - # Note that the linked code is in a draft pull request of https://github.com/Lightning-AI/lightning-thunder/pull/2481 - # so we might want to do it more clumsily by copying the code in the pull request for now. + # Validate quantization flags + if (args.quantize_linear or args.quantize_grouped_linear) and not args.enable_nvfp4: + raise ValueError("--quantize-linear or --quantize-grouped-linear requires --enable-nvfp4 to be set") + + # Register NVFP4 custom ops with nvfuser translators when enabled if args.enable_nvfp4: - sym_of_nvfp4_scaled_mm = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm) # noqa: F841 - sym_of_nvfp4_scaled_grouped_mm = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) # noqa: F841 + try: + _register_nvfp4_ops() + except Exception as e: + # If registration fails (e.g., nvfuser not available), warn and continue + warnings.warn(f"Failed to register nvfp4 custom ops: {e}") config = InferenceBenchmarkConfig( model_name=args.model_name, @@ -720,6 +870,8 @@ def main(): warmup_iterations=args.warmup_iterations, mode=args.mode, enable_nvfp4=args.enable_nvfp4, + quantize_linear=args.quantize_linear, + quantize_grouped_linear=args.quantize_grouped_linear, fx_report_folder=args.fx_report_folder, enable_nv_linear=args.enable_nv_linear, disable_moe_replacement=args.disable_moe_replacement, @@ -730,11 +882,6 @@ def main(): ) benchmark = InferenceBenchmark(config) - if args.enable_nvfp4: - msg = "NVFP4 kernels are not yet available. `--enable-nvfp4` runs only quantization but not benchmark" - warnings.warn(msg) - sys.exit(0) - benchmark.run_benchmark() benchmark.print_results() if args.save_results: diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 4bbc578ae8..0436c7c4e6 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -201,9 +201,8 @@ def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, bloc return out -# TODO: Update this accordingly to the progress of nvfp4 kernel implementation. -# An alternative is to use `_register_nvfuser_translator` of https://github.com/Lightning-AI/lightning-thunder/pull/2481 -# instead of updating this function itself. +# NOTE: This custom op is registered with nvfuser translator in benchmark_inference.py +# using _register_nvfuser_translator. See benchmark_inference._register_nvfp4_ops(). @torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_mm", mutates_args=()) def nvfuser_f16a_nvfp4weight_scaled_mm( activation: torch.Tensor, @@ -212,10 +211,16 @@ def nvfuser_f16a_nvfp4weight_scaled_mm( weight_global_scale: torch.Tensor, bias: torch.Tensor | None, ) -> torch.Tensor: + # fp4_weight shape: (out_features, in_features // 2) - stored like nn.Linear weight hp_weight = dequantize_to_dtype( fp4_weight, weight_scaling_factor, weight_global_scale, activation.dtype, fp4_weight.device, 16 ) - return activation @ hp_weight + bias + # hp_weight shape after unpack: (out_features, in_features) + # Need to transpose to match nn.Linear: activation @ weight.T + result = activation @ hp_weight.T + if bias is not None: + result = result + bias + return result @torch.library.register_fake("nvf_cutlass::f16a_nvfp4weight_scaled_mm") @@ -226,20 +231,26 @@ def _( weight_global_scale: torch.Tensor, bias: torch.Tensor | None, ) -> torch.Tensor: - return torch.empty((activation.size(0), fp4_weight.size(0)), device=activation.device, dtype=activation.dtype) + # fp4_weight shape: (out_features, in_features // 2) + # Validate that activation has at least 1 dimension + if activation.ndim == 0: + raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") + # Output shape should match activation.shape[:-1] + (out_features,) + # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs + out_features = fp4_weight.size(0) + output_shape = activation.shape[:-1] + (out_features,) + return torch.empty(output_shape, device=activation.device, dtype=activation.dtype) -# TODO: Update this accordingly to the progress of nvfp4 kernel implementation. -# An alternative is to use `_register_nvfuser_translator` of https://github.com/Lightning-AI/lightning-thunder/pull/2481 -# instead of updating this function itself. + +# NOTE: This custom op is registered with nvfuser translator in benchmark_inference.py +# using _register_nvfuser_translator. See benchmark_inference._register_nvfp4_ops(). @torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_grouped_mm", mutates_args=()) def nvfuser_f16a_nvfp4weight_scaled_grouped_mm( activation: torch.Tensor, fp4_weight: torch.Tensor, weight_scaling_factor: torch.Tensor, weight_global_scale: torch.Tensor, - ab_strides: torch.Tensor, - c_strides: torch.Tensor, offsets: torch.Tensor, blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, @@ -262,13 +273,21 @@ def _( fp4_weight: torch.Tensor, weight_scaling_factor: torch.Tensor, weight_global_scale: torch.Tensor, - ab_strides: torch.Tensor, - c_strides: torch.Tensor, offsets: torch.Tensor, blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, ) -> torch.Tensor: - return torch.empty((activation.size(0), fp4_weight.size(1)), device=activation.device, dtype=activation.dtype) + # fp4_weight shape: (groups, in_features, out_features // 2) + # Validate that activation has at least 1 dimension + if activation.ndim == 0: + raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") + + # After unpacking: (groups, in_features, out_features) + # Output shape should match activation.shape[:-1] + (out_features,) + # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs + out_features = fp4_weight.size(2) * 2 # Unpack from FP4 + output_shape = activation.shape[:-1] + (out_features,) + return torch.empty(output_shape, device=activation.device, dtype=activation.dtype) class NVFP4InferenceLinear(nn.Module): @@ -375,20 +394,16 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T @torch.inference_mode() def quantize_grouped_linear_weight_to_nvfp4( weight: torch.Tensor | nn.Parameter, -) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: +) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Quantize grouped linear's weight to nvfp4 Args: weight: Parameter of `GroupedLinear` of [g, n, k] - m: hidden_states.size(0) - tokens_per_expert_neg_one: Returns: fp4_weight: [g, n, k // 2] scale_factors: [g, n, k // 16] global_scales: [g] - ab_strides: [g] - c_strides: [g] """ assert weight.ndim == 3, "Weight must be a 3D tensor" @@ -396,9 +411,6 @@ def quantize_grouped_linear_weight_to_nvfp4( g, n, k = weight.size() with device: - ab_strides = torch.full((g,), k, dtype=torch.int32) - c_strides = torch.full((g,), n, dtype=torch.int32) - fp4_weight = torch.empty((g, n, k // 2), dtype=torch.float4_e2m1fn_x2) global_scales = torch.empty((g,), dtype=torch.float32) scale_factors = torch.empty((g, n, k // 16), dtype=torch.float8_e4m3fn) @@ -410,7 +422,7 @@ def quantize_grouped_linear_weight_to_nvfp4( fp4_weight[i] = cur_fp4_weight scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors) - return fp4_weight, scale_factors, global_scales, ab_strides, c_strides + return fp4_weight, scale_factors, global_scales class NVFP4InferenceGroupedLinear(nn.Module): @@ -419,15 +431,11 @@ def __init__( fp4_weight: torch.Tensor, weight_scaling_factor: torch.Tensor, weight_global_scale: torch.Tensor, - ab_strides: torch.Tensor, - c_strides: torch.Tensor, ) -> None: super().__init__() self.register_buffer("fp4_weight", fp4_weight) self.register_buffer("weight_scaling_factor", weight_scaling_factor) self.register_buffer("weight_global_scale", weight_global_scale) - self.register_buffer("ab_strides", ab_strides) - self.register_buffer("c_strides", c_strides) @staticmethod def compute_auxiliary_tensors( @@ -475,8 +483,6 @@ def forward( self.fp4_weight, self.weight_scaling_factor, self.weight_global_scale, - self.ab_strides, - self.c_strides, offsets, blockscale_offsets, problem_sizes, @@ -485,19 +491,11 @@ def forward( @staticmethod def from_grouped_linear(grouped_linear: GroupedLinear, fqn: str | None = None) -> NVFP4InferenceGroupedLinear: weight = grouped_linear.weight - ( - fp4_weight, - weight_scaling_factor, - weight_global_scale, - ab_strides, - c_strides, - ) = quantize_grouped_linear_weight_to_nvfp4(weight) + fp4_weight, weight_scaling_factor, weight_global_scale = quantize_grouped_linear_weight_to_nvfp4(weight) return NVFP4InferenceGroupedLinear( fp4_weight, weight_scaling_factor, weight_global_scale, - ab_strides=ab_strides, - c_strides=c_strides, ) From a0f82de197f49fd1ea45b1a10b24791ff3d2f43f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Sun, 12 Oct 2025 07:41:36 -0700 Subject: [PATCH 03/17] Refactor NVFP4 custom operations registration and quantization logic. Update `_quantize_llama4` to simplify linear layer quantization handling. Modify command-line arguments for NVFP4 to clarify usage and remove deprecated options. Add warnings for experimental features and ensure proper registration of custom ops. --- thunder/benchmarks/benchmark_inference.py | 99 ++++++----------------- 1 file changed, 26 insertions(+), 73 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index a21277631d..3f393ff418 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -78,48 +78,11 @@ # Register nvfp4 custom ops with Thunder and nvFuser def _register_nvfp4_ops(): """Register nvfp4 custom operations with Thunder.""" - # Register f16a_nvfp4weight_scaled_mm + # Register f16a_nvfp4weight_scaled_mm (without nvfuser translator - not yet implemented) _nvfp4_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm) + # Note: nvfuser translator is not provided as nvfuser has not implemented nvfp4_matmul yet - def nvfp4_mm_translator( - activation, - fp4_weight, - weight_scaling_factor, - weight_global_scale, - bias, - *, - fd, - lc_to_nv_map, - ): - """Translator for nvfp4 matmul to nvfuser.""" - from thunder.executors.nvfuserex_impl import getnv - - nv_activation = getnv(activation, fd, lc_to_nv_map) - nv_fp4_weight = getnv(fp4_weight, fd, lc_to_nv_map) - nv_weight_sf = getnv(weight_scaling_factor, fd, lc_to_nv_map) - nv_weight_gs = getnv(weight_global_scale, fd, lc_to_nv_map) - - if bias is not None: - nv_bias = getnv(bias, fd, lc_to_nv_map) - else: - nv_bias = None - - # Call nvfuser's nvfp4 operation - result = fd.ops.nvfp4_matmul( - nv_activation, - nv_fp4_weight, - nv_weight_sf, - nv_weight_gs, - ) - - if nv_bias is not None: - result = fd.ops.add(result, nv_bias) - - return result - - _register_nvfuser_translator(_nvfp4_mm_symbol, nvfp4_mm_translator) - - # Register f16a_nvfp4weight_scaled_grouped_mm + # Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator _nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) def nvfp4_grouped_mm_translator( @@ -180,9 +143,6 @@ def nvfp4_grouped_mm_translator( _register_nvfuser_translator(_nvfp4_grouped_mm_symbol, nvfp4_grouped_mm_translator) -# Note: _register_nvfp4_ops() is called conditionally in main() when --enable-nvfp4 is specified - - # The logic is based on https://github.com/pytorch/ao/blob/b34c1037/torchao/quantization/quant_api.py#L230 def _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -226,13 +186,14 @@ def _replace_llama4_moe(model: nn.Module) -> None: ) -def _quantize_llama4(model: nn.Module, quantize_linear: bool = True, quantize_grouped_linear: bool = True) -> None: - """Replace linear and moe with nvfp4 inference version. +def _quantize_llama4(model: nn.Module, quantize_linear: bool = False) -> None: + """Replace linear and/or MoE with nvfp4 inference version. Args: model: The model to quantize - quantize_linear: Whether to quantize regular nn.Linear layers - quantize_grouped_linear: Whether to quantize GroupedSwiGLU layers + quantize_linear: Whether to quantize regular nn.Linear layers (experimental, nvfuser translator not implemented) + + Note: GroupedSwiGLU is always quantized when this function is called. """ if quantize_linear: _replace_with_custom_fn_if_matches_filter_with_name( @@ -240,12 +201,12 @@ def _quantize_llama4(model: nn.Module, quantize_linear: bool = True, quantize_gr NVFP4InferenceLinear.from_linear, lambda model, cur_fqn: isinstance(model, nn.Linear), ) - if quantize_grouped_linear: - _replace_with_custom_fn_if_matches_filter_with_name( - model, - NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu, - lambda model, cur_fqn: isinstance(model, GroupedSwiGLU), - ) + # Always quantize GroupedSwiGLU when this function is called + _replace_with_custom_fn_if_matches_filter_with_name( + model, + NVFP4InferenceGroupedSwiGLU.from_grouped_swiglu, + lambda model, cur_fqn: isinstance(model, GroupedSwiGLU), + ) @contextmanager @@ -270,9 +231,8 @@ class InferenceBenchmarkConfig: warmup_iterations: int enable_nvfp4: bool # Enable NVFP4 quantization dtensor_single_gpu: bool - enable_nvfp4: bool # Enable NVFP4 registration (required for any nvfp4 quantization) - quantize_linear: bool # Quantize regular nn.Linear layers to NVFP4 - quantize_grouped_linear: bool # Quantize GroupedLinear/GroupedSwiGLU to NVFP4 + enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE + quantize_linear: bool # [Experimental] Quantize nn.Linear to NVFP4 (nvfuser translator not implemented) fx_report_folder: str | None enable_nv_linear: bool mode: str @@ -400,11 +360,7 @@ def __init__(self, config: InferenceBenchmarkConfig): self.vocab_size = model.vocab_size if self.config.enable_nvfp4: - _quantize_llama4( - model, - quantize_linear=self.config.quantize_linear, - quantize_grouped_linear=self.config.quantize_grouped_linear, - ) + _quantize_llama4(model, quantize_linear=self.config.quantize_linear) self.model = self._compile_model(model) @property @@ -805,17 +761,12 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--enable-nvfp4", action="store_true", - help="Enable NVFP4 custom op registration (required for --quantize-linear or --quantize-grouped-linear)", + help="Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)", ) parser.add_argument( "--quantize-linear", action="store_true", - help="Quantize regular nn.Linear layers to NVFP4 (requires --enable-nvfp4)", - ) - parser.add_argument( - "--quantize-grouped-linear", - action="store_true", - help="Quantize GroupedLinear/GroupedSwiGLU to NVFP4 for MoE layers (requires --enable-nvfp4)", + help="[Experimental] Quantize nn.Linear to NVFP4. Note: nvfuser has not yet implemented nvfp4_matmul translator", ) parser.add_argument( "--enable-nv-linear", @@ -848,12 +799,15 @@ def main(): if args.save_results: os.makedirs(args.output_dir, exist_ok=True) - # Validate quantization flags - if (args.quantize_linear or args.quantize_grouped_linear) and not args.enable_nvfp4: - raise ValueError("--quantize-linear or --quantize-grouped-linear requires --enable-nvfp4 to be set") + # Warn if experimental flag is used + if args.quantize_linear: + warnings.warn( + "--quantize-linear is experimental. nvfuser has not implemented nvfp4_matmul translator yet. " + "The custom op is registered but fusion may not work optimally." + ) # Register NVFP4 custom ops with nvfuser translators when enabled - if args.enable_nvfp4: + if args.enable_nvfp4 or args.quantize_linear: try: _register_nvfp4_ops() except Exception as e: @@ -871,7 +825,6 @@ def main(): mode=args.mode, enable_nvfp4=args.enable_nvfp4, quantize_linear=args.quantize_linear, - quantize_grouped_linear=args.quantize_grouped_linear, fx_report_folder=args.fx_report_folder, enable_nv_linear=args.enable_nv_linear, disable_moe_replacement=args.disable_moe_replacement, From 87e9cd082a8d5f7386b0c747811c6ba72d6b8ac1 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 11 Nov 2025 16:34:29 +0900 Subject: [PATCH 04/17] dedup args Signed-off-by: Masaki Kozuki --- thunder/benchmarks/benchmark_inference.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 3f393ff418..bb15921c29 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -229,8 +229,6 @@ class InferenceBenchmarkConfig: num_layers: int | None num_iterations: int warmup_iterations: int - enable_nvfp4: bool # Enable NVFP4 quantization - dtensor_single_gpu: bool enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE quantize_linear: bool # [Experimental] Quantize nn.Linear to NVFP4 (nvfuser translator not implemented) fx_report_folder: str | None @@ -752,12 +750,6 @@ def parse_args() -> argparse.Namespace: help="Specify the folder for thunderfx_benchmark_report.", ) - parser.add_argument("--enable-nvfp4", action="store_true", help="Enable NVFP4 quantization for linear layers") - parser.add_argument( - "--dtensor-single-gpu", - action="store_true", - help="Use DTensor for single GPU", - ) parser.add_argument( "--enable-nvfp4", action="store_true", From 4002e04bf4b926b5d9000d702bd70de8095bfe8f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 22 Oct 2025 14:47:00 -0700 Subject: [PATCH 05/17] small fixes on the model --- .../layers_for_inference_benchmark.py | 27 ++++++++++--------- 1 file changed, 14 insertions(+), 13 deletions(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 0436c7c4e6..a94c09427d 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -256,7 +256,7 @@ def nvfuser_f16a_nvfp4weight_scaled_grouped_mm( problem_sizes: torch.Tensor, ) -> torch.Tensor: hp_weight = torch.empty( - (fp4_weight.size(0), fp4_weight.size(1), fp4_weight.size(2) * 2), + (fp4_weight.size(0), fp4_weight.size(1) * 2, fp4_weight.size(2)), device=activation.device, dtype=activation.dtype, ) @@ -277,7 +277,7 @@ def _( blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, ) -> torch.Tensor: - # fp4_weight shape: (groups, in_features, out_features // 2) + # fp4_weight shape: (groups, in_features // 2, out_features) # Validate that activation has at least 1 dimension if activation.ndim == 0: raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") @@ -285,7 +285,7 @@ def _( # After unpacking: (groups, in_features, out_features) # Output shape should match activation.shape[:-1] + (out_features,) # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs - out_features = fp4_weight.size(2) * 2 # Unpack from FP4 + out_features = fp4_weight.size(2) output_shape = activation.shape[:-1] + (out_features,) return torch.empty(output_shape, device=activation.device, dtype=activation.dtype) @@ -398,23 +398,24 @@ def quantize_grouped_linear_weight_to_nvfp4( """Quantize grouped linear's weight to nvfp4 Args: - weight: Parameter of `GroupedLinear` of [g, n, k] + weight: Parameter of `GroupedLinear` of [g, k, n] Returns: - fp4_weight: [g, n, k // 2] + fp4_weight: [g, k // 2, n] scale_factors: [g, n, k // 16] global_scales: [g] """ assert weight.ndim == 3, "Weight must be a 3D tensor" device: torch.device = weight.device - g, n, k = weight.size() + g, k, n = weight.size() with device: fp4_weight = torch.empty((g, n, k // 2), dtype=torch.float4_e2m1fn_x2) global_scales = torch.empty((g,), dtype=torch.float32) scale_factors = torch.empty((g, n, k // 16), dtype=torch.float8_e4m3fn) + weight = weight.transpose(-1, -2).contiguous() for i in range(g): cur_weight = weight[i] global_scales[i] = cur_weight.abs().amax() @@ -422,7 +423,7 @@ def quantize_grouped_linear_weight_to_nvfp4( fp4_weight[i] = cur_fp4_weight scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors) - return fp4_weight, scale_factors, global_scales + return fp4_weight.transpose(-1, -2), scale_factors, global_scales class NVFP4InferenceGroupedLinear(nn.Module): @@ -451,8 +452,8 @@ def compute_auxiliary_tensors( problem_sizes = torch.stack( [ tokens_per_group, - torch.full_like(tokens_per_group, hidden_states.size(1)), torch.full_like(tokens_per_group, out_features), + torch.full_like(tokens_per_group, hidden_states.size(1)), ], dim=1, ) @@ -463,7 +464,7 @@ def compute_auxiliary_tensors( torch.zeros(1, dtype=torch.int32, device=tokens_per_group.device), torch.cumsum(rounded_tokens, 0, dtype=torch.int32), ] - ) + )[0:-1] return blockscale_offsets, problem_sizes # TODO: Update this accordingly to the progress of nvfp4 kernel implementation. @@ -476,14 +477,14 @@ def forward( ) -> torch.Tensor: if blockscale_offsets is None or problem_sizes is None: # Compute them if not provided (backward compatibility) - out_features = self.fp4_weight.size(2) * 2 + out_features = self.fp4_weight.size(2) blockscale_offsets, problem_sizes = self.compute_auxiliary_tensors(hidden_states, offsets, out_features) return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_grouped_mm( hidden_states, self.fp4_weight, self.weight_scaling_factor, self.weight_global_scale, - offsets, + offsets[:-1], blockscale_offsets, problem_sizes, ) @@ -529,7 +530,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: # Compute auxiliary tensors once for all three operations - intermediate_features = self.gate_proj.fp4_weight.size(2) * 2 + intermediate_features = self.gate_proj.fp4_weight.size(2) blockscale_offsets_gate, problem_sizes_gate = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( hidden_states, offsets, intermediate_features ) @@ -540,7 +541,7 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T intermediate = torch.nn.functional.silu(gate_out) * up_out # For down_proj, we need different problem_sizes (different output features) - hidden_features = self.down_proj.fp4_weight.size(2) * 2 + hidden_features = self.down_proj.fp4_weight.size(2) blockscale_offsets_down, problem_sizes_down = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( intermediate, offsets, hidden_features ) From ca09d94de6d34e7f106d342d333235d4026bace6 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Tue, 11 Nov 2025 04:38:13 -0800 Subject: [PATCH 06/17] changes to fix meta info Signed-off-by: Masaki Kozuki --- .../layers_for_inference_benchmark.py | 24 ++++++++++++------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index a94c09427d..a66b5acdd0 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -256,7 +256,7 @@ def nvfuser_f16a_nvfp4weight_scaled_grouped_mm( problem_sizes: torch.Tensor, ) -> torch.Tensor: hp_weight = torch.empty( - (fp4_weight.size(0), fp4_weight.size(1) * 2, fp4_weight.size(2)), + (fp4_weight.size(0), fp4_weight.size(1), fp4_weight.size(2) * 2), device=activation.device, dtype=activation.dtype, ) @@ -277,7 +277,7 @@ def _( blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, ) -> torch.Tensor: - # fp4_weight shape: (groups, in_features // 2, out_features) + # fp4_weight shape: (groups, in_features, out_features // 2) # Validate that activation has at least 1 dimension if activation.ndim == 0: raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") @@ -285,7 +285,7 @@ def _( # After unpacking: (groups, in_features, out_features) # Output shape should match activation.shape[:-1] + (out_features,) # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs - out_features = fp4_weight.size(2) + out_features = fp4_weight.size(2) * 2 output_shape = activation.shape[:-1] + (out_features,) return torch.empty(output_shape, device=activation.device, dtype=activation.dtype) @@ -401,7 +401,7 @@ def quantize_grouped_linear_weight_to_nvfp4( weight: Parameter of `GroupedLinear` of [g, k, n] Returns: - fp4_weight: [g, k // 2, n] + fp4_weight: [g, n, k // 2] scale_factors: [g, n, k // 16] global_scales: [g] """ @@ -423,7 +423,7 @@ def quantize_grouped_linear_weight_to_nvfp4( fp4_weight[i] = cur_fp4_weight scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors) - return fp4_weight.transpose(-1, -2), scale_factors, global_scales + return fp4_weight, scale_factors, global_scales class NVFP4InferenceGroupedLinear(nn.Module): @@ -438,6 +438,14 @@ def __init__( self.register_buffer("weight_scaling_factor", weight_scaling_factor) self.register_buffer("weight_global_scale", weight_global_scale) + @property + def out_features(self) -> int: + return self.fp4_weight.size(2) * 2 + + @property + def in_features(self) -> int: + return self.fp4_weight.size(1) + @staticmethod def compute_auxiliary_tensors( hidden_states: torch.Tensor, @@ -477,7 +485,7 @@ def forward( ) -> torch.Tensor: if blockscale_offsets is None or problem_sizes is None: # Compute them if not provided (backward compatibility) - out_features = self.fp4_weight.size(2) + out_features = self.out_features blockscale_offsets, problem_sizes = self.compute_auxiliary_tensors(hidden_states, offsets, out_features) return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_grouped_mm( hidden_states, @@ -530,7 +538,7 @@ def __init__( def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: # Compute auxiliary tensors once for all three operations - intermediate_features = self.gate_proj.fp4_weight.size(2) + intermediate_features = self.gate_proj.out_features blockscale_offsets_gate, problem_sizes_gate = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( hidden_states, offsets, intermediate_features ) @@ -541,7 +549,7 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T intermediate = torch.nn.functional.silu(gate_out) * up_out # For down_proj, we need different problem_sizes (different output features) - hidden_features = self.down_proj.fp4_weight.size(2) + hidden_features = self.down_proj.out_features blockscale_offsets_down, problem_sizes_down = NVFP4InferenceGroupedLinear.compute_auxiliary_tensors( intermediate, offsets, hidden_features ) From c5c8b761ebe6a8870ac39cbb8b51e58db29c3078 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 12 Nov 2025 14:47:36 -0800 Subject: [PATCH 07/17] trimming the last element in offsets --- thunder/benchmarks/layers_for_inference_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index a66b5acdd0..16ce41efd2 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -669,7 +669,7 @@ def run_routed_experts(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.zeros(1, dtype=torch.int32, device=tokens_per_expert.device), torch.cumsum(tokens_per_expert, 0, dtype=torch.int32), ] - ) # [n+1] + )[:-1] # [n] outs_sorted_by_expert_id = self.routed_experts(tokens_sorted_by_expert_id, offsets) # [s, h] token_ids_sorted_by_expert_inverse_id = torch.argsort(token_ids_sorted_by_expert_id) From 00f0066e7963a6f7bb2970acdb5de6523604a72f Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 13 Nov 2025 10:54:43 -0800 Subject: [PATCH 08/17] fixing weight layout --- .../layers_for_inference_benchmark.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 16ce41efd2..9fac5eee62 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -256,11 +256,12 @@ def nvfuser_f16a_nvfp4weight_scaled_grouped_mm( problem_sizes: torch.Tensor, ) -> torch.Tensor: hp_weight = torch.empty( - (fp4_weight.size(0), fp4_weight.size(1), fp4_weight.size(2) * 2), + (fp4_weight.size(0), fp4_weight.size(1) * 2, fp4_weight.size(2)), device=activation.device, dtype=activation.dtype, ) for i in range(fp4_weight.size(0)): + # NOTE: dequantize here doesn't look right, since we have (g, k, n) hp_weight[i] = dequantize_to_dtype( fp4_weight[i], weight_scaling_factor[i], weight_global_scale[i], activation.dtype, fp4_weight.device, 16 ) @@ -277,7 +278,7 @@ def _( blockscale_offsets: torch.Tensor, problem_sizes: torch.Tensor, ) -> torch.Tensor: - # fp4_weight shape: (groups, in_features, out_features // 2) + # fp4_weight shape: (groups, in_features // 2, out_features) # Validate that activation has at least 1 dimension if activation.ndim == 0: raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") @@ -285,9 +286,9 @@ def _( # After unpacking: (groups, in_features, out_features) # Output shape should match activation.shape[:-1] + (out_features,) # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs - out_features = fp4_weight.size(2) * 2 + out_features = fp4_weight.size(2) output_shape = activation.shape[:-1] + (out_features,) - return torch.empty(output_shape, device=activation.device, dtype=activation.dtype) + return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16) class NVFP4InferenceLinear(nn.Module): @@ -401,7 +402,7 @@ def quantize_grouped_linear_weight_to_nvfp4( weight: Parameter of `GroupedLinear` of [g, k, n] Returns: - fp4_weight: [g, n, k // 2] + fp4_weight: [g, k // 2, n] scale_factors: [g, n, k // 16] global_scales: [g] """ @@ -423,7 +424,7 @@ def quantize_grouped_linear_weight_to_nvfp4( fp4_weight[i] = cur_fp4_weight scale_factors[i] = linear_to_swizzled_128_4(cur_scale_factors) - return fp4_weight, scale_factors, global_scales + return fp4_weight.transpose(-1, -2), scale_factors, global_scales class NVFP4InferenceGroupedLinear(nn.Module): @@ -440,11 +441,11 @@ def __init__( @property def out_features(self) -> int: - return self.fp4_weight.size(2) * 2 + return self.fp4_weight.size(2) @property def in_features(self) -> int: - return self.fp4_weight.size(1) + return self.fp4_weight.size(1) * 2 @staticmethod def compute_auxiliary_tensors( From 4ed1ee3dfc6446934ba30f3ab349f824f65fa19a Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Thu, 13 Nov 2025 12:04:08 -0800 Subject: [PATCH 09/17] fix more transpose on the weight --- thunder/benchmarks/layers_for_inference_benchmark.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 9fac5eee62..388d3056de 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -387,9 +387,10 @@ def __init__(self, groups: int, in_features: int, out_features: int, dtype: torc self.weight = nn.Parameter(torch.empty(groups, out_features, in_features, dtype=dtype, device=device)) # Initialize the weight in the same way as nn.Linear nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) + self.weight.data = self.weight.transpose(-1, -2) def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: - return grouped_mm(hidden_states, self.weight.transpose(-1, -2), offsets) + return grouped_mm(hidden_states, self.weight, offsets) @torch.inference_mode() @@ -631,13 +632,13 @@ def from_transformers_llama4textmoe(moe: Llama4TextMoe) -> Llama4MoE: # Split into gate and up projections gate_proj_w, up_proj_w = moe.experts.gate_up_proj.chunk(2, dim=2) - new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w.transpose(-1, -2)) - new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w.transpose(-1, -2)) + new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w) + new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w) # Handle down_proj # HF format: (groups, intermediate_size, hidden_size) # Our format: (groups, hidden, intermediate_size) - new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj.transpose(-1, -2)) + new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj) return new_moe From 1ead05dd7a67567214e07049dad6aade9ac6890f Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 17 Nov 2025 15:24:58 +0900 Subject: [PATCH 10/17] Remove NVFP4 GEMM related code Signed-off-by: Masaki Kozuki --- thunder/benchmarks/benchmark_inference.py | 34 +++-------------------- 1 file changed, 4 insertions(+), 30 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index bb15921c29..0fcbd30670 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -42,9 +42,7 @@ GroupedSwiGLU, Llama4MoE, NVFP4InferenceGroupedSwiGLU, - NVFP4InferenceLinear, nvfuser_f16a_nvfp4weight_scaled_grouped_mm, - nvfuser_f16a_nvfp4weight_scaled_mm, FLOAT4_E2M1_MAX, FLOAT8_E4M3_EPS, FLOAT8_E4M3_MAX, @@ -75,13 +73,10 @@ LLAMA4_MAVERICK_MODEL_ID: str = "meta-llama/Llama-4-Maverick-17B-128E" +# TODO: Add mm quantization once nvfuser implements nvfp4 gemm # Register nvfp4 custom ops with Thunder and nvFuser def _register_nvfp4_ops(): """Register nvfp4 custom operations with Thunder.""" - # Register f16a_nvfp4weight_scaled_mm (without nvfuser translator - not yet implemented) - _nvfp4_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_mm) - # Note: nvfuser translator is not provided as nvfuser has not implemented nvfp4_matmul yet - # Register f16a_nvfp4weight_scaled_grouped_mm with nvfuser translator _nvfp4_grouped_mm_symbol = _register_custom_op(nvfuser_f16a_nvfp4weight_scaled_grouped_mm) @@ -186,21 +181,14 @@ def _replace_llama4_moe(model: nn.Module) -> None: ) -def _quantize_llama4(model: nn.Module, quantize_linear: bool = False) -> None: +def _quantize_llama4(model: nn.Module) -> None: """Replace linear and/or MoE with nvfp4 inference version. Args: model: The model to quantize - quantize_linear: Whether to quantize regular nn.Linear layers (experimental, nvfuser translator not implemented) Note: GroupedSwiGLU is always quantized when this function is called. """ - if quantize_linear: - _replace_with_custom_fn_if_matches_filter_with_name( - model, - NVFP4InferenceLinear.from_linear, - lambda model, cur_fqn: isinstance(model, nn.Linear), - ) # Always quantize GroupedSwiGLU when this function is called _replace_with_custom_fn_if_matches_filter_with_name( model, @@ -230,7 +218,6 @@ class InferenceBenchmarkConfig: num_iterations: int warmup_iterations: int enable_nvfp4: bool # Enable NVFP4 registration and quantize GroupedSwiGLU in MoE - quantize_linear: bool # [Experimental] Quantize nn.Linear to NVFP4 (nvfuser translator not implemented) fx_report_folder: str | None enable_nv_linear: bool mode: str @@ -358,7 +345,7 @@ def __init__(self, config: InferenceBenchmarkConfig): self.vocab_size = model.vocab_size if self.config.enable_nvfp4: - _quantize_llama4(model, quantize_linear=self.config.quantize_linear) + _quantize_llama4(model) self.model = self._compile_model(model) @property @@ -755,11 +742,6 @@ def parse_args() -> argparse.Namespace: action="store_true", help="Enable NVFP4 quantization for MoE GroupedSwiGLU layers (has nvfuser grouped_mm support)", ) - parser.add_argument( - "--quantize-linear", - action="store_true", - help="[Experimental] Quantize nn.Linear to NVFP4. Note: nvfuser has not yet implemented nvfp4_matmul translator", - ) parser.add_argument( "--enable-nv-linear", action="store_true", @@ -791,15 +773,8 @@ def main(): if args.save_results: os.makedirs(args.output_dir, exist_ok=True) - # Warn if experimental flag is used - if args.quantize_linear: - warnings.warn( - "--quantize-linear is experimental. nvfuser has not implemented nvfp4_matmul translator yet. " - "The custom op is registered but fusion may not work optimally." - ) - # Register NVFP4 custom ops with nvfuser translators when enabled - if args.enable_nvfp4 or args.quantize_linear: + if args.enable_nvfp4: try: _register_nvfp4_ops() except Exception as e: @@ -816,7 +791,6 @@ def main(): warmup_iterations=args.warmup_iterations, mode=args.mode, enable_nvfp4=args.enable_nvfp4, - quantize_linear=args.quantize_linear, fx_report_folder=args.fx_report_folder, enable_nv_linear=args.enable_nv_linear, disable_moe_replacement=args.disable_moe_replacement, From 942504964eca0164add5035ba241246033e0ae8b Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 17 Nov 2025 15:34:29 +0900 Subject: [PATCH 11/17] remove quantized linear Signed-off-by: Masaki Kozuki --- .../layers_for_inference_benchmark.py | 90 ------------------- 1 file changed, 90 deletions(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 388d3056de..91ea55868f 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -32,9 +32,7 @@ "Llama4MoE", "NVFP4InferenceGroupedLinear", "NVFP4InferenceGroupedSwiGLU", - "NVFP4InferenceLinear", "nvfuser_f16a_nvfp4weight_scaled_grouped_mm", - "nvfuser_f16a_nvfp4weight_scaled_mm", ] @@ -201,48 +199,6 @@ def dequantize_to_dtype(tensor_fp4, tensor_sf, global_scale, dtype, device, bloc return out -# NOTE: This custom op is registered with nvfuser translator in benchmark_inference.py -# using _register_nvfuser_translator. See benchmark_inference._register_nvfp4_ops(). -@torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_mm", mutates_args=()) -def nvfuser_f16a_nvfp4weight_scaled_mm( - activation: torch.Tensor, - fp4_weight: torch.Tensor, - weight_scaling_factor: torch.Tensor, - weight_global_scale: torch.Tensor, - bias: torch.Tensor | None, -) -> torch.Tensor: - # fp4_weight shape: (out_features, in_features // 2) - stored like nn.Linear weight - hp_weight = dequantize_to_dtype( - fp4_weight, weight_scaling_factor, weight_global_scale, activation.dtype, fp4_weight.device, 16 - ) - # hp_weight shape after unpack: (out_features, in_features) - # Need to transpose to match nn.Linear: activation @ weight.T - result = activation @ hp_weight.T - if bias is not None: - result = result + bias - return result - - -@torch.library.register_fake("nvf_cutlass::f16a_nvfp4weight_scaled_mm") -def _( - activation: torch.Tensor, - fp4_weight: torch.Tensor, - weight_scaling_factor: torch.Tensor, - weight_global_scale: torch.Tensor, - bias: torch.Tensor | None, -) -> torch.Tensor: - # fp4_weight shape: (out_features, in_features // 2) - # Validate that activation has at least 1 dimension - if activation.ndim == 0: - raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") - - # Output shape should match activation.shape[:-1] + (out_features,) - # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs - out_features = fp4_weight.size(0) - output_shape = activation.shape[:-1] + (out_features,) - return torch.empty(output_shape, device=activation.device, dtype=activation.dtype) - - # NOTE: This custom op is registered with nvfuser translator in benchmark_inference.py # using _register_nvfuser_translator. See benchmark_inference._register_nvfp4_ops(). @torch.library.custom_op("nvf_cutlass::f16a_nvfp4weight_scaled_grouped_mm", mutates_args=()) @@ -291,52 +247,6 @@ def _( return torch.empty(output_shape, device=activation.device, dtype=torch.bfloat16) -class NVFP4InferenceLinear(nn.Module): - """NVFP4 Linear layer for Inference. - - Weight, its scaling factor, its global scale, and bias are registered as a buffer, not a parameter. - """ - - def __init__( - self, - in_features: int, - out_features: int, - *, - fp4_weight: torch.Tensor | nn.Parameter, - weight_scaling_factor: torch.Tensor | nn.Parameter, - weight_global_scale: torch.Tensor | nn.Parameter | None, - bias: torch.Tensor | nn.Parameter | None, - ) -> None: - super().__init__() - self.in_features = in_features - self.out_features = out_features - - self.register_buffer("fp4_weight", fp4_weight) - self.register_buffer("weight_scaling_factor", weight_scaling_factor) - self.register_buffer("weight_global_scale", weight_global_scale) - self.register_buffer("bias", bias) - - def forward(self, x: torch.Tensor) -> torch.Tensor: - return torch.ops.nvf_cutlass.f16a_nvfp4weight_scaled_mm( - x, self.fp4_weight, self.weight_scaling_factor, self.weight_global_scale, self.bias - ) - - @staticmethod - def from_linear(linear: nn.Linear, fqn: str | None = None) -> NVFP4InferenceLinear: - weight = linear.weight - bias = linear.bias - out_features, in_features = weight.size() - fp4_weight, weight_scaling_factor, weight_global_scale = quantize_linear_weight_to_nvfp4(weight) - return NVFP4InferenceLinear( - in_features, - out_features, - fp4_weight=fp4_weight, - weight_scaling_factor=weight_scaling_factor, - weight_global_scale=weight_global_scale, - bias=bias, - ) - - class SwiGLU(nn.Module): def __init__(self, hidden_size: int, intermediate_size: int, dtype: torch.dtype, device: str): super().__init__() From 58482e88e4432b5c99b7b656d0e3666875a2c4c4 Mon Sep 17 00:00:00 2001 From: Masaki Kozuki Date: Mon, 17 Nov 2025 15:34:46 +0900 Subject: [PATCH 12/17] bring back docstirng of `from_` methods Signed-off-by: Masaki Kozuki --- .../benchmarks/layers_for_inference_benchmark.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 91ea55868f..d4c8c95ca1 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -411,6 +411,12 @@ def forward( @staticmethod def from_grouped_linear(grouped_linear: GroupedLinear, fqn: str | None = None) -> NVFP4InferenceGroupedLinear: + """Create an NVFP4InferenceGroupedLinear from a GroupedLinear. + + Args: + grouped_linear (GroupedLinear): The source GroupedLinear. + fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility. + """ weight = grouped_linear.weight fp4_weight, weight_scaling_factor, weight_global_scale = quantize_grouped_linear_weight_to_nvfp4(weight) return NVFP4InferenceGroupedLinear( @@ -470,7 +476,12 @@ def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.T @staticmethod def from_grouped_swiglu(grouped_swiglu: GroupedSwiGLU, fqn: str | None = None) -> NVFP4InferenceGroupedSwiGLU: - """Convert a GroupedSwiGLU to NVFP4InferenceGroupedSwiGLU.""" + """Create an NVFP4InferenceGroupedSwiGLU from a GroupedSwiGLU. + + Args: + grouped_swiglu (GroupedSwiGLU): The source GroupedSwiGLU. + fqn (str or None): Fully qualified name. Currently unused; reserved for future use or compatibility. + """ gate_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.gate_proj) up_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.up_proj) down_proj = NVFP4InferenceGroupedLinear.from_grouped_linear(grouped_swiglu.down_proj) From b6f426258825dd3bccb265b87cb47a9f509979d9 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 18 Nov 2025 14:48:50 -0800 Subject: [PATCH 13/17] reverting the layout change for bf16 --- thunder/benchmarks/benchmark_inference.py | 14 ++++++++++++++ .../benchmarks/layers_for_inference_benchmark.py | 16 ++++++++-------- 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/thunder/benchmarks/benchmark_inference.py b/thunder/benchmarks/benchmark_inference.py index 0fcbd30670..0ae7c5c527 100644 --- a/thunder/benchmarks/benchmark_inference.py +++ b/thunder/benchmarks/benchmark_inference.py @@ -753,6 +753,11 @@ def parse_args() -> argparse.Namespace: help="Wrap each non-warmup iteration with cudaProfilerStart() and cudaProfilerStop(). This allows us to run `nsys profile --capture-range=cudaProfilerApi --capture-range-end=repeat: ... --profile` to record only the non-warmup iterations.", ) + parser.add_argument( + "--thunder-trace", + action="store_true", + help="Enable debug dump of thunder trace", + ) parser.add_argument("--save-results", action="store_true", help="Save results to JSON file") parser.add_argument("--output-dir", type=str, default="./results", help="Directory to save results") parser.add_argument( @@ -803,6 +808,15 @@ def main(): benchmark.run_benchmark() benchmark.print_results() + + if args.thunder_trace and args.mode == "thunder": + backend = benchmark.model._backend + for subgraph_info in backend.subgraph_infos: + assert isinstance(subgraph_info.original_graph_module, torch.fx.GraphModule) + assert len(subgraph_info.thunder_compiled_fns) + for thunder_fn in subgraph_info.thunder_compiled_fns: + print(thunder.last_traces(thunder_fn)[-1]) + if args.save_results: timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"thunder_inference_{args.model_name.replace('/', '_')}_{timestamp}.json" diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index d4c8c95ca1..8890883872 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -297,10 +297,10 @@ def __init__(self, groups: int, in_features: int, out_features: int, dtype: torc self.weight = nn.Parameter(torch.empty(groups, out_features, in_features, dtype=dtype, device=device)) # Initialize the weight in the same way as nn.Linear nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - self.weight.data = self.weight.transpose(-1, -2) + # self.weight.data = self.weight.transpose(-1, -2) def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: - return grouped_mm(hidden_states, self.weight, offsets) + return grouped_mm(hidden_states, self.weight.transpose(-1, -2), offsets) @torch.inference_mode() @@ -310,7 +310,7 @@ def quantize_grouped_linear_weight_to_nvfp4( """Quantize grouped linear's weight to nvfp4 Args: - weight: Parameter of `GroupedLinear` of [g, k, n] + weight: Parameter of `GroupedLinear` of [g, n, k] Returns: fp4_weight: [g, k // 2, n] @@ -320,14 +320,14 @@ def quantize_grouped_linear_weight_to_nvfp4( assert weight.ndim == 3, "Weight must be a 3D tensor" device: torch.device = weight.device - g, k, n = weight.size() + g, n, k = weight.size() with device: fp4_weight = torch.empty((g, n, k // 2), dtype=torch.float4_e2m1fn_x2) global_scales = torch.empty((g,), dtype=torch.float32) scale_factors = torch.empty((g, n, k // 16), dtype=torch.float8_e4m3fn) - weight = weight.transpose(-1, -2).contiguous() + weight = weight.contiguous() for i in range(g): cur_weight = weight[i] global_scales[i] = cur_weight.abs().amax() @@ -553,13 +553,13 @@ def from_transformers_llama4textmoe(moe: Llama4TextMoe) -> Llama4MoE: # Split into gate and up projections gate_proj_w, up_proj_w = moe.experts.gate_up_proj.chunk(2, dim=2) - new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w) - new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w) + new_moe.routed_experts.gate_proj.weight.data.copy_(gate_proj_w.transpose(-1, -2)) + new_moe.routed_experts.up_proj.weight.data.copy_(up_proj_w.transpose(-1, -2)) # Handle down_proj # HF format: (groups, intermediate_size, hidden_size) # Our format: (groups, hidden, intermediate_size) - new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj) + new_moe.routed_experts.down_proj.weight.data.copy_(moe.experts.down_proj.transpose(-1, -2)) return new_moe From 34c940637a5894e20e8ad29cf21415f4bdaa04b1 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 18 Nov 2025 15:06:20 -0800 Subject: [PATCH 14/17] adding assert on input devices --- thunder/benchmarks/layers_for_inference_benchmark.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 8890883872..441e59f28c 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -239,6 +239,9 @@ def _( if activation.ndim == 0: raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") + if len(set(t.device for t in [activation, fp4_weight, weight_scaling_factor, weight_global_scale, offsets, block_scale_offsets, problem_sizes])) != 1: + raise ValueError(f"Expected all inputs to be on the same device.") + # After unpacking: (groups, in_features, out_features) # Output shape should match activation.shape[:-1] + (out_features,) # This handles both 2D (tokens, hidden) and 3D (batch, seq_len, hidden) inputs @@ -316,6 +319,11 @@ def quantize_grouped_linear_weight_to_nvfp4( fp4_weight: [g, k // 2, n] scale_factors: [g, n, k // 16] global_scales: [g] + + Note: + The reason we choose different layout of weight is to avoid performance + regression for bf16. See + https://github.com/Lightning-AI/lightning-thunder/pull/2659 """ assert weight.ndim == 3, "Weight must be a 3D tensor" From 49498aea47aa29e66daf04dc3679ffec6e6e845c Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Tue, 18 Nov 2025 15:10:50 -0800 Subject: [PATCH 15/17] typo --- thunder/benchmarks/layers_for_inference_benchmark.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 441e59f28c..d2b19980a9 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -239,7 +239,7 @@ def _( if activation.ndim == 0: raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") - if len(set(t.device for t in [activation, fp4_weight, weight_scaling_factor, weight_global_scale, offsets, block_scale_offsets, problem_sizes])) != 1: + if len(set(t.device for t in [activation, fp4_weight, weight_scaling_factor, weight_global_scale, offsets, blockscale_offsets, problem_sizes])) != 1: raise ValueError(f"Expected all inputs to be on the same device.") # After unpacking: (groups, in_features, out_features) From b335832deaaca172ad80ab3517cde3c21e025300 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Nov 2025 23:11:14 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../layers_for_inference_benchmark.py | 20 +++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index d2b19980a9..1a0d18a3db 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -239,8 +239,24 @@ def _( if activation.ndim == 0: raise ValueError(f"Expected activation to have at least 1 dimension, got {activation.ndim}") - if len(set(t.device for t in [activation, fp4_weight, weight_scaling_factor, weight_global_scale, offsets, blockscale_offsets, problem_sizes])) != 1: - raise ValueError(f"Expected all inputs to be on the same device.") + if ( + len( + { + t.device + for t in [ + activation, + fp4_weight, + weight_scaling_factor, + weight_global_scale, + offsets, + blockscale_offsets, + problem_sizes, + ] + } + ) + != 1 + ): + raise ValueError("Expected all inputs to be on the same device.") # After unpacking: (groups, in_features, out_features) # Output shape should match activation.shape[:-1] + (out_features,) From 42c2a9d736593bd7f8d91d94e6bde1524e640918 Mon Sep 17 00:00:00 2001 From: jjsjann123 Date: Wed, 19 Nov 2025 08:45:47 -0800 Subject: [PATCH 17/17] remove commented code --- thunder/benchmarks/layers_for_inference_benchmark.py | 1 - 1 file changed, 1 deletion(-) diff --git a/thunder/benchmarks/layers_for_inference_benchmark.py b/thunder/benchmarks/layers_for_inference_benchmark.py index 1a0d18a3db..87d1f3533d 100644 --- a/thunder/benchmarks/layers_for_inference_benchmark.py +++ b/thunder/benchmarks/layers_for_inference_benchmark.py @@ -316,7 +316,6 @@ def __init__(self, groups: int, in_features: int, out_features: int, dtype: torc self.weight = nn.Parameter(torch.empty(groups, out_features, in_features, dtype=dtype, device=device)) # Initialize the weight in the same way as nn.Linear nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5)) - # self.weight.data = self.weight.transpose(-1, -2) def forward(self, hidden_states: torch.Tensor, offsets: torch.Tensor) -> torch.Tensor: return grouped_mm(hidden_states, self.weight.transpose(-1, -2), offsets)