From 935659682f4da446837da1c06397de02398f55a5 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sat, 18 Oct 2025 15:56:33 +0800 Subject: [PATCH 1/7] silu_mul_fp8_quant Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 101 ++++++++++++++------ vllm/compilation/activation_quant_fusion.py | 37 ++++--- vllm/compilation/matcher_utils.py | 30 ++++++ vllm/model_executor/layers/activation.py | 3 +- 4 files changed, 120 insertions(+), 51 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 16a4271655ef..4f5a521e47cc 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -16,7 +16,13 @@ from vllm.compilation.fusion import QUANT_OPS from vllm.compilation.noop_elimination import NoOpEliminationPass from vllm.compilation.post_cleanup import PostCleanupPass -from vllm.config import CompilationConfig, PassConfig, VllmConfig +from vllm.config import ( + CompilationConfig, + CompilationMode, + PassConfig, + VllmConfig, + set_current_vllm_config, +) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.quantization.utils.quant_utils import ( GroupShape, @@ -26,6 +32,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, cutlass_fp8_supported, + maybe_create_device_identity, ) from vllm.platforms import current_platform @@ -54,6 +61,8 @@ def __init__(self, hidden_size: int, cuda_force_torch: bool, **kwargs): act_quant_static=True, act_quant_group_shape=GroupShape.PER_TENSOR, ) + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() + self.enable_quant_fp8_custom_op = self.fp8_linear.quant_fp8.enabled() def forward(self, x): y = self.silu_and_mul(x) @@ -61,7 +70,14 @@ def forward(self, x): return x2 def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kFp8StaticTensorSym]] + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + ( + QUANT_OPS[kFp8StaticTensorSym] + if self.enable_quant_fp8_custom_op + else torch.ops.aten.reciprocal + ), + ] def ops_in_model_after(self): return [FUSED_OPS[kFp8StaticTensorSym]] @@ -110,13 +126,17 @@ def ops_in_model_after(self): @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) +@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) @pytest.mark.parametrize( "model_class", cast( list[type], - [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() - else [TestSiluMulFp8QuantModel], + ( + [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] + if is_nvfp4_supported() + else [TestSiluMulFp8QuantModel] + ), ), ) # cuda_force_torch used to test torch code path on platforms that @@ -128,49 +148,68 @@ def ops_in_model_after(self): envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" ) def test_fusion_silu_and_mul_quant( - num_tokens, hidden_size, dtype, model_class, cuda_force_torch + num_tokens, + hidden_size, + dtype, + model_class, + enable_silu_mul_custom_op, + enable_quant_fp8_custom_op, + cuda_force_torch, ): if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: pytest.skip("Duplicate tests for NVFP4") torch.set_default_device("cuda") torch.set_default_dtype(dtype) + maybe_create_device_identity() x = torch.rand(num_tokens, hidden_size * 2) # Reshape pass is needed for the fusion pass to work - config = VllmConfig() - config.compilation_config = CompilationConfig( - pass_config=PassConfig(enable_fusion=True, enable_noop=True) + custom_ops = [] + if enable_silu_mul_custom_op: + custom_ops.append("+silu_and_mul") + if enable_quant_fp8_custom_op: + custom_ops.append("+quant_fp8") + config = VllmConfig( + compilation_config=CompilationConfig( + mode=CompilationMode.VLLM_COMPILE, + custom_ops=custom_ops, + pass_config=PassConfig(enable_fusion=True, enable_noop=True), + ), ) - fusion_pass = ActivationQuantFusionPass(config) - passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] - backend = TestBackend(*passes) - model = model_class(hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x) + with set_current_vllm_config(config): + fusion_pass = ActivationQuantFusionPass(config) - # First dimension dynamic - torch._dynamo.mark_dynamic(x, 0) + passes = [NoOpEliminationPass(config), fusion_pass, PostCleanupPass(config)] + backend = TestBackend(*passes) + model = model_class( + hidden_size=hidden_size, cuda_force_torch=cuda_force_torch, x=x + ) - result = model(x) + # First dimension dynamic + torch._dynamo.mark_dynamic(x, 0) - model2 = torch.compile(model, backend=backend) - result2 = model2(x) + result = model(x) - # Check that it gives the same answer - if model_class == TestSiluMulFp8QuantModel: - atol, rtol = 1e-3, 1e-3 - elif model_class == TestSiluMulNvfp4QuantModel: - atol, rtol = 1e-1, 1e-1 + model2 = torch.compile(model, backend=backend) + result2 = model2(x) - torch.testing.assert_close( - result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol - ) + # Check that it gives the same answer + if model_class == TestSiluMulFp8QuantModel: + atol, rtol = 1e-3, 1e-3 + elif model_class == TestSiluMulNvfp4QuantModel: + atol, rtol = 1e-1, 1e-1 + + torch.testing.assert_close( + result[0].to(dtype=dtype), result2[0].to(dtype=dtype), atol=atol, rtol=rtol + ) - assert fusion_pass.matched_count == 1 + assert fusion_pass.matched_count == 1 - # In pre-nodes, quant op should be present and fused kernels should not - backend.check_before_ops(model.ops_in_model_before()) + # In pre-nodes, quant op should be present and fused kernels should not + backend.check_before_ops(model.ops_in_model_before()) - # In post-nodes, fused kernels should be present and quant op should not - backend.check_after_ops(model.ops_in_model_after()) + # In post-nodes, fused kernels should be present and quant op should not + backend.check_after_ops(model.ops_in_model_after()) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index 7448bb122152..d404ec7f7410 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -18,12 +18,12 @@ QuantKey, kFp8StaticTensorSym, kNvfp4Quant, - kStaticTensorScale, ) from vllm.platforms import current_platform from .fusion import QUANT_OPS, empty_bf16, empty_fp32, empty_i32 from .inductor_pass import enable_fake_mode +from .matcher_utils import MatcherQuantFP8, MatcherSiluAndMul from .vllm_inductor_pass import VllmInductorPass, VllmPatternMatcherPass logger = init_logger(__name__) @@ -66,6 +66,8 @@ def __init__( ) self.FUSED_OP = FUSED_OPS[self.quant_key] + self.silu_and_mul_matcher = MatcherSiluAndMul() + def empty_quant(self, *args, **kwargs): kwargs = {"dtype": self.quant_dtype, "device": "cuda", **kwargs} return torch.empty(*args, **kwargs) @@ -80,42 +82,39 @@ class SiluMulFp8StaticQuantPattern(ActivationQuantPattern): Fusion for SiluMul+Fp8StaticQuant Pattern """ - def __init__(self, symmetric: bool = True): - quant_key = QuantKey( - dtype=FP8_DTYPE, scale=kStaticTensorScale, symmetric=symmetric - ) - super().__init__(quant_key) + def __init__(self): + super().__init__(kFp8StaticTensorSym) + self.quant_matcher = MatcherQuantFP8(kFp8StaticTensorSym) def register(self, pm_pass: PatternMatcherPass): def pattern( - result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): - at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) - at2 = auto_functionalized( - self.QUANT_OP, result=result, input=at1[1], scale=scale - ) - return at2[1] + result_silu_mul = self.silu_and_mul_matcher(input) + result_quant = self.quant_matcher(result_silu_mul, scale) + return result_quant[0] def replacement( - result: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): + d = input.shape[-1] // 2 + output_shape = input.shape[:-1] + (d,) + result = torch.empty( + output_shape, device=input.device, dtype=self.quant_dtype + ) at = auto_functionalized( self.FUSED_OP, result=result, input=input, scale=scale ) return at[1] inputs = [ - self.empty_quant(5, 4), # result - empty_bf16(5, 4), # result_silu_mul - empty_bf16(5, 4), # input - empty_fp32(1, 1), # scale + # input, weight + *self.silu_and_mul_matcher.inputs(), + self.quant_matcher.inputs()[1], # scale ] + pattern(*inputs) register_replacement(pattern, replacement, inputs, fwd_only, pm_pass) diff --git a/vllm/compilation/matcher_utils.py b/vllm/compilation/matcher_utils.py index c4eb463de1d2..383fe6033a6d 100644 --- a/vllm/compilation/matcher_utils.py +++ b/vllm/compilation/matcher_utils.py @@ -7,6 +7,7 @@ from torch._ops import OpOverload from vllm.config import get_current_vllm_config +from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8 from vllm.model_executor.layers.quantization.utils.quant_utils import ( @@ -31,6 +32,8 @@ if current_platform.is_cuda() and hasattr(torch.ops._C, "scaled_fp4_quant"): QUANT_OPS[kNvfp4Quant] = torch.ops._C.scaled_fp4_quant.default # noqa: E501 +SILU_MUL_OP = torch.ops._C.silu_and_mul.default + class MatcherCustomOp(ABC): def __init__(self, enabled: bool): @@ -206,3 +209,30 @@ def inputs(self) -> list[torch.Tensor]: return [input, self.empty_f32(1, 1)] return [input] + + +class MatcherSiluAndMul(MatcherCustomOp): + def __init__(self, enabled: bool | None = None): + if enabled is None: + enabled = SiluAndMul.enabled() + super().__init__(enabled) + + def inputs(self) -> list[torch.Tensor]: + input = self.empty(5, 4) + return [input] + + def forward_custom( + self, + x: torch.Tensor, + ) -> torch.Tensor: + d = x.shape[-1] // 2 + output_shape = x.shape[:-1] + (d,) + out = torch.empty(output_shape, dtype=x.dtype, device=x.device) + result = auto_functionalized(SILU_MUL_OP, result=out, input=x) + return result[1] + + def forward_native( + self, + x: torch.Tensor, + ) -> torch.Tensor: + return SiluAndMul.forward_native(x) diff --git a/vllm/model_executor/layers/activation.py b/vllm/model_executor/layers/activation.py index f48fad559efd..e65d6daf9c5e 100644 --- a/vllm/model_executor/layers/activation.py +++ b/vllm/model_executor/layers/activation.py @@ -80,7 +80,8 @@ def __init__(self): elif current_platform.is_cpu(): self._forward_method = self.forward_native - def forward_native(self, x: torch.Tensor) -> torch.Tensor: + @staticmethod + def forward_native(x: torch.Tensor) -> torch.Tensor: """PyTorch-native implementation equivalent to forward().""" d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] From 6e49f7264fd111a934b994943fefc8ab36235141 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sat, 18 Oct 2025 16:11:46 +0800 Subject: [PATCH 2/7] update Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 2 ++ vllm/compilation/activation_quant_fusion.py | 11 ++++------- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 4f5a521e47cc..4f258dd93b17 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -158,6 +158,8 @@ def test_fusion_silu_and_mul_quant( ): if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: pytest.skip("Duplicate tests for NVFP4") + if not enable_quant_fp8_custom_op: + pytest.skip("enable_quant_fp8_custom_op is irrelevant for nvfp4 tests.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index d404ec7f7410..cff871676c83 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -131,24 +131,22 @@ def register(self, pm_pass: PatternMatcherPass): def pattern( result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): - at1 = auto_functionalized(SILU_MUL_OP, result=result_silu_mul, input=input) - at2 = auto_functionalized( + result_silu_mul = self.silu_and_mul_matcher(input) + at = auto_functionalized( self.QUANT_OP, output=result, - input=at1[1], + input=result_silu_mul, output_scale=output_scale, input_scale=scale, ) - return at2[1], at2[2] + return at[1], at[2] def replacement( result: torch.Tensor, output_scale: torch.Tensor, - result_silu_mul: torch.Tensor, input: torch.Tensor, scale: torch.Tensor, ): @@ -164,7 +162,6 @@ def replacement( inputs = [ self.empty_quant(5, 32), # result empty_i32(128, 4), # output_scale - empty_bf16(5, 64), # result_silu_mul empty_bf16(5, 64), # input empty_fp32(1, 1), # scale ] From 6e83e11e3f06e515448f5e4318f5952d7cd7c1b2 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sun, 19 Oct 2025 14:19:58 +0800 Subject: [PATCH 3/7] update Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 46 +++++++++------------ vllm/compilation/activation_quant_fusion.py | 3 +- 2 files changed, 21 insertions(+), 28 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 4f258dd93b17..6cd7847952db 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -1,6 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from typing import cast +import itertools import pytest import torch @@ -31,7 +31,6 @@ ) from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( Fp8LinearOp, - cutlass_fp8_supported, maybe_create_device_identity, ) from vllm.platforms import current_platform @@ -123,43 +122,38 @@ def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] +test_cases: list[ + tuple[type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], bool, bool] +] = [ + *list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])), + (TestSiluMulNvfp4QuantModel, False, False), +] + + @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) -@pytest.mark.parametrize("enable_quant_fp8_custom_op", [True, False]) @pytest.mark.parametrize( - "model_class", - cast( - list[type], - ( - [TestSiluMulFp8QuantModel, TestSiluMulNvfp4QuantModel] - if is_nvfp4_supported() - else [TestSiluMulFp8QuantModel] - ), - ), + "model_class, enable_quant_fp8_custom_op, cuda_force_torch", + test_cases, ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. -@pytest.mark.parametrize( - "cuda_force_torch", [True, False] if cutlass_fp8_supported() else [True] -) @pytest.mark.skipif( envs.VLLM_TARGET_DEVICE not in ["cuda", "rocm"], reason="Only test on CUDA and ROCm" ) def test_fusion_silu_and_mul_quant( - num_tokens, - hidden_size, - dtype, - model_class, - enable_silu_mul_custom_op, - enable_quant_fp8_custom_op, - cuda_force_torch, + num_tokens: int, + hidden_size: int, + dtype: torch.dtype, + model_class: type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], + enable_silu_mul_custom_op: bool, + enable_quant_fp8_custom_op: bool, + cuda_force_torch: bool, ): - if model_class == TestSiluMulNvfp4QuantModel and cuda_force_torch: - pytest.skip("Duplicate tests for NVFP4") - if not enable_quant_fp8_custom_op: - pytest.skip("enable_quant_fp8_custom_op is irrelevant for nvfp4 tests.") + if model_class == TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): + pytest.skip("NVFP4 is not supported on this GPU.") torch.set_default_device("cuda") torch.set_default_dtype(dtype) diff --git a/vllm/compilation/activation_quant_fusion.py b/vllm/compilation/activation_quant_fusion.py index cff871676c83..b5fd67c5b027 100644 --- a/vllm/compilation/activation_quant_fusion.py +++ b/vllm/compilation/activation_quant_fusion.py @@ -110,8 +110,7 @@ def replacement( return at[1] inputs = [ - # input, weight - *self.silu_and_mul_matcher.inputs(), + *self.silu_and_mul_matcher.inputs(), # input self.quant_matcher.inputs()[1], # scale ] pattern(*inputs) From 0fb6ffef030e58ee565e932ddd9def2ccc79dbea Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Sun, 19 Oct 2025 14:21:06 +0800 Subject: [PATCH 4/7] update Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 6cd7847952db..254b5b31576d 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -122,7 +122,7 @@ def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] -test_cases: list[ +test_params: list[ tuple[type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], bool, bool] ] = [ *list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])), @@ -136,7 +136,7 @@ def ops_in_model_after(self): @pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize( "model_class, enable_quant_fp8_custom_op, cuda_force_torch", - test_cases, + test_params, ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. From 41d4313934086e65607f7dca7152cea57afa8b8d Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Tue, 21 Oct 2025 00:02:09 +0800 Subject: [PATCH 5/7] update Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 254b5b31576d..58fcc3a55a47 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -122,21 +122,14 @@ def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] -test_params: list[ - tuple[type[TestSiluMulFp8QuantModel | TestSiluMulNvfp4QuantModel], bool, bool] -] = [ - *list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])), - (TestSiluMulNvfp4QuantModel, False, False), -] - - @pytest.mark.parametrize("num_tokens", [32, 64]) @pytest.mark.parametrize("hidden_size", [128, 256]) @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16]) @pytest.mark.parametrize("enable_silu_mul_custom_op", [True, False]) @pytest.mark.parametrize( "model_class, enable_quant_fp8_custom_op, cuda_force_torch", - test_params, + list(itertools.product([TestSiluMulFp8QuantModel], [True, False], [True, False])) + + [(TestSiluMulNvfp4QuantModel, False, False)], ) # cuda_force_torch used to test torch code path on platforms that # cutlass_fp8_supported() == True. @@ -152,7 +145,7 @@ def test_fusion_silu_and_mul_quant( enable_quant_fp8_custom_op: bool, cuda_force_torch: bool, ): - if model_class == TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): + if model_class is TestSiluMulNvfp4QuantModel and not is_nvfp4_supported(): pytest.skip("NVFP4 is not supported on this GPU.") torch.set_default_device("cuda") From 7162838cf147fcae60850b9497db240d3dedba09 Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Tue, 21 Oct 2025 11:11:59 +0800 Subject: [PATCH 6/7] fix nvfp4 test Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 58fcc3a55a47..12d45b8327a3 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -116,7 +116,10 @@ def forward(self, x): return out def ops_in_model_before(self): - return [SILU_MUL_OP, QUANT_OPS[kNvfp4Quant]] + return [ + SILU_MUL_OP if self.enable_silu_mul_custom_op else torch.ops.aten.mul, + QUANT_OPS[kNvfp4Quant], + ] def ops_in_model_after(self): return [FUSED_OPS[kNvfp4Quant]] From 0e8c6aea5310a99ee5d81ffd08bd8910d5b1fded Mon Sep 17 00:00:00 2001 From: zjy0516 Date: Tue, 21 Oct 2025 12:44:39 +0800 Subject: [PATCH 7/7] fix nvfp4 test Signed-off-by: zjy0516 --- tests/compile/test_silu_mul_quant_fusion.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/compile/test_silu_mul_quant_fusion.py b/tests/compile/test_silu_mul_quant_fusion.py index 12d45b8327a3..0ddb82b7c3fc 100644 --- a/tests/compile/test_silu_mul_quant_fusion.py +++ b/tests/compile/test_silu_mul_quant_fusion.py @@ -92,6 +92,7 @@ def __init__(self, hidden_size: int, x: torch.Tensor, **kwargs): assert silu_and_mul_nvfp4_quant_supported self.silu_and_mul = SiluAndMul() + self.enable_silu_mul_custom_op = self.silu_and_mul.enabled() # create nvfp4 weight w = torch.rand((hidden_size, hidden_size))