@@ -96,7 +96,7 @@ def trtllmgen_maybe_get_cached_w3_w1_permute_indices(
9696 torch .Tensor ],
9797 epilogue_tile_m : int ,
9898 num_elts_per_sf : Union [None , int ] = None ) -> torch .Tensor :
99- key = (dst_w3_w1_weight .shape , "w31" )
99+ key = (dst_w3_w1_weight .shape , "w31" , int ( num_elts_per_sf or - 1 ) )
100100 if key not in cache_permute_indices :
101101 # Get permute indices and chain them together
102102 permute0 = get_reorder_rows_for_gated_act_gemm_row_indices (
@@ -122,7 +122,7 @@ def trtllmgen_maybe_get_cached_w2_permute_indices(
122122 torch .Tensor ],
123123 epilogue_tile_m : int ,
124124 num_elts_per_sf : Union [None , int ] = None ) -> torch .Tensor :
125- key = (dst_w2_weight .shape , "w2" )
125+ key = (dst_w2_weight .shape , "w2" , int ( num_elts_per_sf or - 1 ) )
126126 if key not in cache_permute_indices :
127127 if num_elts_per_sf is None :
128128 permute_indices = (get_shuffle_matrix_a_row_indices (
@@ -1478,11 +1478,15 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
14781478 Base class for NVFP4 fused MoE methods for all backends.
14791479 """
14801480
1481- def create_weights (self , module : torch .nn .Module , weight_dtype ,
1482- weight_vec_size , block_scales_dtype ,
1483- block_scales_vec_size ):
1481+ def create_weights (self ,
1482+ module : torch .nn .Module ,
1483+ weight_dtype ,
1484+ weight_vec_size ,
1485+ block_scales_dtype ,
1486+ block_scales_vec_size ,
1487+ scaling_vector_size = 16 ):
14841488
1485- module .scaling_vector_size = 16
1489+ module .scaling_vector_size = scaling_vector_size
14861490 # Divide by 16 because we use int64 to pack 16 fp4 values
14871491 w3_w1_weight_shape = (module .expert_size_per_partition ,
14881492 module .intermediate_size_per_partition * 2 ,
@@ -1893,9 +1897,12 @@ def load_expert_w2_weight(self, module: torch.nn.Module,
18931897 non_blocking = True )
18941898
18951899 def load_expert_w3_w1_weight_scale_nvfp4 (
1896- self , module : torch .nn .Module , w1_weight_scale : torch .Tensor ,
1900+ self ,
1901+ module : torch .nn .Module ,
1902+ w1_weight_scale : torch .Tensor ,
18971903 w3_weight_scale : torch .Tensor ,
1898- dst_w3_w1_weight_scale : torch .Tensor ):
1904+ dst_w3_w1_weight_scale : torch .Tensor ,
1905+ num_elts_per_sf : int = 16 ):
18991906 device = dst_w3_w1_weight_scale .device
19001907 assert device .type == "cuda"
19011908 w1_weight_scale = load_weight_shard (w1_weight_scale ,
@@ -1933,7 +1940,7 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19331940 dst_w3_w1_weight_scale .view (float4_sf_dtype ),
19341941 self ._cache_permute_indices ,
19351942 epilogue_tile_m ,
1936- num_elts_per_sf = 16 )
1943+ num_elts_per_sf = num_elts_per_sf )
19371944
19381945 # Shuffle the weight according to permute indices
19391946 w3_w1_weight_scale = torch .ops .trtllm .shuffle_matrix (
@@ -1949,9 +1956,11 @@ def load_expert_w3_w1_weight_scale_nvfp4(
19491956 processed_w3_w1_weight_scale .view (
19501957 self .block_scales_dtype ).reshape (orig_shape ))
19511958
1952- def load_expert_w2_weight_scale_nvfp4 (self , module : torch .nn .Module ,
1959+ def load_expert_w2_weight_scale_nvfp4 (self ,
1960+ module : torch .nn .Module ,
19531961 w2_weight_scale : torch .Tensor ,
1954- dst_w2_weight_scale : torch .Tensor ):
1962+ dst_w2_weight_scale : torch .Tensor ,
1963+ num_elts_per_sf : int = 16 ):
19551964 device = dst_w2_weight_scale .device
19561965 assert device .type == "cuda"
19571966 w2_weight_scale = load_weight_shard (w2_weight_scale ,
@@ -1976,7 +1985,7 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
19761985 dst_w2_weight_scale .view (float4_sf_dtype ),
19771986 self ._cache_permute_indices ,
19781987 epilogue_tile_m ,
1979- num_elts_per_sf = 16 )
1988+ num_elts_per_sf = num_elts_per_sf )
19801989
19811990 # Shuffle the weight according to permute indices
19821991 w_shuffled = torch .ops .trtllm .shuffle_matrix (
@@ -1998,6 +2007,56 @@ def load_quant_scales(self, module: torch.nn.Module, weights: Dict):
19982007 non_blocking = True )
19992008
20002009
2010+ class W4A8NVFP4FP8TRTLLMGenFusedMoEMethod (NVFP4TRTLLMGenFusedMoEMethod ):
2011+
2012+ def create_weights (self , module : torch .nn .Module ):
2013+ weight_vec_size = torch .iinfo (self .weight_dtype ).bits // 4
2014+ block_scales_vec_size = 1
2015+
2016+ NVFP4FusedMoEMethod .create_weights (self , module , self .weight_dtype ,
2017+ weight_vec_size ,
2018+ self .block_scales_dtype ,
2019+ block_scales_vec_size , 32 )
2020+
2021+ fc31_scale_c = nn .Parameter (torch .ones (module .expert_size_per_partition ,
2022+ dtype = torch .float32 ),
2023+ requires_grad = False )
2024+ module .register_parameter ("fc31_scale_c" , fc31_scale_c )
2025+
2026+ self .setup_quant_scales (module )
2027+
2028+ def load_expert_w3_w1_weight_scale_nvfp4 (
2029+ self , module : torch .nn .Module , w1_weight_scale : torch .Tensor ,
2030+ w3_weight_scale : torch .Tensor ,
2031+ dst_w3_w1_weight_scale : torch .Tensor ):
2032+ return super ().load_expert_w3_w1_weight_scale_nvfp4 (
2033+ module , w1_weight_scale , w3_weight_scale , dst_w3_w1_weight_scale ,
2034+ 32 )
2035+
2036+ def load_expert_w2_weight_scale_nvfp4 (self , module : torch .nn .Module ,
2037+ w2_weight_scale : torch .Tensor ,
2038+ dst_w2_weight_scale : torch .Tensor ):
2039+ return super ().load_expert_w2_weight_scale_nvfp4 (
2040+ module , w2_weight_scale , dst_w2_weight_scale , 32 )
2041+
2042+ def load_all_fp4_weight_scales_and_alphas (
2043+ self , module : torch .nn .Module , weights : Dict ,
2044+ load_expert_ids : List [int ], dst_w3_w1_weight_scale : torch .Tensor ,
2045+ dst_w2_weight_scale : torch .Tensor , dst_fc31_alpha : torch .Tensor ,
2046+ dst_fc2_alpha : torch .Tensor ):
2047+ super ().load_all_fp4_weight_scales_and_alphas (
2048+ module , weights , load_expert_ids , dst_w3_w1_weight_scale ,
2049+ dst_w2_weight_scale , dst_fc31_alpha , dst_fc2_alpha )
2050+ # The kernel we use will convert nvfp4 to e4m3 before matmul,
2051+ # so the range of the scale factor can only be [0,448/6].
2052+ dst_w3_w1_weight_scale .copy_ ((dst_w3_w1_weight_scale .to (torch .float32 ) /
2053+ 6.0 ).to (torch .float8_e4m3fn ))
2054+ dst_w2_weight_scale .copy_ ((dst_w2_weight_scale .to (torch .float32 ) /
2055+ 6.0 ).to (torch .float8_e4m3fn ))
2056+ dst_fc31_alpha .copy_ (dst_fc31_alpha * 6.0 )
2057+ dst_fc2_alpha .copy_ (dst_fc2_alpha * 6.0 )
2058+
2059+
20012060def _get_weight_alignment (weight_alignment , scaling_vector_size , tp_size ,
20022061 shard_dim_size ):
20032062
0 commit comments