Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
118 changes: 75 additions & 43 deletions tests/compile/test_silu_mul_quant_fusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import itertools

import pytest
import torch
Expand All @@ -16,7 +16,13 @@
from vllm.compilation.fusion import QUANT_OPS
from vllm.compilation.noop_elimination import NoOpEliminationPass
from vllm.compilation.post_cleanup import PostCleanupPass
from vllm.config import CompilationConfig, PassConfig, VllmConfig
from vllm.config import (
CompilationConfig,
CompilationMode,
PassConfig,
VllmConfig,
set_current_vllm_config,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.quantization.utils.quant_utils import (
GroupShape,
Expand All @@ -25,7 +31,7 @@
)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
Fp8LinearOp,
cutlass_fp8_supported,
maybe_create_device_identity,
)
from vllm.platforms import current_platform

Expand Down Expand Up @@ -54,14 +60,23 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs):
act_quant_static=True,
act_quant_group_shape=GroupShape.PER_TENSOR,
)
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()
self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled()

def forward(self, x):
y = self.silu_and_mul(x)
x2 = self.fp8_linear.apply(y, self.w, self.wscale, input_scale=self.wscale)
return x2

def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]]
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
(
QUANT_OPS[kFp8StaticTensorSym]
if self.enable_quant_fp8_custom_op
else torch.ops.aten.reciprocal
),
]

def ops_in_model_after(self):
return [FUSED_OPS[kFp8StaticTensorSym]]
Expand All @@ -77,6 +92,7 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs):
assert silu_and_mul_nvfp4_quant_supported

self.silu_and_mul = SiluAndMul()
self.enable_silu_mul_custom_op = self.silu_and_mul.enabled()

# create nvfp4 weight
w = torch.rand((hidden_size, hidden_size))
Expand All @@ -101,7 +117,10 @@ def forward(self, x):
return out

def ops_in_model_before(self):
return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]]
return [
SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul,
QUANT_OPS[kNvfp4Quant],
]

def ops_in_model_after(self):
return [FUSED_OPS[kNvfp4Quant]]
Expand All @@ -110,67 +129,80 @@ def ops_in_model_after(self):
@pytest.mark.parametrize("num_tokens", [32, 64])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False])
@pytest.mark.parametrize(
"model_class",
cast(
list[type],
[TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel]
if is_nvfp4_supported()
else [TestSiluMulFp8QuantModel],
),
"model_class, enable_quant_fp8_custom_op, cuda_force_torch",
list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False]))
+ [(TestSiluMulNvfp4QuantModel, False, False)],
)
# cuda_force_torch used to test torch code path on platforms that
# cutlass_fp8_supported() == True.
@pytest.mark.parametrize(
"cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True]
)
@pytest.mark.skipif(
envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm"
)
def test_fusion_silu_and_mul_quant(
num_tokens, hidden_size, dtype, model_class, cuda_force_torch
num_tokens: int,
hidden_size: int,
dtype: torch.dtype,
model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel],
enable_silu_mul_custom_op: bool,
enable_quant_fp8_custom_op: bool,
cuda_force_torch: bool,
):
if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch:
pytest.skip("Duplicate tests for NVFP4")
if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported():
pytest.skip("NVFP4 is not supported on this GPU.")

torch.set_default_device("cuda")
torch.set_default_dtype(dtype)
maybe_create_device_identity()

x = torch.rand(num_tokens, hidden_size * 2)

# Reshape pass is needed for the fusion pass to work
config = VllmConfig()
config.compilation_config = CompilationConfig(
pass_config=PassConfig(enable_fusion=True, enable_noop=True)
custom_ops = []
if enable_silu_mul_custom_op:
custom_ops.append("+silu_and_mul")
if enable_quant_fp8_custom_op:
custom_ops.append("+quant_fp8")
config = VllmConfig(
compilation_config=CompilationConfig(
mode=CompilationMode.VLLM_COMPILE,
custom_ops=custom_ops,
pass_config=PassConfig(enable_fusion=True, enable_noop=True),
),
)
fusion_pass = ActivationQuantFusionPass(config)

passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x)
with set_current_vllm_config(config):
fusion_pass = ActivationQuantFusionPass(config)

# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)
passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)]
backend = TestBackend(*passes)
model = model_class(
hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x
)

result = model(x)
# First dimension dynamic
torch._dynamo.mark_dynamic(x, 0)

model2 = torch.compile(model, backend=backend)
result2 = model2(x)
result = model(x)

# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1
model2 = torch.compile(model, backend=backend)
result2 = model2(x)

torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)
# Check that it gives the same answer
if model_class == TestSiluMulFp8QuantModel:
atol, rtol = 1e-3, 1e-3
elif model_class == TestSiluMulNvfp4QuantModel:
atol, rtol = 1e-1, 1e-1

torch.testing.assert_close(
result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol
)

assert fusion_pass.matched_count == 1
assert fusion_pass.matched_count == 1

# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())
# In pre-nodes, quant op should be present and fused kernels should not
backend.check_before_ops(model.ops_in_model_before())

# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())
# In post-nodes, fused kernels should be present and quant op should not
backend.check_after_ops(model.ops_in_model_after())
47 changes: 21 additions & 26 deletions vllm/compilation/activation_quant_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,12 @@
QuantKey,
kFp8StaticTensorSym,
kNvfp4Quant,
kStaticTensorScale,
)
from vllm.platforms import current_platform

from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32
from .inductor_pass import enable_fake_mode
from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul
from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass

logger = init_logger(__name__)
Expand Down Expand Up @@ -66,6 +66,8 @@ def __init__(
)
self.FUSED_OP = FUSED_OPS[self.quant_key]

self.silu_and_mul_matcher = MatcherSiluAndMul()

def empty_quant(self, *args, **kwargs):
kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs}
return torch.empty(*args, **kwargs)
Expand All @@ -80,42 +82,38 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern):
Fusion for SiluMul+Fp8StaticQuant Pattern
"""

def __init__(self, symmetric: bool = True):
quant_key = QuantKey(
dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric
)
super().__init__(quant_key)
def __init__(self):
super().__init__(kFp8StaticTensorSym)
self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym)

def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
at2 = auto_functionalized(
self.QUANT_OP, result=result, input=at1[1], scale=scale
)
return at2[1]
result_silu_mul = self.silu_and_mul_matcher(input)
result_quant = self.quant_matcher(result_silu_mul, scale)
return result_quant[0]

def replacement(
result: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
d = input.shape[-1] // 2
output_shape = input.shape[:-1] + (d,)
result = torch.empty(
output_shape, device=input.device, dtype=self.quant_dtype
)
at = auto_functionalized(
self.FUSED_OP, result=result, input=input, scale=scale
)
return at[1]

inputs = [
self.empty_quant(5, 4), # result
empty_bf16(5, 4), # result_silu_mul
empty_bf16(5, 4), # input
empty_fp32(1, 1), # scale
*self.silu_and_mul_matcher.inputs(), # input
self.quant_matcher.inputs()[1], # scale
]
pattern(*inputs)

register_replacement(pattern, replacement, inputs, fwd_only, pm_pass)

Expand All @@ -132,24 +130,22 @@ def register(self, pm_pass: PatternMatcherPass):
def pattern(
result: torch.Tensor,
output_scale: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input)
at2 = auto_functionalized(
result_silu_mul = self.silu_and_mul_matcher(input)
at = auto_functionalized(
self.QUANT_OP,
output=result,
input=at1[1],
input=result_silu_mul,
output_scale=output_scale,
input_scale=scale,
)
return at2[1], at2[2]
return at[1], at[2]

def replacement(
result: torch.Tensor,
output_scale: torch.Tensor,
result_silu_mul: torch.Tensor,
input: torch.Tensor,
scale: torch.Tensor,
):
Expand All @@ -165,7 +161,6 @@ def replacement(
inputs = [
self.empty_quant(5, 32), # result
empty_i32(128, 4), # output_scale
empty_bf16(5, 64), # result_silu_mul
empty_bf16(5, 64), # input
empty_fp32(1, 1), # scale
]
Expand Down
30 changes: 30 additions & 0 deletions vllm/compilation/matcher_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from torch._ops import OpOverload

from vllm.config import get_current_vllm_config
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
from vllm.model_executor.layers.quantization.utils.quant_utils import (
Expand All @@ -31,6 +32,8 @@
if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"):
QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501

SILU_MUL_OP = torch.ops._C.silu_and_mul.default


class MatcherCustomOp(ABC):
def __init__(self, enabled: bool):
Expand Down Expand Up @@ -206,3 +209,30 @@ def inputs(self) -> list[torch.Tensor]:
return [input, self.empty_f32(1, 1)]

return [input]


class MatcherSiluAndMul(MatcherCustomOp):
def __init__(self, enabled: bool | None = None):
if enabled is None:
enabled = SiluAndMul.enabled()
super().__init__(enabled)

def inputs(self) -> list[torch.Tensor]:
input = self.empty(5, 4)
return [input]

def forward_custom(
self,
x: torch.Tensor,
) -> torch.Tensor:
d = x.shape[-1] // 2
output_shape = x.shape[:-1] + (d,)
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
result = auto_functionalized(SILU_MUL_OP, result=out, input=x)
return result[1]

def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
return SiluAndMul.forward_native(x)
3 changes: 2 additions & 1 deletion vllm/model_executor/layers/activation.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ def __init__(self):
elif current_platform.is_cpu():
self._forward_method = self.forward_native

def forward_native(self, x: torch.Tensor) -> torch.Tensor:
@staticmethod
def forward_native(x: torch.Tensor) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward()."""
d = x.shape[-1] // 2
return F.silu(x[..., :d]) * x[..., d:]
Expand Down