Skip to content

Commit ba8abea

Browse files
authored
[OMNIML-2336][feat] add W4A8 NVFP4 FP8 fused moe (#7968)
Signed-off-by: Shiyang Chen <[email protected]>
1 parent b77f19f commit ba8abea

File tree

6 files changed

+312
-26
lines changed

6 files changed

+312
-26
lines changed

tensorrt_llm/_torch/modules/fused_moe/create_moe.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ def get_moe_cls(
4141
quant_config.quant_mode.has_fp8_block_scales()
4242
or quant_config.quant_mode.has_nvfp4()
4343
or quant_config.quant_mode.has_w4a16_mxfp4()
44+
or quant_config.quant_mode.has_w4a8_nvfp4_fp8()
4445
or quant_config.quant_mode.has_w4a8_mxfp4_fp8()
4546
or quant_config.quant_mode.has_w4a8_mxfp4_mxfp8()):
4647
return TRTLLMGenFusedMoE

tensorrt_llm/_torch/modules/fused_moe/fused_moe_trtllm_gen.py

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
NVFP4TRTLLMGenFusedMoEMethod,
1616
W4A8MXFP4FP8TRTLLMGenFusedMoEMethod,
1717
W4A8MXFP4MXFP8TRTLLMGenFusedMoEMethod,
18+
W4A8NVFP4FP8TRTLLMGenFusedMoEMethod,
1819
W4A16MXFP4TRTLLMGenFusedMoEMethod)
1920
from .routing import BaseMoeRoutingMethod, DeepSeekV3MoeRoutingMethod
2021

