Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
528 changes: 528 additions & 0 deletions tests/kernels/moe/test_cutedsl_moe.py

Large diffs are not rendered by default.

21 changes: 19 additions & 2 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,9 @@
VLLM_USE_FLASHINFER_MOE_FP16: bool = False
VLLM_USE_FLASHINFER_MOE_FP8: bool = False
VLLM_USE_FLASHINFER_MOE_FP4: bool = False
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency"] = "throughput"
VLLM_FLASHINFER_MOE_BACKEND: Literal["throughput", "latency", "cutedsl"] = (
"throughput"
)
VLLM_XGRAMMAR_CACHE_MB: int = 0
VLLM_MSGPACK_ZERO_COPY_THRESHOLD: int = 256
VLLM_ALLOW_INSECURE_SERIALIZATION: bool = False
Expand Down Expand Up @@ -1055,6 +1057,19 @@ def get_vllm_port() -> int | None:
"VLLM_MXFP4_USE_MARLIN": lambda: maybe_convert_bool(
os.environ.get("VLLM_MXFP4_USE_MARLIN", None)
),
# Whether to use DeepEPLL kernels for NVFP4 quantization and dispatch method
# only supported on Blackwell GPUs and with
# https://github.com/deepseek-ai/DeepEP/pull/341
"VLLM_DEEPEPLL_NVFP4_DISPATCH": lambda: bool(
int(os.getenv("VLLM_DEEPEPLL_NVFP4_DISPATCH", "0"))
),
# Whether to turn on the outlines cache for V0
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
"VLLM_V0_USE_OUTLINES_CACHE": lambda: os.environ.get(
"VLLM_V0_USE_OUTLINES_CACHE", "0"
)
== "1",
# Whether to turn on the outlines cache for V1
# This cache is unbounded and on disk, so it's not safe to use in
# an environment with potentially malicious users.
Expand Down Expand Up @@ -1192,7 +1207,9 @@ def get_vllm_port() -> int | None:
# - "latency":
# Uses TensorRT-LLM kernels optimized for low-latency inference.
"VLLM_FLASHINFER_MOE_BACKEND": env_with_choices(
"VLLM_FLASHINFER_MOE_BACKEND", "throughput", ["throughput", "latency"]
"VLLM_FLASHINFER_MOE_BACKEND",
"throughput",
["throughput", "latency", "cutedsl"],
),
# Control the maximum number of tokens per expert supported by the
# NVFP4 MoE CUTLASS Kernel. This value is used to create a buffer for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import torch

import vllm.model_executor.layers.fused_moe.modular_kernel as mk
from vllm import envs
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.config import FusedMoEQuantConfig
from vllm.model_executor.layers.fused_moe.topk_weight_and_reduce import (
TopKWeightAndReduceDelegate,
Expand All @@ -24,6 +26,8 @@
DEEPEP_QUANT_BLOCK_SIZE = 128
DEEPEP_QUANT_BLOCK_SHAPE = [DEEPEP_QUANT_BLOCK_SIZE, DEEPEP_QUANT_BLOCK_SIZE]

logger = init_logger(__name__)


def dequant_fp8(
expert_x_fp8: torch.Tensor, expert_x_scales: torch.Tensor
Expand Down Expand Up @@ -131,24 +135,45 @@ def _do_quant(
x_fp8, x_scales = x
x = dequant_fp8(x_fp8, x_scales).to(dtype=a1_dtype)

assert isinstance(x, torch.Tensor)

num_experts, max_tokens, hidden_dim = x.size()

# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
quant_config.quant_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))
assert isinstance(x, (torch.Tensor, tuple))
q_dtype = quant_config.quant_dtype

if q_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH:
assert isinstance(x, tuple)
x_scales = x[1]
x = x[0].permute(2, 0, 1)
num_experts, max_tokens, hidden_dim_by_2 = x.shape
hidden_dim = hidden_dim_by_2 * 2
assert envs.VLLM_FLASHINFER_MOE_BACKEND == "cutedsl"
logger.info_once(
"Quantization is fused with DeepEP nvfp4 dispatch for "
"FlashInfer CUTEDSL as VLLM_DEEPEPLL_NVFP4_DISPATCH==1"
)
else:
if q_dtype == "nvfp4":
q_dtype = None
logger.info_once(
"Using DeepEP bfloat16 dispatch for FlashInfer CUTEDSL as "
"VLLM_DEEPEPLL_NVFP4_DISPATCH==0"
)
assert isinstance(x, torch.Tensor)
num_experts, max_tokens, hidden_dim = x.size()

# TODO (varun): Optimization - Use a batched version of quant
x = x.view((-1, hidden_dim))
x, x_scales = moe_kernel_quantize_input(
x,
quant_config.a1_scale,
q_dtype,
quant_config.per_act_token_quant,
quant_config.block_shape,
)
x = x.view((num_experts, -1, hidden_dim))

if quant_config.quant_dtype is not None:
if q_dtype is not None:
assert x_scales is not None
x_scales = normalize_batched_scales_shape(x_scales, num_experts)
if q_dtype != "nvfp4":
x_scales = normalize_batched_scales_shape(x_scales, num_experts)

return x, x_scales

Expand Down Expand Up @@ -178,18 +203,28 @@ def prepare_async(
"DeepEP kernels quantize the inputs in blocks of shape 128"
)

use_nvfp4 = False
nvfp4_dispatch = (
quant_config.quant_dtype == "nvfp4" and envs.VLLM_DEEPEPLL_NVFP4_DISPATCH
)
if nvfp4_dispatch:
use_nvfp4 = True
qc_a1_gscale_or_scale = (
quant_config.a1_gscale if nvfp4_dispatch else quant_config.a1_scale
)
has_per_token_scales = (
quant_config.a1_scale.numel() != 1
if quant_config.a1_scale is not None
qc_a1_gscale_or_scale.numel() != 1
if qc_a1_gscale_or_scale is not None
else (
quant_config.a2_scale.numel() != 1
if quant_config.a2_scale is not None
else False
)
)
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales"
)
if not use_nvfp4:
assert not has_per_token_scales, (
"low_latency kernels doesn't support dispatching per-token scales"
)

if apply_router_weight_on_input:
topk = topk_ids.size(1)
Expand All @@ -206,6 +241,12 @@ def prepare_async(
self.max_tokens_per_rank,
num_experts,
use_fp8=self.use_fp8_dispatch,
**(dict(use_nvfp4=True) if use_nvfp4 else dict()),
**(
dict(x_global_scale=qc_a1_gscale_or_scale)
if qc_a1_gscale_or_scale is not None
else dict()
),
async_finish=False,
return_recv_hook=True,
)
Expand Down
Loading
Loading