From c063911384fed32d4cfc0f61aa8ce739f669f691 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Thu, 9 Oct 2025 20:34:31 -0700 Subject: [PATCH 1/4] Add flashinfer_cutedsl grouped gemm Signed-off-by: Shu Wang --- tests/kernels/moe/test_cutedsl_moe.py | 527 ++++++++++++++++++ vllm/envs.py | 11 +- .../fused_moe/deepep_ll_prepare_finalize.py | 46 +- .../fused_moe/flashinfer_cutedsl_moe.py | 367 ++++++++++++ .../layers/quantization/modelopt.py | 12 +- .../quantization/utils/flashinfer_fp4_moe.py | 41 +- .../quantization/utils/flashinfer_utils.py | 18 +- .../quantization/utils/nvfp4_moe_support.py | 6 +- vllm/utils/flashinfer.py | 41 ++ 9 files changed, 1031 insertions(+), 38 deletions(-) create mode 100644 tests/kernels/moe/test_cutedsl_moe.py create mode 100644 vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py diff --git a/tests/kernels/moe/test_cutedsl_moe.py b/tests/kernels/moe/test_cutedsl_moe.py new file mode 100644 index 000000000000..f3c7b281b086 --- /dev/null +++ b/tests/kernels/moe/test_cutedsl_moe.py @@ -0,0 +1,527 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import pytest +import torch +from flashinfer import fp4_quantize +from torch.nn import functional as F + +from vllm.model_executor.layers.activation import SiluAndMul +from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + flashinfer_cutedsl_moe_masked, + scaled_fp4_grouped_quant, +) +from vllm.utils.flashinfer import ( + flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked, +) + +if torch.cuda.get_device_capability() < (10, 0): + pytest.skip( + reason="Nvfp4 Requires compute capability of 10 or above.", + allow_module_level=True, + ) + +kE2M1ToFloat = torch.tensor( + [0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32 +) + +FLOAT8_E4M3_MAX = 448.0 +FLOAT4_E2M1_MAX = 6.0 + + +def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size): + m_tiles = (m + 128 - 1) // 128 + f = block_size * 4 + k_tiles = (k + f - 1) // f + tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4)) + tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5)) + out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size) + return out[0:m, 0:k] + + +def dequantize_nvfp4_to_dtype( + tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16 +): + """Dequantize the fp4 tensor back to high precision.""" + # Two fp4 values are packed into one uint8. + assert tensor_fp4.dtype == torch.uint8 + m, packed_k = tensor_fp4.shape + k = packed_k * 2 + tensor_f32 = break_fp4_bytes(tensor_fp4, dtype) + tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size) + tensor_sf = tensor_sf.view(torch.float8_e4m3fn) + tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size) + tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale + + # scale the tensor + out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k) + return out.to(dtype=dtype) + + +def break_fp4_bytes(a, dtype): + assert a.dtype == torch.uint8 + m, n = a.shape + + # Vectorized nibble processing + a_flat = a.flatten() + high = (a_flat & 0xF0) >> 4 # Upper nibbles + low = a_flat & 0x0F # Lower nibbles + + # Combine nibbles for batch processing + combined = torch.stack((low, high), dim=1).flatten() + + # Vectorized sign and magnitude extraction + signs = (combined & 0x08).to(torch.bool) # Sign bits + abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices + + # Device-aware lookup and sign application + kE2M1 = kE2M1ToFloat.to(device=a.device) + values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0) + + # Reshape to final form + return values.reshape(m, n * 2).to(dtype=dtype) + + +def generate_balanced_routing( + hidden_states: torch.Tensor, num_experts: int, top_k: int +): + """ + Generate routing weights and topk indices such that every expert is active. + Returns routing_weights, topk_idx + """ + + num_tokens, hidden_dim = hidden_states.shape + # num_tokens = batch_size * seq_len + + # First, assign at least one token per expert + tokens_per_expert = torch.arange(num_tokens) % num_experts + tokens_per_expert = tokens_per_expert[torch.randperm(num_tokens)] # shuffle + + # Each token has top_k experts — start with one guaranteed expert + topk_idx = torch.full((num_tokens, top_k), -1, dtype=torch.long) + topk_idx[:, 0] = tokens_per_expert + + # For remaining top_k - 1 experts, pick randomly (allowing repeats) + if top_k > 1: + random_choices = torch.randint(0, num_experts, (num_tokens, top_k - 1)) + topk_idx[:, 1:] = random_choices + + # Normalize routing weights so each token's weights sum to 1 + routing_weights = torch.rand(num_tokens, top_k) + routing_weights /= routing_weights.sum(dim=-1, keepdim=True) + + # Reshape back if needed + routing_weights = routing_weights.view(num_tokens, top_k) + topk_idx = topk_idx.view(num_tokens, top_k) + + return routing_weights, topk_idx + + +def prepare_inputs( + hidden_states: torch.Tensor, + router_logits: torch.Tensor, + num_experts: int, + topk: int, +): + routing_weights, topk_idx = generate_balanced_routing( + router_logits, num_experts, topk + ) + + masked_m = [] + for i in range(num_experts): + mask = topk_idx.view(-1) == i + masked_m.append(mask.sum()) + + masked_m = torch.tensor(masked_m, dtype=torch.int32) + hidden_states_3d = torch.empty( + (num_experts, max(masked_m), hidden_states.shape[1]), dtype=hidden_states.dtype + ) + for i in range(num_experts): + hidden_states_3d[i, : masked_m[i], :] = hidden_states[topk_idx.view(-1) == i] + + return hidden_states_3d, masked_m, topk_idx, routing_weights + + +MNK_FACTORS = [ + (2, 1024, 1024), + (2, 1024, 1536), + (2, 3072, 1024), + (2, 3072, 1536), + (64, 1024, 1024), + (64, 1024, 1536), + (64, 3072, 1024), + (64, 2048, 1024), + (224, 1024, 1024), + (224, 1024, 1536), +] + + +# Reference implementation of torch_moe +def torch_moe(a, w1, w2, score, topk, expert_map): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + score = torch.softmax(score, dim=-1, dtype=torch.float32) + topk_weight, topk_ids = torch.topk(score, topk) + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + if expert_map is not None: + topk_ids = expert_map[topk_ids] + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + out[mask] = SiluAndMul()(a[mask] @ w1[i].transpose(0, 1)) @ w2[i].transpose( + 0, 1 + ) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids): + B, D = a.shape + a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device) + + topk_weight = topk_weight.view(-1) + topk_ids = topk_ids.view(-1) + + for i in range(w1.shape[0]): + mask = topk_ids == i + if mask.sum(): + m = w1[i].shape[0] + assert m % 2 == 0 + # Note: w1 and w3 are swapped! + w3_expert, w1_expert = w1[i][m // 2 :, :], w1[i][: m // 2, :] + inter = F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t()) + inter_gs = torch.tensor(1.0).cuda() + inter_q, inter_blockscale = fp4_quantize(inter, inter_gs) + inter = dequantize_nvfp4_to_dtype( + inter_q, + inter_blockscale, + inter_gs, + dtype=inter.dtype, + device=inter.device, + block_size=16, + ).cuda() + out[mask] = inter @ w2[i].transpose(0, 1) + return ( + out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype) + ).sum(dim=1) + + +def flashinfer_cutedsl_grouped_gemm_nt_masked( + hidden_states: torch.Tensor, # 3d + input_global_scale: torch.Tensor, # (l,) + weights: torch.Tensor, + w_global_scale: torch.Tensor, # (l,) + masked_m: torch.Tensor, +): + # hidden_states: [l, m, k] + # weights: [l, n, k] + aq, aq_sf = scaled_fp4_grouped_quant( + hidden_states, + input_global_scale, + masked_m.to(hidden_states.device), + ) + num_experts, n, k = weights.shape + bq, bq_sf = scaled_fp4_grouped_quant( + weights, + w_global_scale, + torch.full((num_experts,), n, device=weights.device, dtype=torch.int32), + ) + + out = torch.zeros( + (num_experts, max(masked_m), n), dtype=weights.dtype, device=aq.device + ) + out = out.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + c_dtype = "bfloat16" + alpha = 1.0 / (input_global_scale * w_global_scale).to(out.dtype).view( + 1, 1, num_experts + ) + + def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + cutedsl_gmm_masked( + (aq, aq_sf), + (bq, bq_sf), + out, + masked_m.to(aq.device), + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=alpha, + alpha_dtype=get_cute_dtype(alpha), + ) + + return out + + +@pytest.mark.parametrize("bs, hidden_dim, inter_dim", [(2, 128, 256), (16, 128, 512)]) +@pytest.mark.parametrize("topk", [1, 2, 4]) +@torch.inference_mode() +def test_flashinfer_cutedsl_moe_masked( + bs: int, hidden_dim: int, inter_dim: int, topk: int +): + torch.manual_seed(42) + device = "cuda" + num_experts = 8 + hidden_states = ( + torch.randn(bs, hidden_dim, dtype=torch.bfloat16, device=device) / 5.0 + ) + w1 = ( + torch.randn( + num_experts, 2 * inter_dim, hidden_dim, dtype=torch.bfloat16, device=device + ) + / 10.0 + ) + w2 = ( + torch.randn( + num_experts, hidden_dim, inter_dim, dtype=torch.bfloat16, device=device + ) + / 10.0 + ) + router_logits = torch.randn(bs, num_experts, dtype=torch.float32) + + hidden_states_expanded = ( + hidden_states.view(bs, -1, hidden_dim) + .repeat(1, topk, 1) + .reshape(-1, hidden_dim) + ) + hidden_states_3d, masked_m, topk_idx, routing_weights = prepare_inputs( + hidden_states_expanded, router_logits, num_experts, topk + ) + + w1_amax = w1.abs().amax(dim=(1, 2)).to(torch.float32).to(w1.device) + w2_amax = w2.abs().amax(dim=(1, 2)).to(torch.float32).to(w2.device) + input_global_scale = torch.ones( + (num_experts,), dtype=torch.float32, device=hidden_states.device + ) + + w1_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax + w2_global_scale = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax + a2_global_scale = torch.ones( + (num_experts,), dtype=torch.float32, device=hidden_states.device + ) # assume intermediate scale is 1.0 + + w1_fp4, w1_blockscale = scaled_fp4_grouped_quant( + w1, + w1_global_scale, + torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim, + ) + w2_fp4, w2_blockscale = scaled_fp4_grouped_quant( + w2, + w2_global_scale, + torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim, + ) + + w1_alpha = 1.0 / (input_global_scale * w1_global_scale) + w2_alpha = 1.0 / (a2_global_scale * w2_global_scale) + + out = torch.empty_like(hidden_states_3d) + # Note: the 1st dim shouldn't be bs + wk = torch.empty( + num_experts, + hidden_states_3d.shape[1], + inter_dim * 2, + dtype=hidden_states_3d.dtype, + device=hidden_states.device, + ) + flashinfer_cutedsl_moe_masked( + hidden_states_3d.to(hidden_states.device), + input_global_scale, + w1_fp4.permute(2, 0, 1), + w1_blockscale, + w1_alpha, + w2_fp4.permute(2, 0, 1), + a2_global_scale, + w2_blockscale, + w2_alpha, + masked_m.to(hidden_states.device), + wk, + out, + ) + + # reference + a_fp4, a_scale_interleaved = fp4_quantize(hidden_states, input_global_scale) + a_in_dtype = dequantize_nvfp4_to_dtype( + a_fp4, + a_scale_interleaved, + input_global_scale, + dtype=hidden_states.dtype, + device=hidden_states.device, + block_size=16, + ) + w1_d = torch.empty( + (num_experts, 2 * inter_dim, hidden_dim), device=w1.device, dtype=w1.dtype + ) + w2_d = torch.empty( + (num_experts, hidden_dim, inter_dim), device=w2.device, dtype=w2.dtype + ) + + for idx in range(0, num_experts): + w1_fp4_sliced, w1_blockscale_sliced = fp4_quantize( + w1[idx], w1_global_scale[idx] + ) + w2_fp4_sliced, w2_blockscale_sliced = fp4_quantize( + w2[idx], w2_global_scale[idx] + ) + w1_d[idx] = dequantize_nvfp4_to_dtype( + w1_fp4_sliced, + w1_blockscale_sliced, + w1_global_scale[idx], + dtype=w1.dtype, + device=w1.device, + block_size=16, + ) + w2_d[idx] = dequantize_nvfp4_to_dtype( + w2_fp4_sliced, + w2_blockscale_sliced, + w2_global_scale[idx], + dtype=w2.dtype, + device=w2.device, + block_size=16, + ) + + ref_output = torch_moe_nvfp4( + a_in_dtype, + w1_d, + w2_d, + topk, + routing_weights.to(a_in_dtype.device), + topk_idx.to(a_in_dtype.device), + ) + out_weighted = torch.zeros_like(ref_output, device=out.device, dtype=out.dtype) + + positions = torch.nonzero(masked_m[topk_idx], as_tuple=False) + rows, cols = positions[:, 0], positions[:, 1] + experts = topk_idx[rows, cols] + for i in range(num_experts): + mask = experts == i + if mask.any(): + idx = torch.nonzero(mask, as_tuple=False).squeeze(-1) + r, c = rows[idx], cols[idx] + out_weighted[r] += out[i, : len(r), :] * routing_weights[r, c].to( + out.device + ).unsqueeze(-1) + torch.testing.assert_close( + out_weighted.cpu(), ref_output.cpu(), atol=1e-1, rtol=1e-1 + ) + + +@pytest.mark.parametrize( + "bs, hidden_dim, inter_dim, topk", [(2, 128, 256, 2), (16, 128, 512, 5)] +) +@torch.inference_mode() +def test_grouped_gemm_nt_masked( + bs: int, hidden_dim: int, inter_dim: int, topk: int +) -> None: + torch.manual_seed(42) + B = bs + D = hidden_dim + N = inter_dim + # CuteDSL group gemm has issue when not all experts are active. + # i.e. masked = [2, 3, 0, 0, 1] where the 2nd and 3rd experts are inactive + # see https://github.com/flashinfer-ai/flashinfer/issues/1856 + num_experts = bs + hidden_states = torch.randn(B, D, dtype=torch.bfloat16, device="cuda") + weights = torch.randn(num_experts, N, D, dtype=torch.bfloat16, device="cuda") + router_logits = torch.randn(B, num_experts, dtype=torch.float32) + + hidden_states_expanded = ( + hidden_states.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D) + ) + hidden_states_3d, masked_m, topk_idx, _ = prepare_inputs( + hidden_states_expanded, router_logits, num_experts, topk + ) + + # reference + out = torch.zeros( + (B * topk, weights.shape[1]), dtype=weights.dtype, device=weights.device + ) + for i in range(num_experts): + mask = topk_idx.view(-1) == i + if mask.sum(): + lhs = hidden_states_expanded[mask] + rhs = weights[i] + a_amax = lhs.abs().max().to(torch.float32).to(hidden_states.device) + b_amax = rhs.abs().max().to(torch.float32).to(weights.device) + a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax + b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax + + lhsq, lhsq_sf = fp4_quantize( + lhs, + a_gs, + ) + rhsq, rhsq_sf = fp4_quantize( + rhs, + b_gs, + ) + + lhs_in_dtype = dequantize_nvfp4_to_dtype( + lhsq, + lhsq_sf, + a_gs, + dtype=hidden_states.dtype, + device=hidden_states.device, + block_size=16, + ) + + rhs_in_dtype = dequantize_nvfp4_to_dtype( + rhsq, + rhsq_sf, + b_gs, + dtype=hidden_states.dtype, + device=hidden_states.device, + block_size=16, + ) + out[mask] = lhs_in_dtype @ rhs_in_dtype.t() + + a_amax = ( + hidden_states_3d.abs() + .amax(dim=(1, 2)) + .to(torch.float32) + .to(hidden_states.device) + ) + b_amax = weights.abs().amax(dim=(1, 2)).to(torch.float32).to(weights.device) + a_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / a_amax + b_gs = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / b_amax + out_flashinfer = flashinfer_cutedsl_grouped_gemm_nt_masked( + hidden_states_3d.to(hidden_states.device), a_gs, weights, b_gs, masked_m + ) + + # re-pack out into [num_experts, max_m, n] + out_ref = torch.zeros( + (num_experts, max(masked_m), weights.shape[1]), dtype=out.dtype + ) + expert_slot = [0] * num_experts + for i, expert_id in enumerate(topk_idx.view(-1).tolist()): + out_ref[expert_id, expert_slot[expert_id], :] = out[i] + expert_slot[expert_id] += 1 + + # Note: just to compare the masked position due to cutedsl may write nan + # into unmasked position. + for i in range(num_experts): + torch.testing.assert_close( + out_flashinfer.permute(2, 0, 1)[i, : masked_m[i]], + out_ref.to(out_flashinfer.device)[i, : masked_m[i]], + atol=1e-1, + rtol=1e-1, + ) + + +if __name__ == "__main__": + test_flashinfer_cutedsl_moe_masked(16, 128, 512, 4) + test_grouped_gemm_nt_masked(16, 128, 512, 4) diff --git a/vllm/envs.py b/vllm/envs.py index d93ae8b9c225..c4b35a3c115e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -156,7 +156,9 @@ VLLM_USE_FLASHINFER_MOE_FP16: bool = False VLLM_USE_FLASHINFER_MOE_FP8: bool = False VLLM_USE_FLASHINFER_MOE_FP4: bool = False - VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput" + VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "cutedsl"] = ( + "throughput" + ) VLLM_XGRAMMAR_CACHE_MB: int = 0 VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256 VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False @@ -1051,6 +1053,9 @@ def get_vllm_port() -> int | None: "VLLM_MARLIN_USE_ATOMIC_ADD", "0" ) == "1", + "VLLM_DEEPEPLL_BF16_DISPATCH": lambda: bool( + int(os.getenv("VLLM_DEEPEPLL_BF16_DISPATCH", "0")) + ), # Whether to use marlin kernel in mxfp4 quantization method "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( os.environ.get("VLLM_MXFP4_USE_MARLIN", None) @@ -1199,7 +1204,9 @@ def get_vllm_port() -> int | None: # - "latency": # Uses TensorRT-LLM kernels optimized for low-latency inference. "VLLM_FLASHINFER_MOE_BACKEND": env_with_choices( - "VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"] + "VLLM_FLASHINFER_MOE_BACKEND", + "throughput", + ["throughput", "latency", "cutedsl"], ), # Control the maximum number of tokens per expert supported by the # NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index b3ba2e308953..450ac57e2b3b 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -6,6 +6,8 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( TopKWeightAndReduceDelegate, @@ -24,6 +26,8 @@ DEEPEP_QUANT_BLOCK_SIZE = 128 DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE] +logger = init_logger(__name__) + def dequant_fp8( expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor @@ -110,21 +114,31 @@ def _do_quant( assert isinstance(x, torch.Tensor) num_experts, max_tokens, hidden_dim = x.size() - - # TODO (varun): Optimization - Use a batched version of quant - x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input( - x, - quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) - x = x.view((num_experts, -1, hidden_dim)) - - if quant_config.quant_dtype is not None: - assert x_scales is not None - x_scales = normalize_batched_scales_shape(x_scales, num_experts) + if not envs.VLLM_DEEPEPLL_BF16_DISPATCH: + # TODO (varun): Optimization - Use a batched version of quant + x = x.view((-1, hidden_dim)) + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + quant_config.quant_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + x = x.view((num_experts, -1, hidden_dim)) + + if quant_config.quant_dtype is not None: + assert x_scales is not None + x_scales = normalize_batched_scales_shape(x_scales, num_experts) + else: + # BF16 dispatch path - no quantization + # TODO(shuw@nvidia.com): enable nvfp4 dispatch once DEEPEP is ready. + logger.info_once("Using BF16 dispatch path for DeepEPLLPrepareAndFinalize") + assert x.dtype == torch.bfloat16, ( + "BF16 dispatch requires input to be in BF16" + ) + x_scales = None + x = x.view((num_experts, -1, hidden_dim)) + # print(f"after deepepll: x.shape = {x.shape}") return x, x_scales @@ -262,6 +276,8 @@ def _finalize( # TODO (varun) : Enable zero copy mode dbo_maybe_run_recv_hook() + # print("xxx"*100, fused_expert_output.shape) + # print("ttt"*100, fused_expert_output.dtype) _, _, recv_hook = self.buffer.low_latency_combine( fused_expert_output, topk_ids, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py new file mode 100644 index 000000000000..096a6c8dbf49 --- /dev/null +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from typing import Optional + +import torch + +import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger +from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig +from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( + TopKWeightAndReduceDelegate, +) +from vllm.utils.flashinfer import ( + flashinfer_cutedsl_grouped_gemm_nt_masked, + has_flashinfer_cutedsl_grouped_gemm_nt_masked, + nvfp4_batched_quantize, + silu_and_mul, +) + +logger = init_logger(__name__) + + +def is_valid_flashinfer_cutedsl_fused_moe( + hidden_states: torch.Tensor, w1: torch.Tensor, w2: torch.Tensor +) -> bool: + """ + Check if the given problem size is supported by the FlashInfer CuteDSL MoE + kernel. + """ + if not has_flashinfer_cutedsl_grouped_gemm_nt_masked(): + logger.debug_once( + "FlashInferCuteDSLExperts disabled: " + "flashinfer_cutedsl_fused_moe not available." + ) + return False + # Data type checks + if ( + w1.dtype != torch.uint8 + or w2.dtype != torch.uint8 + or hidden_states.dtype not in [torch.float32, torch.float16, torch.bfloat16] + ): + logger.debug_once( + "FlashInferCuteDSLExperts disabled: w1/w2 must be torch.uint8 " + f"(got w1={w1.dtype}, w2={w2.dtype}), hidden_states must be " + f"float32, float16, or bfloat16 (got {hidden_states.dtype})." + ) + return False + return True + + +class FlashInferCuteDSLExperts(mk.FusedMoEPermuteExpertsUnpermute): + def __init__( + self, + out_dtype: torch.dtype, + quant_config: FusedMoEQuantConfig, + ): + super().__init__(quant_config) + assert quant_config.quant_dtype == "nvfp4", ( + "Only nvfp4 quantization are currently supported." + ) + self.out_dtype = out_dtype + + @property + def activation_formats( + self, + ) -> tuple[mk.FusedMoEActivationFormat, mk.FusedMoEActivationFormat]: + return ( + mk.FusedMoEActivationFormat.BatchedExperts, + mk.FusedMoEActivationFormat.BatchedExperts, + ) + + def supports_expert_map(self) -> bool: + return False + + def supports_chunking(self) -> bool: + # This refers to TP chunking; DP chunking is handled separately. + # TODO(shuw@nvidia.com): Set to False to be consistent with + # batched_deep_gemm_moe + return False + + def finalize_weight_and_reduce_impl(self) -> mk.TopKWeightAndReduce: + # Let PrepareAndFinalize::finalize() decide the impl. + return TopKWeightAndReduceDelegate() + + def workspace_shapes( + self, + M: int, + N: int, + K: int, + topk: int, + global_num_experts: int, + local_num_experts: int, + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: + # We use global_num_experts due to how moe_align_block_size handles + # expert_maps. + """ + Compute the shapes for the temporary and final outputs of the two gemms + and activation in the fused expert function. Since the gemms are + independent, the workspace for the first gemm can be shared with the + workspace for the last gemm. + + Returns a tuple of: + - workspace13 shape tuple: must be large enough to hold the + result of either expert gemm. + - workspace2 shape tuple: must be large enough to hold the + result of the activation function. + - output shape tuple: must be exact size of the final gemm output. + - Workspace type: The dtype to use for the workspace tensors. + - Note: in order for activation chunking to work, the first dimension + of each tuple must be the number of tokens. + """ + # assert a.dim() == 2 + # assert aq.dim() == 3 + # output_shape = aq.shape + # workspace_dtype = a.dtype + # E = aq.size(0) + # workspace2 = (E, M, N) + # workspace1 = output_shape + output_shape = (local_num_experts, M, K) + workspace2 = (local_num_experts, M, N) + workspace1 = output_shape + # The workspace is determined by `aq`, since it comes after any + # potential communication op and is involved in the expert computation. + return (workspace1, workspace2, output_shape) + + def apply( + self, + output: torch.Tensor, + hidden_states: torch.Tensor, + w1: torch.Tensor, + w2: torch.Tensor, + topk_weights: torch.Tensor, + topk_ids: torch.Tensor, + activation: str, + global_num_experts: int, + expert_map: Optional[torch.Tensor], + a1q_scale: Optional[torch.Tensor], + a2_scale: Optional[torch.Tensor], # Not used + workspace13: Optional[torch.Tensor], + workspace2: Optional[torch.Tensor], + expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + apply_router_weight_on_input: Optional[bool], + ): + assert self.quant_dtype == "nvfp4", ( + "Only nvfp4 quantization are currently supported." + ) + # Ensure w1_scale and w2_scale are not None before calling view + assert self.w1_scale is not None and self.w2_scale is not None, ( + "w1_scale and w2_scale must not be None for FlashInferExperts" + ) + assert expert_tokens_meta is not None + expert_num_tokens = expert_tokens_meta.expert_num_tokens + assert hidden_states.ndim == 3 + assert self.w1_scale.ndim == 3 + assert self.w2_scale.ndim == 3 + + flashinfer_cutedsl_moe_masked( + hidden_states=hidden_states, + input_global_scale=self.a1_gscale, + w1=w1, + w1_blockscale=self.w1_scale, + w1_alpha=self.g1_alphas, + w2=w2, + a2_global_scale=self.a2_gscale, + w2_blockscale=self.w2_scale, + w2_alpha=self.g2_alphas, + masked_m=expert_num_tokens, + workspace=workspace2, + out=output, + ) + + +def get_cute_dtype(input: torch.Tensor) -> str: + if input.dtype == torch.bfloat16: + return "bfloat16" + elif input.dtype == torch.float16: + return "float16" + elif input.dtype == torch.float32: + return "float32" + else: + raise ValueError(f"Unsupported cute dtype {input.dtype}") + + +def scaled_fp4_grouped_quant( + input_tensor: torch.Tensor, + input_global_scale: torch.Tensor, + mask: torch.Tensor, +): + """ + Wrapper around nvfp4_batched_quantize + + Args: + input_tensor (Tensor): + - Shape (l, m, k) + input_global_scale (Tensor): Shape (l,) + mask (Tensor): Mask tensor, broadcastable + + Returns: + output (Tensor): Quantized tensor, logical shape (m, k//2, l) + output_scales (Tensor): Blockscale tensor, logical shape + (32, 4, rm, 4, rk, l) + """ + num_experts, m, k = input_tensor.shape + + sf_vec_size = 16 + assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." + + scale_k = k // sf_vec_size + padded_k = (scale_k + (4 - 1)) // 4 * 4 + padded_m = (m + (128 - 1)) // 128 * 128 + + aq, aq_sf = nvfp4_batched_quantize( + input_tensor, + input_global_scale, + mask=mask, + ) + + # --- re-layout quantized tensor --- + # physical (l, m, k//2) -> logical (m, k//2, l) + output = aq.permute(1, 2, 0) + + # --- re-layout blockscales --- + # physical (l, rm, rk, 32, 4, 4) -> logical (32, 4, rm, 4, rk, l) + output_scales = aq_sf.view(torch.float8_e4m3fn).view( + num_experts, padded_m // 128, padded_k // 4, 32, 4, 4 + ) + output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) + + return output, output_scales + + +def flashinfer_cutedsl_moe_masked( + hidden_states: torch.Tensor, + input_global_scale: torch.Tensor, + w1: torch.Tensor, + w1_blockscale: torch.Tensor, + w1_alpha, + w2: torch.Tensor, + a2_global_scale: torch.Tensor, + w2_blockscale: torch.Tensor, + w2_alpha, + masked_m: torch.Tensor, + workspace: torch.Tensor, + out: torch.Tensor, +): + """ + Perform masked Mixture-of-Experts computation with FlashInfer's CuteDSL + kernels. + + Args: + hidden_states (torch.Tensor): [num_experts, m, k], bf16 + input_global_scale (torch.Tensor): (l,) + w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 + w1_blockscale (torch.Tensor): blockscale factors, e4m3, + w1_alpha (torch.Tensor): (l,) + w2 (torch.Tensor): fp4 weights, [l, k, n // 2], uint8 + a2_global_scale (torch.Tensor): (l,) + w2_blockscale (torch.Tensor): blockscale factors, e4m3, + w2_alpha (torch.Tensor): (l,) + masked_m (torch.Tensor): Masked dimension indices + workspace (torch.Tensor): For gateup_output + + Notes: + - Assumes max(masked_m) <= m. + """ + + # === Assertions on dtypes === + assert input_global_scale.dtype == torch.float32, ( + f"input_global_scale must be float32, got {input_global_scale.dtype}" + ) + assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" + assert w1_blockscale.dtype == torch.float8_e4m3fn, ( + f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" + ) + assert w1_alpha.dtype == torch.float32, ( + f"w1_alpha must be float32, got {w1_alpha.dtype}" + ) + assert w2.dtype == torch.uint8, f"w2 must be uint8, got {w2.dtype}" + assert a2_global_scale.dtype == torch.float32, ( + f"a2_global_scale must be float32, got {a2_global_scale.dtype}" + ) + assert w2_blockscale.dtype == torch.float8_e4m3fn, ( + f"w2_blockscale must be float8_e4m3fn, got {w2_blockscale.dtype}" + ) + assert w2_alpha.dtype == torch.float32, ( + f"w2_alpha must be float32, got {w2_alpha.dtype}" + ) + + # === Assertions on shapes === + n = w2.shape[-1] * 2 # intermediate dimension + num_experts, m, k = hidden_states.shape + + assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" + assert w1.shape[-1] * 2 == k, ( + f"w1 last dim * 2 must equal k, got {w1.shape[-1]} vs k={k}" + ) + assert w2.shape[-2:] == ( + k, + n // 2, + ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" + + assert input_global_scale.shape == (num_experts,), ( + f"input_global_scale must be (l,), got {input_global_scale.shape}" + ) + assert w1_alpha.shape == (num_experts,), ( + f"w1_alpha must be (l,), got {w1_alpha.shape}" + ) + assert a2_global_scale.shape == (num_experts,), ( + f"a2_global_scale must be (l,), got {a2_global_scale.shape}" + ) + assert w2_alpha.shape == (num_experts,), ( + f"w2_alpha must be (l,), got {w2_alpha.shape}" + ) + + aq, aq_sf = scaled_fp4_grouped_quant( + hidden_states, + input_global_scale, + masked_m, + ) + + workspace = workspace.permute(1, 2, 0) # requirement of kernel + sf_vec_size = 16 + assert aq_sf.dtype == torch.float8_e4m3fn + assert aq.dtype == torch.uint8 + ab_dtype = "float4_e2m1fn" + sf_dtype = "float8_e4m3fn" + + c_dtype = get_cute_dtype(hidden_states) + + # Gemm1 + flashinfer_cutedsl_grouped_gemm_nt_masked( + (aq, aq_sf), + (w1.permute(1, 2, 0), w1_blockscale), + workspace, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w1_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w1_alpha), + ) # in logical [m, n, l] + + # SILU and quantization + + diq, diq_sf = scaled_fp4_grouped_quant( + silu_and_mul(workspace.permute(2, 0, 1)), + a2_global_scale, + masked_m, + ) + + # Gemm2 + out = out.permute(1, 2, 0) # requirement of kernel + flashinfer_cutedsl_grouped_gemm_nt_masked( + (diq, diq_sf), + (w2.permute(1, 2, 0), w2_blockscale), + out, + masked_m, + ab_dtype=ab_dtype, + sf_dtype=sf_dtype, + c_dtype=c_dtype, + sf_vec_size=sf_vec_size, + alpha=w2_alpha.view(1, 1, num_experts), + alpha_dtype=get_cute_dtype(w2_alpha), + ) # in logical [m, k, l] + out = out.permute(2, 0, 1) diff --git a/vllm/model_executor/layers/quantization/modelopt.py b/vllm/model_executor/layers/quantization/modelopt.py index 0f0638899bf1..7c24cd8870ba 100644 --- a/vllm/model_executor/layers/quantization/modelopt.py +++ b/vllm/model_executor/layers/quantization/modelopt.py @@ -1443,7 +1443,10 @@ def process_weights_after_loading(self, layer: torch.nn.Module) -> None: gemm1_weight = layer.w13_weight.data gemm1_weight_scale = layer.w13_weight_scale.data - if self.allow_flashinfer: + if ( + self.allow_flashinfer + and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + ): gemm1_weight, gemm1_weight_scale = reorder_w1w3_to_w3w1( gemm1_weight, gemm1_weight_scale, dim=-2 ) @@ -1723,9 +1726,10 @@ def apply( ) elif self.fused_experts is not None: - assert ( - self.allow_flashinfer - and self.flashinfer_moe_backend == FlashinferMoeBackend.CUTLASS + assert self.allow_flashinfer + assert self.flashinfer_moe_backend in ( + FlashinferMoeBackend.CUTLASS, + FlashinferMoeBackend.CUTEDSL, ) assert is_valid_flashinfer_cutlass_fused_moe( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py index ddb74a27dc12..473a932ee6e7 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py @@ -10,6 +10,9 @@ FusedMoEConfig, FusedMoEQuantConfig, ) +from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( + FlashInferCuteDSLExperts, +) from vllm.model_executor.layers.fused_moe.flashinfer_cutlass_moe import ( FlashInferExperts, ) @@ -17,10 +20,14 @@ create_flashinfer_prepare_finalize, ) from vllm.platforms import current_platform -from vllm.utils.flashinfer import has_flashinfer_cutlass_fused_moe +from vllm.utils.flashinfer import ( + has_flashinfer_cutedsl_grouped_gemm_nt_masked, + has_flashinfer_cutlass_fused_moe, +) __all__ = [ "is_flashinfer_fp4_cutlass_moe_available", + "is_flashinfer_fp4_cutedsl_moe_available", "reorder_w1w3_to_w3w1", "build_flashinfer_fp4_cutlass_moe_prepare_finalize", ] @@ -36,6 +43,16 @@ def is_flashinfer_fp4_cutlass_moe_available() -> bool: ) +def is_flashinfer_fp4_cutedsl_moe_available() -> bool: + """Return ``True`` when FlashInfer CUTEDSL NV-FP4 kernels can be used.""" + return ( + envs.VLLM_USE_FLASHINFER_MOE_FP4 + and has_flashinfer_cutedsl_grouped_gemm_nt_masked() + and current_platform.is_cuda() + and current_platform.is_device_capability(100) + ) + + def reorder_w1w3_to_w3w1( weight: torch.Tensor, scale: torch.Tensor, dim: int = -2 ) -> tuple[torch.Tensor, torch.Tensor]: @@ -72,14 +89,20 @@ def select_nvfp4_gemm_impl( """Return a GEMM *experts* implementation for NV-FP4 fused-MoE layers""" if allow_flashinfer: - return FlashInferExperts( - out_dtype=moe.in_dtype, - quant_config=moe_quant_config, - ep_rank=moe.moe_parallel_config.ep_rank, - ep_size=moe.moe_parallel_config.ep_size, - tp_rank=moe.moe_parallel_config.tp_rank, - tp_size=moe.moe_parallel_config.tp_size, - ) + if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl": + return FlashInferCuteDSLExperts( + out_dtype=moe.in_dtype, + quant_config=moe_quant_config, + ) + elif envs.VLLM_FLASHINFER_MOE_BACKEND == "throughput": + return FlashInferExperts( + out_dtype=moe.in_dtype, + quant_config=moe_quant_config, + ep_rank=moe.moe_parallel_config.ep_rank, + ep_size=moe.moe_parallel_config.ep_size, + tp_rank=moe.moe_parallel_config.tp_rank, + tp_size=moe.moe_parallel_config.tp_size, + ) # native cutlass experts currently don't support DP; TP case won't call this raise ValueError( diff --git a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py index 8fce7235bdde..500d5f4a426e 100644 --- a/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +++ b/vllm/model_executor/layers/quantization/utils/flashinfer_utils.py @@ -24,6 +24,7 @@ class FlashinferMoeBackend(Enum): TENSORRT_LLM = "TensorRT-LLM" CUTLASS = "CUTLASS" + CUTEDSL = "CUTEDSL" def calculate_tile_tokens_dim(num_tokens, top_k, num_experts): @@ -252,14 +253,17 @@ def flashinfer_cutlass_moe_fp8( def get_flashinfer_moe_backend() -> FlashinferMoeBackend: + backend_map = { + "throughput": FlashinferMoeBackend.CUTLASS, + "latency": FlashinferMoeBackend.TENSORRT_LLM, + "cutedsl": FlashinferMoeBackend.CUTEDSL, + } + flashinfer_moe_backend = envs.VLLM_FLASHINFER_MOE_BACKEND - if flashinfer_moe_backend == "throughput": - return FlashinferMoeBackend.CUTLASS - elif flashinfer_moe_backend == "latency": - return FlashinferMoeBackend.TENSORRT_LLM + if flashinfer_moe_backend in backend_map: + return backend_map[flashinfer_moe_backend] - allowed_backends = ["throughput", "latency"] raise ValueError( - f"Unknown flashinfer moe backend: {flashinfer_moe_backend}" - f" expected one of {allowed_backends}" + f"Unknown flashinfer moe backend: {flashinfer_moe_backend!r}. " + f"Expected one of {list(backend_map.keys())}." ) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py index c3f26cc77411..44c5b027daf4 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py @@ -5,6 +5,7 @@ import vllm.envs as envs from vllm.logger import init_logger from vllm.model_executor.layers.quantization.utils.flashinfer_fp4_moe import ( + is_flashinfer_fp4_cutedsl_moe_available, is_flashinfer_fp4_cutlass_moe_available, ) from vllm.model_executor.layers.quantization.utils.marlin_utils_fp4 import ( @@ -32,7 +33,10 @@ def detect_nvfp4_moe_support(class_name: str = "") -> NvFp4Support: """Detect platform support for NV-FP4 fused-MoE path""" cutlass_supported = cutlass_fp4_supported() - allow_flashinfer = cutlass_supported and is_flashinfer_fp4_cutlass_moe_available() + allow_flashinfer = cutlass_supported and ( + is_flashinfer_fp4_cutlass_moe_available() + or is_flashinfer_fp4_cutedsl_moe_available() + ) if allow_flashinfer: _logger.info_once( diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 24b80e389e83..3128768b4f17 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -96,7 +96,15 @@ def wrapper(*args, **kwargs): flashinfer_cutlass_fused_moe = _lazy_import_wrapper( "flashinfer.fused_moe", "cutlass_fused_moe" ) +flashinfer_cutedsl_grouped_gemm_nt_masked = _lazy_import_wrapper( + "flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked" +) flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") +nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize") +silu_and_mul_nvfp4_batched_quantize = _lazy_import_wrapper( + "flashinfer", "silu_and_mul_nvfp4_batched_quantize" +) +silu_and_mul = _lazy_import_wrapper("flashinfer", "silu_and_mul") nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer", "nvfp4_block_scale_interleave" ) @@ -148,6 +156,14 @@ def has_flashinfer_moe() -> bool: ) +@functools.cache +def has_flashinfer_cutedsl() -> bool: + """Return ``True`` if FlashInfer cutedsl module is available.""" + return ( + has_flashinfer() and importlib.util.find_spec("flashinfer.cute_dsl") is not None + ) + + @functools.cache def has_flashinfer_cutlass_fused_moe() -> bool: """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" @@ -169,6 +185,26 @@ def has_flashinfer_cutlass_fused_moe() -> bool: return True +@functools.cache +def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: + """Return ``True`` if FlashInfer CUTLASS fused MoE is available.""" + if not has_flashinfer_cutedsl(): + return False + + # Check if all required functions are available + required_functions = [ + ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"), + ("flashinfer", "silu_and_mul"), + ("flashinfer", "nvfp4_batched_quantize"), + ] + + for module_name, attr_name in required_functions: + mod = _get_submodule(module_name) + if not mod or not hasattr(mod, attr_name): + return False + return True + + @functools.cache def has_nvidia_artifactory() -> bool: """Return ``True`` if NVIDIA's artifactory is accessible. @@ -444,7 +480,11 @@ def flashinfer_disable_q_quantization() -> bool: "has_flashinfer", "flashinfer_trtllm_fp8_block_scale_moe", "flashinfer_cutlass_fused_moe", + "flashinfer_cutedsl_grouped_gemm_nt_masked", "flashinfer_fp4_quantize", + "silu_and_mul_nvfp4_batched_quantize", + "silu_and_mul", + "nvfp4_batched_quantize", "nvfp4_block_scale_interleave", "trtllm_fp4_block_scale_moe", "autotune", @@ -452,6 +492,7 @@ def flashinfer_disable_q_quantization() -> bool: "has_flashinfer_comm", "has_flashinfer_all2all", "has_flashinfer_cutlass_fused_moe", + "has_flashinfer_cutedsl_grouped_gemm_nt_masked", "has_nvidia_artifactory", "supports_trtllm_attention", "can_use_trtllm_attention", From 8a224daeb8227fc4623af26007154e6c161bec55 Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 14 Oct 2025 03:26:24 +0000 Subject: [PATCH 2/4] Make fused version work with cuda graph Signed-off-by: Shu Wang --- vllm/envs.py | 3 - .../fused_moe/deepep_ll_prepare_finalize.py | 49 ++++++------- .../fused_moe/flashinfer_cutedsl_moe.py | 72 ++----------------- vllm/utils/flashinfer.py | 17 ++--- 4 files changed, 39 insertions(+), 102 deletions(-) diff --git a/vllm/envs.py b/vllm/envs.py index c4b35a3c115e..f265943562a8 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1053,9 +1053,6 @@ def get_vllm_port() -> int | None: "VLLM_MARLIN_USE_ATOMIC_ADD", "0" ) == "1", - "VLLM_DEEPEPLL_BF16_DISPATCH": lambda: bool( - int(os.getenv("VLLM_DEEPEPLL_BF16_DISPATCH", "0")) - ), # Whether to use marlin kernel in mxfp4 quantization method "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( os.environ.get("VLLM_MXFP4_USE_MARLIN", None) diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 450ac57e2b3b..236ef54cf06f 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -114,31 +114,30 @@ def _do_quant( assert isinstance(x, torch.Tensor) num_experts, max_tokens, hidden_dim = x.size() - if not envs.VLLM_DEEPEPLL_BF16_DISPATCH: - # TODO (varun): Optimization - Use a batched version of quant - x = x.view((-1, hidden_dim)) - x, x_scales = moe_kernel_quantize_input( - x, - quant_config.a1_scale, - quant_config.quant_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) - x = x.view((num_experts, -1, hidden_dim)) - - if quant_config.quant_dtype is not None: - assert x_scales is not None - x_scales = normalize_batched_scales_shape(x_scales, num_experts) - else: - # BF16 dispatch path - no quantization - # TODO(shuw@nvidia.com): enable nvfp4 dispatch once DEEPEP is ready. - logger.info_once("Using BF16 dispatch path for DeepEPLLPrepareAndFinalize") - assert x.dtype == torch.bfloat16, ( - "BF16 dispatch requires input to be in BF16" + + # TODO (varun): Optimization - Use a batched version of quant + x = x.view((-1, hidden_dim)) + q_dtype = quant_config.quant_dtype + + if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl": + logger.info_once( + "Skip quantization when using FlashInfer CUTEDSL for " + "ModelOptNvFp4FusedMoE." ) - x_scales = None - x = x.view((num_experts, -1, hidden_dim)) - # print(f"after deepepll: x.shape = {x.shape}") + q_dtype = None + + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + q_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + x = x.view((num_experts, -1, hidden_dim)) + + if q_dtype is not None: + assert x_scales is not None + x_scales = normalize_batched_scales_shape(x_scales, num_experts) return x, x_scales @@ -276,8 +275,6 @@ def _finalize( # TODO (varun) : Enable zero copy mode dbo_maybe_run_recv_hook() - # print("xxx"*100, fused_expert_output.shape) - # print("ttt"*100, fused_expert_output.dtype) _, _, recv_hook = self.buffer.low_latency_combine( fused_expert_output, topk_ids, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 096a6c8dbf49..550f6f82ae1a 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -13,8 +13,8 @@ from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked, has_flashinfer_cutedsl_grouped_gemm_nt_masked, - nvfp4_batched_quantize, - silu_and_mul, + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, ) logger = init_logger(__name__) @@ -110,18 +110,9 @@ def workspace_shapes( - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - # assert a.dim() == 2 - # assert aq.dim() == 3 - # output_shape = aq.shape - # workspace_dtype = a.dtype - # E = aq.size(0) - # workspace2 = (E, M, N) - # workspace1 = output_shape output_shape = (local_num_experts, M, K) workspace2 = (local_num_experts, M, N) workspace1 = output_shape - # The workspace is determined by `aq`, since it comes after any - # potential communication op and is involved in the expert computation. return (workspace1, workspace2, output_shape) def apply( @@ -182,54 +173,6 @@ def get_cute_dtype(input: torch.Tensor) -> str: raise ValueError(f"Unsupported cute dtype {input.dtype}") -def scaled_fp4_grouped_quant( - input_tensor: torch.Tensor, - input_global_scale: torch.Tensor, - mask: torch.Tensor, -): - """ - Wrapper around nvfp4_batched_quantize - - Args: - input_tensor (Tensor): - - Shape (l, m, k) - input_global_scale (Tensor): Shape (l,) - mask (Tensor): Mask tensor, broadcastable - - Returns: - output (Tensor): Quantized tensor, logical shape (m, k//2, l) - output_scales (Tensor): Blockscale tensor, logical shape - (32, 4, rm, 4, rk, l) - """ - num_experts, m, k = input_tensor.shape - - sf_vec_size = 16 - assert k % sf_vec_size == 0, f"k must be multiple of 16, but got {k}." - - scale_k = k // sf_vec_size - padded_k = (scale_k + (4 - 1)) // 4 * 4 - padded_m = (m + (128 - 1)) // 128 * 128 - - aq, aq_sf = nvfp4_batched_quantize( - input_tensor, - input_global_scale, - mask=mask, - ) - - # --- re-layout quantized tensor --- - # physical (l, m, k//2) -> logical (m, k//2, l) - output = aq.permute(1, 2, 0) - - # --- re-layout blockscales --- - # physical (l, rm, rk, 32, 4, 4) -> logical (32, 4, rm, 4, rk, l) - output_scales = aq_sf.view(torch.float8_e4m3fn).view( - num_experts, padded_m // 128, padded_k // 4, 32, 4, 4 - ) - output_scales = output_scales.permute(3, 4, 1, 5, 2, 0) - - return output, output_scales - - def flashinfer_cutedsl_moe_masked( hidden_states: torch.Tensor, input_global_scale: torch.Tensor, @@ -313,10 +256,10 @@ def flashinfer_cutedsl_moe_masked( f"w2_alpha must be (l,), got {w2_alpha.shape}" ) - aq, aq_sf = scaled_fp4_grouped_quant( + aq, aq_sf = scaled_fp4_grouped_quantize( hidden_states, - input_global_scale, masked_m, + input_global_scale, ) workspace = workspace.permute(1, 2, 0) # requirement of kernel @@ -343,11 +286,10 @@ def flashinfer_cutedsl_moe_masked( ) # in logical [m, n, l] # SILU and quantization - - diq, diq_sf = scaled_fp4_grouped_quant( - silu_and_mul(workspace.permute(2, 0, 1)), - a2_global_scale, + diq, diq_sf = silu_and_mul_scaled_nvfp4_experts_quantize( + workspace.permute(2, 0, 1), masked_m, + a2_global_scale, ) # Gemm2 diff --git a/vllm/utils/flashinfer.py b/vllm/utils/flashinfer.py index 3128768b4f17..7a2d17ac1c9e 100644 --- a/vllm/utils/flashinfer.py +++ b/vllm/utils/flashinfer.py @@ -101,10 +101,12 @@ def wrapper(*args, **kwargs): ) flashinfer_fp4_quantize = _lazy_import_wrapper("flashinfer", "fp4_quantize") nvfp4_batched_quantize = _lazy_import_wrapper("flashinfer", "nvfp4_batched_quantize") -silu_and_mul_nvfp4_batched_quantize = _lazy_import_wrapper( - "flashinfer", "silu_and_mul_nvfp4_batched_quantize" +silu_and_mul_scaled_nvfp4_experts_quantize = _lazy_import_wrapper( + "flashinfer", "silu_and_mul_scaled_nvfp4_experts_quantize" +) +scaled_fp4_grouped_quantize = _lazy_import_wrapper( + "flashinfer", "scaled_fp4_grouped_quantize" ) -silu_and_mul = _lazy_import_wrapper("flashinfer", "silu_and_mul") nvfp4_block_scale_interleave = _lazy_import_wrapper( "flashinfer", "nvfp4_block_scale_interleave" ) @@ -194,8 +196,8 @@ def has_flashinfer_cutedsl_grouped_gemm_nt_masked() -> bool: # Check if all required functions are available required_functions = [ ("flashinfer.cute_dsl.blockscaled_gemm", "grouped_gemm_nt_masked"), - ("flashinfer", "silu_and_mul"), - ("flashinfer", "nvfp4_batched_quantize"), + ("flashinfer", "scaled_fp4_grouped_quantize"), + ("flashinfer", "silu_and_scaled_nvfp4_experts_quantize"), ] for module_name, attr_name in required_functions: @@ -482,9 +484,8 @@ def flashinfer_disable_q_quantization() -> bool: "flashinfer_cutlass_fused_moe", "flashinfer_cutedsl_grouped_gemm_nt_masked", "flashinfer_fp4_quantize", - "silu_and_mul_nvfp4_batched_quantize", - "silu_and_mul", - "nvfp4_batched_quantize", + "silu_and_mul_scaled_nvfp4_experts_quantize", + "scaled_fp4_grouped_quantize", "nvfp4_block_scale_interleave", "trtllm_fp4_block_scale_moe", "autotune", From ec6acfdb7b52d3a99b5251ac7b24cd0b9cb9aeaf Mon Sep 17 00:00:00 2001 From: Shu Wang Date: Tue, 14 Oct 2025 03:51:21 +0000 Subject: [PATCH 3/4] fix pre-commit Signed-off-by: Shu Wang --- .../layers/fused_moe/flashinfer_cutedsl_moe.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 550f6f82ae1a..8e4a6df4fc5e 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import Optional import torch @@ -90,7 +89,7 @@ def workspace_shapes( topk: int, global_num_experts: int, local_num_experts: int, - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], + expert_tokens_meta: mk.ExpertTokensMetadata | None, ) -> tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]]: # We use global_num_experts due to how moe_align_block_size handles # expert_maps. @@ -125,13 +124,13 @@ def apply( topk_ids: torch.Tensor, activation: str, global_num_experts: int, - expert_map: Optional[torch.Tensor], - a1q_scale: Optional[torch.Tensor], - a2_scale: Optional[torch.Tensor], # Not used - workspace13: Optional[torch.Tensor], - workspace2: Optional[torch.Tensor], - expert_tokens_meta: Optional[mk.ExpertTokensMetadata], - apply_router_weight_on_input: Optional[bool], + expert_map: torch.Tensor | None, + a1q_scale: torch.Tensor | None, + a2_scale: torch.Tensor | None, # Not used + workspace13: torch.Tensor | None, + workspace2: torch.Tensor | None, + expert_tokens_meta: mk.ExpertTokensMetadata | None, + apply_router_weight_on_input: bool | None, ): assert self.quant_dtype == "nvfp4", ( "Only nvfp4 quantization are currently supported." From 65548dd5f4a130f2188895980aa18421b5683d6e Mon Sep 17 00:00:00 2001 From: root Date: Wed, 15 Oct 2025 17:26:10 +0000 Subject: [PATCH 4/4] Add DeepEP LL nvfp4 dispatch. Signed-off-by: Shu Wang. --- tests/kernels/moe/test_cutedsl_moe.py | 19 ++--- vllm/envs.py | 6 ++ .../fused_moe/deepep_ll_prepare_finalize.py | 78 +++++++++++++------ .../fused_moe/flashinfer_cutedsl_moe.py | 67 +++++++++++----- 4 files changed, 117 insertions(+), 53 deletions(-) diff --git a/tests/kernels/moe/test_cutedsl_moe.py b/tests/kernels/moe/test_cutedsl_moe.py index f3c7b281b086..8fe440bc5a4b 100644 --- a/tests/kernels/moe/test_cutedsl_moe.py +++ b/tests/kernels/moe/test_cutedsl_moe.py @@ -9,10 +9,11 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe.flashinfer_cutedsl_moe import ( flashinfer_cutedsl_moe_masked, - scaled_fp4_grouped_quant, ) from vllm.utils.flashinfer import ( flashinfer_cutedsl_grouped_gemm_nt_masked as cutedsl_gmm_masked, + scaled_fp4_grouped_quantize, + silu_and_mul_scaled_nvfp4_experts_quantize, ) if torch.cuda.get_device_capability() < (10, 0): @@ -219,16 +220,16 @@ def flashinfer_cutedsl_grouped_gemm_nt_masked( ): # hidden_states: [l, m, k] # weights: [l, n, k] - aq, aq_sf = scaled_fp4_grouped_quant( + aq, aq_sf = scaled_fp4_grouped_quantize( hidden_states, - input_global_scale, masked_m.to(hidden_states.device), + input_global_scale, ) num_experts, n, k = weights.shape - bq, bq_sf = scaled_fp4_grouped_quant( + bq, bq_sf = scaled_fp4_grouped_quantize( weights, - w_global_scale, torch.full((num_experts,), n, device=weights.device, dtype=torch.int32), + w_global_scale, ) out = torch.zeros( @@ -316,15 +317,15 @@ def test_flashinfer_cutedsl_moe_masked( (num_experts,), dtype=torch.float32, device=hidden_states.device ) # assume intermediate scale is 1.0 - w1_fp4, w1_blockscale = scaled_fp4_grouped_quant( + w1_fp4, w1_blockscale = scaled_fp4_grouped_quantize( w1, + torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim, w1_global_scale, - torch.ones(num_experts, dtype=torch.int32, device=w1.device) * 2 * inter_dim, ) - w2_fp4, w2_blockscale = scaled_fp4_grouped_quant( + w2_fp4, w2_blockscale = scaled_fp4_grouped_quantize( w2, + torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim, w2_global_scale, - torch.ones(num_experts, dtype=torch.int32, device=w2.device) * hidden_dim, ) w1_alpha = 1.0 / (input_global_scale * w1_global_scale) diff --git a/vllm/envs.py b/vllm/envs.py index f265943562a8..02793b7b339e 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -1057,6 +1057,12 @@ def get_vllm_port() -> int | None: "VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool( os.environ.get("VLLM_MXFP4_USE_MARLIN", None) ), + # Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method + # only supported on Blackwell GPUs and with + # https://github.com/deepseek-ai/DeepEP/pull/341 + "VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool( + int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0")) + ), # Whether to turn on the outlines cache for V0 # This cache is unbounded and on disk, so it's not safe to use in # an environment with potentially malicious users. diff --git a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py index 236ef54cf06f..aa9646c97161 100644 --- a/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +++ b/vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py @@ -111,33 +111,45 @@ def _do_quant( x_fp8, x_scales = x x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype) - assert isinstance(x, torch.Tensor) - - num_experts, max_tokens, hidden_dim = x.size() - - # TODO (varun): Optimization - Use a batched version of quant - x = x.view((-1, hidden_dim)) + assert isinstance(x, (torch.Tensor, tuple)) q_dtype = quant_config.quant_dtype - if envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl": + if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH: + assert isinstance(x, tuple) + x_scales = x[1] + x = x[0].permute(2, 0, 1) + num_experts, max_tokens, hidden_dim_by_2 = x.shape + hidden_dim = hidden_dim_by_2 * 2 + assert envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl" logger.info_once( - "Skip quantization when using FlashInfer CUTEDSL for " - "ModelOptNvFp4FusedMoE." + "Quantization is fused with DeepEP nvfp4 dispatch for " + "FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1" ) - q_dtype = None - - x, x_scales = moe_kernel_quantize_input( - x, - quant_config.a1_scale, - q_dtype, - quant_config.per_act_token_quant, - quant_config.block_shape, - ) - x = x.view((num_experts, -1, hidden_dim)) + else: + if q_dtype == "nvfp4": + q_dtype = None + logger.info_once( + "Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as " + "VLLM_DEEPEPLL_NVFP4_DISPATCH==0" + ) + assert isinstance(x, torch.Tensor) + num_experts, max_tokens, hidden_dim = x.size() + + # TODO (varun): Optimization - Use a batched version of quant + x = x.view((-1, hidden_dim)) + x, x_scales = moe_kernel_quantize_input( + x, + quant_config.a1_scale, + q_dtype, + quant_config.per_act_token_quant, + quant_config.block_shape, + ) + x = x.view((num_experts, -1, hidden_dim)) if q_dtype is not None: assert x_scales is not None - x_scales = normalize_batched_scales_shape(x_scales, num_experts) + if q_dtype != "nvfp4": + x_scales = normalize_batched_scales_shape(x_scales, num_experts) return x, x_scales @@ -167,18 +179,28 @@ def prepare_async( "DeepEP kernels quantize the inputs in blocks of shape 128" ) + use_nvfp4 = False + nvfp4_dispatch = ( + quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH + ) + if nvfp4_dispatch: + use_nvfp4 = True + qc_a1_gscale_or_scale = ( + quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale + ) has_per_token_scales = ( - quant_config.a1_scale.numel() != 1 - if quant_config.a1_scale is not None + qc_a1_gscale_or_scale.numel() != 1 + if qc_a1_gscale_or_scale is not None else ( quant_config.a2_scale.numel() != 1 if quant_config.a2_scale is not None else False ) ) - assert not has_per_token_scales, ( - "low_latency kernels doesn't support dispatching per-token scales" - ) + if not use_nvfp4: + assert not has_per_token_scales, ( + "low_latency kernels doesn't support dispatching per-token scales" + ) if apply_router_weight_on_input: topk = topk_ids.size(1) @@ -195,6 +217,12 @@ def prepare_async( self.max_tokens_per_rank, num_experts, use_fp8=self.use_fp8_dispatch, + **(dict(use_nvfp4=True) if use_nvfp4 else dict()), + **( + dict(x_global_scale=qc_a1_gscale_or_scale) + if qc_a1_gscale_or_scale is not None + else dict() + ), async_finish=False, return_recv_hook=True, ) diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py index 8e4a6df4fc5e..9f7d4b513ab7 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py @@ -4,6 +4,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm import envs from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import ( @@ -109,7 +110,8 @@ def workspace_shapes( - Note: in order for activation chunking to work, the first dimension of each tuple must be the number of tokens. """ - output_shape = (local_num_experts, M, K) + K_dim = K * 2 if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else K + output_shape = (local_num_experts, M, K_dim) workspace2 = (local_num_experts, M, N) workspace1 = output_shape return (workspace1, workspace2, output_shape) @@ -145,9 +147,17 @@ def apply( assert self.w1_scale.ndim == 3 assert self.w2_scale.ndim == 3 + input_global_scale = ( + None if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH else self.a1_gscale + ) + flashinfer_hidden_states = ( + (hidden_states, a1q_scale) + if envs.VLLM_DEEPEPLL_NVFP4_DISPATCH + else hidden_states + ) flashinfer_cutedsl_moe_masked( - hidden_states=hidden_states, - input_global_scale=self.a1_gscale, + hidden_states=flashinfer_hidden_states, + input_global_scale=input_global_scale, w1=w1, w1_blockscale=self.w1_scale, w1_alpha=self.g1_alphas, @@ -173,7 +183,7 @@ def get_cute_dtype(input: torch.Tensor) -> str: def flashinfer_cutedsl_moe_masked( - hidden_states: torch.Tensor, + hidden_states: torch.Tensor | tuple[torch.Tensor, torch.Tensor], input_global_scale: torch.Tensor, w1: torch.Tensor, w1_blockscale: torch.Tensor, @@ -191,7 +201,10 @@ def flashinfer_cutedsl_moe_masked( kernels. Args: - hidden_states (torch.Tensor): [num_experts, m, k], bf16 + hidden_states: Either of the following case + * torch.Tensor: [num_experts, m, k], bf16 + * tuple[torch.Tensor, torch.Tensor]: [num_experts, m, k // 2], + uint8, [num_experts, m, k // 16], float8_e4m3fn input_global_scale (torch.Tensor): (l,) w1 (torch.Tensor): fp4 weights, [l, 2 * n, k // 2], uint8 w1_blockscale (torch.Tensor): blockscale factors, e4m3, @@ -208,9 +221,9 @@ def flashinfer_cutedsl_moe_masked( """ # === Assertions on dtypes === - assert input_global_scale.dtype == torch.float32, ( - f"input_global_scale must be float32, got {input_global_scale.dtype}" - ) + # assert input_global_scale.dtype == torch.float32, ( + # f"input_global_scale must be float32, got {input_global_scale.dtype}" + # ) assert w1.dtype == torch.uint8, f"w1 must be uint8, got {w1.dtype}" assert w1_blockscale.dtype == torch.float8_e4m3fn, ( f"w1_blockscale must be float8_e4m3fn, got {w1_blockscale.dtype}" @@ -231,7 +244,32 @@ def flashinfer_cutedsl_moe_masked( # === Assertions on shapes === n = w2.shape[-1] * 2 # intermediate dimension - num_experts, m, k = hidden_states.shape + if isinstance(hidden_states, tuple): + assert input_global_scale is None, ( + "input_global_scale is needed when input needs quant" + ) + + aq = hidden_states[0].view(torch.uint8) + aq_sf = hidden_states[1].view(torch.float8_e4m3fn) + # m, k_by_2, num_experts = aq.shape + num_experts, m, k_by_2 = aq.shape + k = k_by_2 * 2 + aq = aq.permute(1, 2, 0) + else: + num_experts, m, k = hidden_states.shape + + assert input_global_scale.dtype == torch.float32, ( + f"input_global_scale must be float32, got {input_global_scale.dtype}" + ) + assert input_global_scale.shape == (num_experts,), ( + f"input_global_scale must be (l,), got {input_global_scale.shape}" + ) + + aq, aq_sf = scaled_fp4_grouped_quantize( + hidden_states, + masked_m, + input_global_scale, + ) assert w1.shape[-2] == 2 * n, f"w1 last-2 dim must be 2*n, got {w1.shape}" assert w1.shape[-1] * 2 == k, ( @@ -242,9 +280,6 @@ def flashinfer_cutedsl_moe_masked( n // 2, ), f"w2 shape mismatch, got {w2.shape[-2:]}, expected {(k, n // 2)}" - assert input_global_scale.shape == (num_experts,), ( - f"input_global_scale must be (l,), got {input_global_scale.shape}" - ) assert w1_alpha.shape == (num_experts,), ( f"w1_alpha must be (l,), got {w1_alpha.shape}" ) @@ -255,12 +290,6 @@ def flashinfer_cutedsl_moe_masked( f"w2_alpha must be (l,), got {w2_alpha.shape}" ) - aq, aq_sf = scaled_fp4_grouped_quantize( - hidden_states, - masked_m, - input_global_scale, - ) - workspace = workspace.permute(1, 2, 0) # requirement of kernel sf_vec_size = 16 assert aq_sf.dtype == torch.float8_e4m3fn @@ -268,7 +297,7 @@ def flashinfer_cutedsl_moe_masked( ab_dtype = "float4_e2m1fn" sf_dtype = "float8_e4m3fn" - c_dtype = get_cute_dtype(hidden_states) + c_dtype = "bfloat16" # Gemm1 flashinfer_cutedsl_grouped_gemm_nt_masked(