diff --git a/tensorrt_llm/_torch/modules/fused_moe/interface.py b/tensorrt_llm/_torch/modules/fused_moe/interface.py index ca1e134bf96..24e8912c5ad 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/interface.py +++ b/tensorrt_llm/_torch/modules/fused_moe/interface.py @@ -198,6 +198,7 @@ def __init__( self.parallel_rank = self.mapping.tp_rank self.parallel_size = self.mapping.tp_size self.intermediate_size_per_partition = intermediate_size // self.tp_size + self.expand_intermediate_size_per_partition = self.intermediate_size_per_partition * self.intermediate_size_expand_ratio self.all_reduce = None if not self.use_dp and self.mapping.tp_size > 1: diff --git a/tensorrt_llm/_torch/modules/fused_moe/quantization.py b/tensorrt_llm/_torch/modules/fused_moe/quantization.py index 5e80d4840cf..fd4ef11d078 100644 --- a/tensorrt_llm/_torch/modules/fused_moe/quantization.py +++ b/tensorrt_llm/_torch/modules/fused_moe/quantization.py @@ -219,9 +219,9 @@ def create_weights( # bias if module.bias: if w3_w1_bias_shape is None: - w3_w1_bias_shape = (module.expert_size_per_partition, - module.intermediate_size_per_partition * - module.intermediate_size_expand_ratio) + w3_w1_bias_shape = ( + module.expert_size_per_partition, + module.expand_intermediate_size_per_partition) if w2_bias_shape is None: w2_bias_shape = (module.expert_size_per_partition, module.hidden_size) @@ -518,8 +518,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase): def create_weights(self, module: torch.nn.Module): weight_dtype = module.dtype w3_w1_weight_shape = (module.expert_size_per_partition, - module.intermediate_size_per_partition * - module.intermediate_size_expand_ratio, + module.expand_intermediate_size_per_partition, module.hidden_size) w2_weight_shape = ( module.expert_size_per_partition, @@ -584,7 +583,7 @@ def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module, w3_weight_scale = w3_weight_scale[...].reshape([]) max_w3_w1_weight_scale = max(w1_weight_scale, w3_weight_scale) - split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2 + split_length = module.expand_intermediate_size_per_partition // 2 w3_weight = dst_w3_w1_weight.narrow( dim=0, start=0, length=split_length).to(dtype=module.dtype) w1_weight = dst_w3_w1_weight.narrow( @@ -608,8 +607,7 @@ def create_weights(self, module: torch.nn.Module): weight_dtype = torch.float8_e4m3fn w3_w1_weight_shape = (module.expert_size_per_partition, - module.intermediate_size_per_partition * - module.intermediate_size_expand_ratio, + module.expand_intermediate_size_per_partition, module.hidden_size) w2_weight_shape = ( module.expert_size_per_partition, @@ -1658,6 +1656,38 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase): Base class for NVFP4 fused MoE methods for all backends. """ + def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int, + block_scales_vec_size: int): + # Divide by 16 because we use int64 to pack 16 fp4 values + w3_w1_weight_shape = (module.expert_size_per_partition, + module.expand_intermediate_size_per_partition, + module.hidden_size // weight_vec_size) + w2_weight_shape = (module.expert_size_per_partition, module.hidden_size, + module.intermediate_size_per_partition // + weight_vec_size) + + w3_w1_weight_scale_shape = ( + module.expert_size_per_partition, + module.expand_intermediate_size_per_partition, module.hidden_size // + module.scaling_vector_size // block_scales_vec_size) + w2_weight_scale_shape = (module.expert_size_per_partition, + module.hidden_size, + module.intermediate_size_per_partition // + module.scaling_vector_size // + block_scales_vec_size) + + if module.bias: + w3_w1_bias_shape = (module.expert_size_per_partition, + module.expand_intermediate_size_per_partition) + w2_bias_shape = (module.expert_size_per_partition, + module.hidden_size) + else: + w3_w1_bias_shape = None + w2_bias_shape = None + + return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, + w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape) + def create_weights(self, module: torch.nn.Module, weight_dtype, @@ -1667,35 +1697,23 @@ def create_weights(self, scaling_vector_size=16): module.scaling_vector_size = scaling_vector_size - # Divide by 16 because we use int64 to pack 16 fp4 values - w3_w1_weight_shape = (module.expert_size_per_partition, - module.intermediate_size_per_partition * - module.intermediate_size_expand_ratio, - module.hidden_size // weight_vec_size) - w2_weight_shape = (module.expert_size_per_partition, module.hidden_size, - module.intermediate_size_per_partition // - weight_vec_size) + + (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, w2_bias_shape, + w3_w1_weight_scale_shape, + w2_weight_scale_shape) = self.get_weights_shapes( + module, weight_vec_size, block_scales_vec_size) # Divide by 4 because we use int32 to pack 4 fp8 values # column parallel - w3_w1_weight_scale = nn.Parameter( - torch.ones(module.expert_size_per_partition, - module.intermediate_size_per_partition * - module.intermediate_size_expand_ratio, - module.hidden_size // module.scaling_vector_size // - block_scales_vec_size, - dtype=block_scales_dtype), - requires_grad=False) + w3_w1_weight_scale = nn.Parameter(torch.ones(w3_w1_weight_scale_shape, + dtype=block_scales_dtype), + requires_grad=False) module.register_parameter("w3_w1_weight_scale", w3_w1_weight_scale) # row parallel - w2_weight_scale = nn.Parameter( - torch.ones(module.expert_size_per_partition, - module.hidden_size, - module.intermediate_size_per_partition // - module.scaling_vector_size // block_scales_vec_size, - dtype=block_scales_dtype), - requires_grad=False) + w2_weight_scale = nn.Parameter(torch.ones(w2_weight_scale_shape, + dtype=block_scales_dtype), + requires_grad=False) module.register_parameter("w2_weight_scale", w2_weight_scale) fc31_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32), @@ -1716,8 +1734,12 @@ def create_weights(self, requires_grad=False) module.register_parameter("fc2_alpha", fc2_alpha) - super().create_weights(module, weight_dtype, w3_w1_weight_shape, - w2_weight_shape) + super().create_weights(module, + weight_dtype, + w3_w1_weight_shape=w3_w1_weight_shape, + w2_weight_shape=w2_weight_shape, + w3_w1_bias_shape=w3_w1_bias_shape, + w2_bias_shape=w2_bias_shape) self.setup_quant_scales(module) @@ -1926,6 +1948,55 @@ def setup_quant_scales(self, module: torch.nn.Module): class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod): weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE + NVFP4_ROW_ALIGNMENT = 128 + NVFP4_COL_ALIGNMENT = 4 + + def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int, + block_scales_vec_size: int): + """Override the base method to get aligned weights shapes for Cutlass nvfp4 alignment.""" + intermediate_size_expand_aligned = ( + module.expand_intermediate_size_per_partition + + self.NVFP4_ROW_ALIGNMENT - + 1) // self.NVFP4_ROW_ALIGNMENT * self.NVFP4_ROW_ALIGNMENT + + if module.hidden_size % self.NVFP4_COL_ALIGNMENT != 0: + raise ValueError( + f"hidden_size {module.hidden_size} must be divisible by {self.NVFP4_COL_ALIGNMENT}" + ) + hidden_size_aligned = module.hidden_size + + w3_w1_weight_shape = (module.expert_size_per_partition, + intermediate_size_expand_aligned, + hidden_size_aligned // weight_vec_size) + w2_weight_shape = (module.expert_size_per_partition, + hidden_size_aligned, + intermediate_size_expand_aligned // + module.intermediate_size_expand_ratio // + weight_vec_size) + + w3_w1_weight_scale_shape = (module.expert_size_per_partition, + intermediate_size_expand_aligned, + hidden_size_aligned // + module.scaling_vector_size // + block_scales_vec_size) + w2_weight_scale_shape = (module.expert_size_per_partition, + hidden_size_aligned, + intermediate_size_expand_aligned // + module.intermediate_size_expand_ratio // + module.scaling_vector_size // + block_scales_vec_size) + + if module.bias: + w3_w1_bias_shape = (module.expert_size_per_partition, + intermediate_size_expand_aligned) + w2_bias_shape = (module.expert_size_per_partition, + hidden_size_aligned) + else: + w3_w1_bias_shape = None + w2_bias_shape = None + + return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, + w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape) def create_weights(self, module: torch.nn.Module): weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4 @@ -1950,21 +2021,16 @@ def load_expert_w3_w1_weight_scale_nvfp4( module.tp_rank, TensorParallelMode.COLUMN, device=device) - # Keep weights in device buffer - # w3 - split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2 - dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0, - start=0, - length=split_length) - dst_w3_weight_scale.copy_( - w3_weight_scale.view(dst_w3_weight_scale.dtype)) - # w1 - dst_w1_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0, - start=split_length, - length=split_length) - dst_w1_weight_scale.copy_( - w1_weight_scale.view(dst_w1_weight_scale.dtype)) + cast_w3_weight_scale = w3_weight_scale.view( + dst_w3_w1_weight_scale.dtype) + cast_w1_weight_scale = w1_weight_scale.view( + dst_w3_w1_weight_scale.dtype) + cast_w31_weight_scale = torch.cat( + [cast_w3_weight_scale, cast_w1_weight_scale], dim=0) + cast_w31_weight_scale = self._maybe_padding_shape( + cast_w31_weight_scale, dst_w3_w1_weight_scale) + dst_w3_w1_weight_scale.copy_(cast_w31_weight_scale) orig_shape = dst_w3_w1_weight_scale.shape @@ -1986,9 +2052,12 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, module.tp_rank, TensorParallelMode.ROW, device=device) + + cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype) + cast_w2_weight_scale = self._maybe_padding_shape( + cast_w2_weight_scale, dst_w2_weight_scale) # Keep weights in device buffer - dst_w2_weight_scale.copy_( - w2_weight_scale.view(dst_w2_weight_scale.dtype)) + dst_w2_weight_scale.copy_(cast_w2_weight_scale) orig_shape = dst_w2_weight_scale.shape @@ -2000,6 +2069,60 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module, dst_w2_weight_scale.copy_(dst_w2_weight_scale_interleaved) + def load_expert_w3_w1_weight(self, module: torch.nn.Module, + w1_weight: torch.Tensor, + w3_weight: torch.Tensor, + dst_w3_w1_weight: torch.Tensor): + """Load and pad w1 and w3 weights for each expert, to match shape requirements for Cutlass nvfp4 alignment.""" + device = dst_w3_w1_weight.device + w1_weight_shard = load_weight_shard(w1_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + w3_weight_shard = load_weight_shard(w3_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.COLUMN, + device=device) + + cast_w1_weight_shard = w1_weight_shard.view(dst_w3_w1_weight.dtype) + cast_w3_weight_shard = w3_weight_shard.view(dst_w3_w1_weight.dtype) + cast_w31_weight_shard = torch.cat( + [cast_w3_weight_shard, cast_w1_weight_shard], dim=0) + cast_w31_weight_shard = self._maybe_padding_shape( + cast_w31_weight_shard, dst_w3_w1_weight) + dst_w3_w1_weight.copy_(cast_w31_weight_shard, non_blocking=True) + + def load_expert_w2_weight(self, module: torch.nn.Module, + w2_weight: torch.Tensor, + dst_w2_weight: torch.Tensor): + """Load and pad w2 weight for each expert, to match shape requirements for Cutlass nvfp4 alignment.""" + device = dst_w2_weight.device + w2_weight_shard = load_weight_shard(w2_weight, + module.tp_size, + module.tp_rank, + TensorParallelMode.ROW, + device=device) + cast_w2_weight_shard = w2_weight_shard.view(dst_w2_weight.dtype) + cast_w2_weight_shard = self._maybe_padding_shape( + cast_w2_weight_shard, dst_w2_weight) + dst_w2_weight.copy_(cast_w2_weight_shard, non_blocking=True) + + def _maybe_padding_shape(self, source_tensor, dst_tensor): + """Pad the source tensor to match the shape of the destination tensor.""" + # In `get_weights_shapes` method, the shape of `weights` and `weight_scales` might be tuned to align with `NVFP4_ROW_ALIGNMENT`. + # Padding the `source_tensor` to match the shape of `dst_tensor` here. + assert len(source_tensor.shape) == 2 and len( + dst_tensor.shape) == 2, "Only support 2D weights padding for now." + dst_row, dst_col = dst_tensor.shape + _row, _col = source_tensor.shape + if _row != dst_row or _col != dst_col: + source_tensor = torch.nn.functional.pad( + source_tensor, (0, dst_col - _col, 0, dst_row - _row), + "constant", 0).contiguous() + return source_tensor + class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod): diff --git a/tests/integration/test_lists/test-db/l0_a10.yml b/tests/integration/test_lists/test-db/l0_a10.yml index 1b79caaec78..b7475e2da88 100644 --- a/tests/integration/test_lists/test-db/l0_a10.yml +++ b/tests/integration/test_lists/test-db/l0_a10.yml @@ -19,6 +19,8 @@ l0_a10: - unittest/utils/test_util.py - unittest/_torch/modeling/test_modeling_mistral.py - unittest/_torch/modeling/test_modeling_pixtral.py + - unittest/_torch/modules/test_fused_moe.py::test_nvfp4_cutlass_get_weights_shapes + - unittest/_torch/modules/test_fused_moe.py::test_nvfp4_cutlass_get_weights_shapes_error_cases - unittest/_torch/sampler/test_trtllm_sampler.py # NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no # test list either). diff --git a/tests/unittest/_torch/modules/test_fused_moe.py b/tests/unittest/_torch/modules/test_fused_moe.py index 14acccf452f..b56ec9e8deb 100644 --- a/tests/unittest/_torch/modules/test_fused_moe.py +++ b/tests/unittest/_torch/modules/test_fused_moe.py @@ -37,6 +37,8 @@ BaseMoeRoutingMethod, CutlassFusedMoE, TRTLLMGenFusedMoE, DefaultMoeRoutingMethod, RenormalizeMoeRoutingMethod, TritonFusedMoE, create_moe, WideEPMoE) +from tensorrt_llm._torch.modules.fused_moe.quantization import \ + NVFP4CutlassFusedMoEMethod # isort: on from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import \ IS_TRITON_KERNELS_AVAILABLE @@ -2735,3 +2737,123 @@ def load_weights(self, weights: List[Dict]): self.experts[expert].gate_up_proj.load_weights(gate_up_proj_weights) self.experts[expert].down_proj.load_weights(down_proj_weights) + + +# Create a mock module with required attributes for NVFP4CutlassFusedMoEMethod.get_weights_shapes test. +class MockModule: + + def __init__(self, hidden_size, intermediate_size, expand_ratio, + expert_size, bias): + self.hidden_size = hidden_size + self.intermediate_size_per_partition = intermediate_size + self.intermediate_size_expand_ratio = expand_ratio + self.expand_intermediate_size_per_partition = intermediate_size * self.intermediate_size_expand_ratio + self.expert_size_per_partition = expert_size + self.bias = bias + # Constants for NVFP4. + self.scaling_vector_size = 16 # Standard for NVFP4 + self.weight_vec_size = 16 # 16 fp4 values packed into int64 + self.block_scales_vec_size = 4 # 4 fp8 values packed into int32 + + +def test_nvfp4_cutlass_get_weights_shapes_error_cases(): + """Test NVFP4CutlassFusedMoEMethod.get_weights_shapes for error cases.""" + method = NVFP4CutlassFusedMoEMethod() + module = MockModule(hidden_size=13, + intermediate_size=16, + expand_ratio=1, + expert_size=4, + bias=False) + with pytest.raises(ValueError, + match="hidden_size 13 must be divisible by 4"): + method.get_weights_shapes(module, module.weight_vec_size, + module.block_scales_vec_size) + + +@pytest.mark.parametrize( + "hidden_size, intermediate_size, expand_ratio, expert_size, bias", [ + (512, 1024, 1, 32, True), + (512, 1024, 2, 32, True), + (256, 512, 1, 16, False), + (256, 512, 2, 16, False), + (128, 120, 1, 8, False), + (128, 120, 2, 8, False), + (128, 120, 1, 8, True), + (128, 120, 2, 8, True), + ]) +def test_nvfp4_cutlass_get_weights_shapes(hidden_size, intermediate_size, + expand_ratio, expert_size, bias): + """Test NVFP4CutlassFusedMoEMethod.get_weights_shapes for alignment requirements.""" + module = MockModule(hidden_size=hidden_size, + intermediate_size=intermediate_size, + expand_ratio=expand_ratio, + expert_size=expert_size, + bias=bias) + method = NVFP4CutlassFusedMoEMethod() + NVFP4_ROW_ALIGNMENT = method.NVFP4_ROW_ALIGNMENT + + # Get weight shapes + (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, w2_bias_shape, + w3_w1_weight_scale_shape, + w2_weight_scale_shape) = method.get_weights_shapes( + module, module.weight_vec_size, module.block_scales_vec_size) + + # Calculate expected aligned sizes + intermediate_size_expand = intermediate_size * module.intermediate_size_expand_ratio + intermediate_size_expand_aligned = ( + (intermediate_size_expand + NVFP4_ROW_ALIGNMENT - 1) // + NVFP4_ROW_ALIGNMENT * NVFP4_ROW_ALIGNMENT) + hidden_size_aligned = hidden_size + + expected_w3_w1_weight_shape = (expert_size, + intermediate_size_expand_aligned, + hidden_size_aligned // + module.weight_vec_size) + assert w3_w1_weight_shape == expected_w3_w1_weight_shape, ( + f"w3_w1_weight_shape mismatch: got {w3_w1_weight_shape}, " + f"expected {expected_w3_w1_weight_shape}") + + expected_w2_weight_shape = (expert_size, hidden_size_aligned, + intermediate_size_expand_aligned // + module.intermediate_size_expand_ratio // + module.weight_vec_size) + assert w2_weight_shape == expected_w2_weight_shape, ( + f"w2_weight_shape mismatch: got {w2_weight_shape}, " + f"expected {expected_w2_weight_shape}") + + expected_w3_w1_weight_scale_shape = (expert_size, + intermediate_size_expand_aligned, + hidden_size_aligned // + module.scaling_vector_size // + module.block_scales_vec_size) + assert w3_w1_weight_scale_shape == expected_w3_w1_weight_scale_shape, ( + f"w3_w1_weight_scale_shape mismatch: got {w3_w1_weight_scale_shape}, " + f"expected {expected_w3_w1_weight_scale_shape}") + + expected_w2_weight_scale_shape = (expert_size, hidden_size_aligned, + intermediate_size_expand_aligned // + module.intermediate_size_expand_ratio // + module.scaling_vector_size // + module.block_scales_vec_size) + assert w2_weight_scale_shape == expected_w2_weight_scale_shape, ( + f"w2_weight_scale_shape mismatch: got {w2_weight_scale_shape}, " + f"expected {expected_w2_weight_scale_shape}") + + # Verify bias shapes + if bias: + expected_w3_w1_bias_shape = (expert_size, + intermediate_size_expand_aligned) + expected_w2_bias_shape = (expert_size, hidden_size_aligned) + assert w3_w1_bias_shape == expected_w3_w1_bias_shape, ( + f"w3_w1_bias_shape mismatch: got {w3_w1_bias_shape}, " + f"expected {expected_w3_w1_bias_shape}") + assert w2_bias_shape == expected_w2_bias_shape, ( + f"w2_bias_shape mismatch: got {w2_bias_shape}, " + f"expected {expected_w2_bias_shape}") + else: + assert w3_w1_bias_shape is None, f"Expected None for w3_w1_bias_shape, got {w3_w1_bias_shape}" + assert w2_bias_shape is None, f"Expected None for w2_bias_shape, got {w2_bias_shape}" + + assert intermediate_size_expand_aligned % NVFP4_ROW_ALIGNMENT == 0, ( + f"intermediate_size_expand_aligned {intermediate_size_expand_aligned} " + f"not aligned to {NVFP4_ROW_ALIGNMENT}")