@@ -111,7 +112,7 @@ def __init__(
111112

112113
def _check_configs(self):
113114
assert self.has_deepseek_fp8_block_scales \
114-
or self.has_nvfp4 or self.has_w4a16_mxfp4 \
115+
or self.has_nvfp4 or self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8 \
115116
or self.has_w4a8_mxfp4_fp8 or self.has_w4a8_mxfp4_mxfp8, "TRTLLMGenFusedMoE only supports fp8_block_scaling, nvfp4, w4a16_mxfp4, w4a8_mxfp4_fp8 and w4a8_mxfp4_mxfp8 dtypes."
116117

117118
if self.bias or self.swiglu_alpha is not None or self.swiglu_beta is not None or self.swiglu_limit is not None:
@@ -125,6 +126,8 @@ def _get_quant_method(self):
125126
return NVFP4TRTLLMGenFusedMoEMethod()
126127
elif self.quant_config.layer_quant_mode.has_w4a16_mxfp4():
127128
return W4A16MXFP4TRTLLMGenFusedMoEMethod()
129+
elif self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8():
130+
return W4A8NVFP4FP8TRTLLMGenFusedMoEMethod()
128131
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_fp8():
129132
return W4A8MXFP4FP8TRTLLMGenFusedMoEMethod()
130133
elif self.quant_config.layer_quant_mode.has_w4a8_mxfp4_mxfp8():
@@ -147,8 +150,8 @@ def create_weights(self):
147150
self._weights_created = True
148151
self._check_configs()
149152

150-
# TODO: FIX this.
151-
if (self.has_w4a16_mxfp4 or self.has_w4a8_mxfp4_fp8
153+
if (self.has_w4a16_mxfp4 or self.has_w4a8_nvfp4_fp8
154+
or self.has_w4a8_mxfp4_fp8
152155
or self.has_w4a8_mxfp4_mxfp8) and not self.bias:
153156
self.w3_w1_bias = nn.Parameter(torch.zeros(
154157
(self.w3_w1_weight.shape[0], self.w3_w1_weight.shape[1]),
@@ -378,6 +381,46 @@ def forward_impl(
378381
)
379382
final_hidden_states = final_hidden_states[:, :self.
380383
hidden_size].contiguous()
384+
elif self.has_w4a8_nvfp4_fp8:
385+
386+
if not run_post_quant_allgather:
387+
hidden_states_fp8, _ = torch.ops.tensorrt_llm.static_quantize_e4m3_per_tensor(
388+
x, 1.0 / self.fc31_input_scale)
389+
else:
390+
hidden_states_fp8 = x
391+
392+
outputs = torch.ops.trtllm.fp8_fp4_block_scale_moe_runner(
393+
router_logits,
394+
routing_bias,
395+
hidden_states_fp8,
396+
self.w3_w1_weight,
397+
self.w3_w1_weight_scale.view(torch.float8_e4m3fn),
398+
self.w2_weight,
399+
self.w2_weight_scale.view(torch.float8_e4m3fn),
400+
self.fc31_scale_c.data,
401+
self.fc31_alpha.data,
402+
self.fc2_alpha.data,
403+
self.num_slots,
404+
top_k,
405+
n_group,
406+
topk_group,
407+
self.intermediate_size_per_partition,
408+
self.
409+
slot_start, # local_expert_start; use ep_rank if stride!=1
410+
self.expert_size_per_partition, # local_expert_size
411+
routed_scaling_factor,
412+
self.routing_method.routing_method_type,
413+
do_finalize=do_finalize,
414+
act_type=0,
415+
topk_ids=token_selected_experts,
416+
topk_weights=token_final_scales,
417+
)
418+
419+
if not do_finalize:
420+
assert not self.reduce_results, "reduce_results must be False when do_finalize is False"
421+
return outputs
422+
else:
423+
final_hidden_states = outputs[0]
381424
elif self.has_w4a8_mxfp4_fp8:
382425
pad_size = self.w3_w1_weight.shape[-1] * 2 - x.shape[-1]
383426
if not run_post_quant_allgather:

tensorrt_llm/_torch/modules/fused_moe/interface.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -301,6 +301,12 @@ def has_nvfp4(self):
301301
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_nvfp4(
302302
)
303303

304+
@property
305+
def has_w4a8_nvfp4_fp8(self):
306+
assert self._weights_created
307+
return self.quant_config is not None and self.quant_config.layer_quant_mode.has_w4a8_nvfp4_fp8(
308+
)
309+
304310
@property
305311
def has_w4a8_mxfp4_fp8(self):
306312
assert self._weights_created

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 71 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def trtllmgen_maybe_get_cached_w3_w1_permute_indices(
9696
torch.Tensor],
9797
epilogue_tile_m: int,
9898
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
99-
key = (dst_w3_w1_weight.shape, "w31")
99+
key = (dst_w3_w1_weight.shape, "w31", int(num_elts_per_sf or -1))
100100
if key not in cache_permute_indices:
101101
# Get permute indices and chain them together
102102
permute0 = get_reorder_rows_for_gated_act_gemm_row_indices(
@@ -122,7 +122,7 @@ def trtllmgen_maybe_get_cached_w2_permute_indices(
122122
torch.Tensor],
123123
epilogue_tile_m: int,
124124
num_elts_per_sf: Union[None, int] = None) -> torch.Tensor:
125-
key = (dst_w2_weight.shape, "w2")
125+
key = (dst_w2_weight.shape, "w2", int(num_elts_per_sf or -1))
126126
if key not in cache_permute_indices:
127127
if num_elts_per_sf is None:
128128
permute_indices = (get_shuffle_matrix_a_row_indices(
@@ -1478,11 +1478,15 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
14781478
Base class for NVFP4 fused MoE methods for all backends.
14791479
"""
14801480

1481-
def create_weights(self, module: torch.nn.Module, weight_dtype,
1482-
weight_vec_size, block_scales_dtype,
1483-
block_scales_vec_size):
1481+
def create_weights(self,
1482+
module: torch.nn.Module,
1483+
weight_dtype,
1484+
weight_vec_size,
1485+
block_scales_dtype,
1486+
block_scales_vec_size,
1487+
scaling_vector_size=16):
14841488

1485-
module.scaling_vector_size = 16
1489+
module.scaling_vector_size = scaling_vector_size
14861490
# Divide by 16 because we use int64 to pack 16 fp4 values
14871491
w3_w1_weight_shape = (module.expert_size_per_partition,
14881492
module.intermediate_size_per_partition * 2,
@@ -1893,9 +1897,12 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
18931897
non_blocking=True)
18941898

18951899
def load_expert_w3_w1_weight_scale_nvfp4(
1896-
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
1900+
self,
1901+
module: torch.nn.Module,
1902+
w1_weight_scale: torch.Tensor,
18971903
w3_weight_scale: torch.Tensor,
1898-
dst_w3_w1_weight_scale: torch.Tensor):
1904+
dst_w3_w1_weight_scale: torch.Tensor,
1905+
num_elts_per_sf: int = 16):
18991906
device = dst_w3_w1_weight_scale.device
19001907
assert device.type == "cuda"
19011908
w1_weight_scale = load_weight_shard(w1_weight_scale,
@@ -1933,7 +1940,7 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19331940
dst_w3_w1_weight_scale.view(float4_sf_dtype),
19341941
self._cache_permute_indices,
19351942
epilogue_tile_m,
1936-
num_elts_per_sf=16)
1943+
num_elts_per_sf=num_elts_per_sf)
19371944

19381945
# Shuffle the weight according to permute indices
19391946
w3_w1_weight_scale = torch.ops.trtllm.shuffle_matrix(
@@ -1949,9 +1956,11 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19491956
processed_w3_w1_weight_scale.view(
19501957
self.block_scales_dtype).reshape(orig_shape))
19511958

1952-
def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
1959+
def load_expert_w2_weight_scale_nvfp4(self,
1960+
module: torch.nn.Module,
19531961
w2_weight_scale: torch.Tensor,
1954-
dst_w2_weight_scale: torch.Tensor):
1962+
dst_w2_weight_scale: torch.Tensor,
1963+
num_elts_per_sf: int = 16):
19551964
device = dst_w2_weight_scale.device
19561965
assert device.type == "cuda"
19571966
w2_weight_scale = load_weight_shard(w2_weight_scale,
@@ -1976,7 +1985,7 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
19761985
dst_w2_weight_scale.view(float4_sf_dtype),
19771986
self._cache_permute_indices,
19781987
epilogue_tile_m,
1979-
num_elts_per_sf=16)
1988+
num_elts_per_sf=num_elts_per_sf)
19801989

19811990
# Shuffle the weight according to permute indices
19821991
w_shuffled = torch.ops.trtllm.shuffle_matrix(
@@ -1998,6 +2007,56 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
19982007
non_blocking=True)
19992008

20002009

2010+
class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod(NVFP4TRTLLMGenFusedMoEMethod):
2011+
2012+
def create_weights(self, module: torch.nn.Module):
2013+
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
2014+
block_scales_vec_size = 1
2015+
2016+
NVFP4FusedMoEMethod.create_weights(self, module, self.weight_dtype,
2017+
weight_vec_size,
2018+
self.block_scales_dtype,
2019+
block_scales_vec_size, 32)
2020+
2021+
fc31_scale_c = nn.Parameter(torch.ones(module.expert_size_per_partition,
2022+
dtype=torch.float32),
2023+
requires_grad=False)
2024+
module.register_parameter("fc31_scale_c", fc31_scale_c)
2025+
2026+
self.setup_quant_scales(module)
2027+
2028+
def load_expert_w3_w1_weight_scale_nvfp4(
2029+
self, module: torch.nn.Module, w1_weight_scale: torch.Tensor,
2030+
w3_weight_scale: torch.Tensor,
2031+
dst_w3_w1_weight_scale: torch.Tensor):
2032+
return super().load_expert_w3_w1_weight_scale_nvfp4(
2033+
module, w1_weight_scale, w3_weight_scale, dst_w3_w1_weight_scale,
2034+
32)
2035+
2036+
def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
2037+
w2_weight_scale: torch.Tensor,
2038+
dst_w2_weight_scale: torch.Tensor):
2039+
return super().load_expert_w2_weight_scale_nvfp4(
2040+
module, w2_weight_scale, dst_w2_weight_scale, 32)
2041+
2042+
def load_all_fp4_weight_scales_and_alphas(
2043+
self, module: torch.nn.Module, weights: Dict,
2044+
load_expert_ids: List[int], dst_w3_w1_weight_scale: torch.Tensor,
2045+
dst_w2_weight_scale: torch.Tensor, dst_fc31_alpha: torch.Tensor,
2046+
dst_fc2_alpha: torch.Tensor):
2047+
super().load_all_fp4_weight_scales_and_alphas(
2048+
module, weights, load_expert_ids, dst_w3_w1_weight_scale,
2049+
dst_w2_weight_scale, dst_fc31_alpha, dst_fc2_alpha)
2050+
# The kernel we use will convert nvfp4 to e4m3 before matmul,
2051+
# so the range of the scale factor can only be [0,448/6].
2052+
dst_w3_w1_weight_scale.copy_((dst_w3_w1_weight_scale.to(torch.float32) /
2053+
6.0).to(torch.float8_e4m3fn))
2054+
dst_w2_weight_scale.copy_((dst_w2_weight_scale.to(torch.float32) /
2055+
6.0).to(torch.float8_e4m3fn))
2056+
dst_fc31_alpha.copy_(dst_fc31_alpha * 6.0)
2057+
dst_fc2_alpha.copy_(dst_fc2_alpha * 6.0)
2058+
2059+
20012060
def _get_weight_alignment(weight_alignment, scaling_vector_size, tp_size,
20022061
shard_dim_size):
20032062

0 commit comments

Comments
 (0)