@@ -1548,6 +1548,42 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
15481548 Base class for NVFP4 fused MoE methods for all backends.
15491549 """
15501550
1551+ def get_weights_shapes (self , module : torch .nn .Module , weight_vec_size : int ,
1552+ block_scales_vec_size : int ):
1553+ # Divide by 16 because we use int64 to pack 16 fp4 values
1554+ w3_w1_weight_shape = (module .expert_size_per_partition ,
1555+ module .intermediate_size_per_partition *
1556+ module .intermediate_size_expand_ratio ,
1557+ module .hidden_size // weight_vec_size )
1558+ w2_weight_shape = (module .expert_size_per_partition , module .hidden_size ,
1559+ module .intermediate_size_per_partition //
1560+ weight_vec_size )
1561+
1562+ w3_w1_weight_scale_shape = (module .expert_size_per_partition ,
1563+ module .intermediate_size_per_partition *
1564+ module .intermediate_size_expand_ratio ,
1565+ module .hidden_size //
1566+ module .scaling_vector_size //
1567+ block_scales_vec_size )
1568+ w2_weight_scale_shape = (module .expert_size_per_partition ,
1569+ module .hidden_size ,
1570+ module .intermediate_size_per_partition //
1571+ module .scaling_vector_size //
1572+ block_scales_vec_size )
1573+
1574+ if module .bias :
1575+ w3_w1_bias_shape = (module .expert_size_per_partition ,
1576+ module .intermediate_size_per_partition *
1577+ module .intermediate_size_expand_ratio )
1578+ w2_bias_shape = (module .expert_size_per_partition ,
1579+ module .hidden_size )
1580+ else :
1581+ w3_w1_bias_shape = None
1582+ w2_bias_shape = None
1583+
1584+ return (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape ,
1585+ w2_bias_shape , w3_w1_weight_scale_shape , w2_weight_scale_shape )
1586+
15511587 def create_weights (self ,
15521588 module : torch .nn .Module ,
15531589 weight_dtype ,
@@ -1557,35 +1593,23 @@ def create_weights(self,
15571593 scaling_vector_size = 16 ):
15581594
15591595 module .scaling_vector_size = scaling_vector_size
1560- # Divide by 16 because we use int64 to pack 16 fp4 values
1561- w3_w1_weight_shape = (module .expert_size_per_partition ,
1562- module .intermediate_size_per_partition *
1563- module .intermediate_size_expand_ratio ,
1564- module .hidden_size // weight_vec_size )
1565- w2_weight_shape = (module .expert_size_per_partition , module .hidden_size ,
1566- module .intermediate_size_per_partition //
1567- weight_vec_size )
1596+
1597+ (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape , w2_bias_shape ,
1598+ w3_w1_weight_scale_shape ,
1599+ w2_weight_scale_shape ) = self .get_weights_shapes (
1600+ module , weight_vec_size , block_scales_vec_size )
15681601
15691602 # Divide by 4 because we use int32 to pack 4 fp8 values
15701603 # column parallel
1571- w3_w1_weight_scale = nn .Parameter (
1572- torch .ones (module .expert_size_per_partition ,
1573- module .intermediate_size_per_partition *
1574- module .intermediate_size_expand_ratio ,
1575- module .hidden_size // module .scaling_vector_size //
1576- block_scales_vec_size ,
1577- dtype = block_scales_dtype ),
1578- requires_grad = False )
1604+ w3_w1_weight_scale = nn .Parameter (torch .ones (w3_w1_weight_scale_shape ,
1605+ dtype = block_scales_dtype ),
1606+ requires_grad = False )
15791607 module .register_parameter ("w3_w1_weight_scale" , w3_w1_weight_scale )
15801608
15811609 # row parallel
1582- w2_weight_scale = nn .Parameter (
1583- torch .ones (module .expert_size_per_partition ,
1584- module .hidden_size ,
1585- module .intermediate_size_per_partition //
1586- module .scaling_vector_size // block_scales_vec_size ,
1587- dtype = block_scales_dtype ),
1588- requires_grad = False )
1610+ w2_weight_scale = nn .Parameter (torch .ones (w2_weight_scale_shape ,
1611+ dtype = block_scales_dtype ),
1612+ requires_grad = False )
15891613 module .register_parameter ("w2_weight_scale" , w2_weight_scale )
15901614
15911615 fc31_input_scale = nn .Parameter (torch .tensor (1. , dtype = torch .float32 ),
@@ -1606,8 +1630,12 @@ def create_weights(self,
16061630 requires_grad = False )
16071631 module .register_parameter ("fc2_alpha" , fc2_alpha )
16081632
1609- super ().create_weights (module , weight_dtype , w3_w1_weight_shape ,
1610- w2_weight_shape )
1633+ super ().create_weights (module ,
1634+ weight_dtype ,
1635+ w3_w1_weight_shape = w3_w1_weight_shape ,
1636+ w2_weight_shape = w2_weight_shape ,
1637+ w3_w1_bias_shape = w3_w1_bias_shape ,
1638+ w2_bias_shape = w2_bias_shape )
16111639
16121640 self .setup_quant_scales (module )
16131641
@@ -1816,6 +1844,55 @@ def setup_quant_scales(self, module: torch.nn.Module):
18161844class NVFP4CutlassFusedMoEMethod (NVFP4FusedMoEMethod ):
18171845 weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
18181846 block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE
1847+ NVFP4_ROW_ALIGNMENT = 128
1848+ NVFP4_COL_ALIGNMENT = 4
1849+
1850+ def get_weights_shapes (self , module : torch .nn .Module , weight_vec_size : int ,
1851+ block_scales_vec_size : int ):
1852+ """Override the base method to get aligned weights shapes for Cutlass nvfp4 alignment."""
1853+ intermediate_size_expand = module .intermediate_size_per_partition * module .intermediate_size_expand_ratio
1854+ intermediate_size_expand_aligned = (
1855+ intermediate_size_expand + self .NVFP4_ROW_ALIGNMENT -
1856+ 1 ) // self .NVFP4_ROW_ALIGNMENT * self .NVFP4_ROW_ALIGNMENT
1857+
1858+ if module .hidden_size % self .NVFP4_COL_ALIGNMENT != 0 :
1859+ raise ValueError (
1860+ f"hidden_size { module .hidden_size } must be divisible by { self .NVFP4_COL_ALIGNMENT } "
1861+ )
1862+ hidden_size_aligned = module .hidden_size
1863+
1864+ w3_w1_weight_shape = (module .expert_size_per_partition ,
1865+ intermediate_size_expand_aligned ,
1866+ hidden_size_aligned // weight_vec_size )
1867+ w2_weight_shape = (module .expert_size_per_partition ,
1868+ hidden_size_aligned ,
1869+ intermediate_size_expand_aligned //
1870+ module .intermediate_size_expand_ratio //
1871+ weight_vec_size )
1872+
1873+ w3_w1_weight_scale_shape = (module .expert_size_per_partition ,
1874+ intermediate_size_expand_aligned ,
1875+ hidden_size_aligned //
1876+ module .scaling_vector_size //
1877+ block_scales_vec_size )
1878+ w2_weight_scale_shape = (module .expert_size_per_partition ,
1879+ hidden_size_aligned ,
1880+ intermediate_size_expand_aligned //
1881+ module .intermediate_size_expand_ratio //
1882+ module .scaling_vector_size //
1883+ block_scales_vec_size )
1884+
1885+ if module .bias :
1886+ w3_w1_bias_shape = (module .expert_size_per_partition ,
1887+ intermediate_size_expand_aligned )
1888+ w2_bias_shape = (module .expert_size_per_partition ,
1889+ hidden_size_aligned )
1890+ else :
1891+ w3_w1_bias_shape = None
1892+ w2_bias_shape = None
1893+
1894+ return (w3_w1_weight_shape , w2_weight_shape , w3_w1_bias_shape ,
1895+ w2_bias_shape , w3_w1_weight_scale_shape , w2_weight_scale_shape )
18191896
18201897 def create_weights (self , module : torch .nn .Module ):
18211898 weight_vec_size = torch .iinfo (self .weight_dtype ).bits // 4
@@ -1842,19 +1919,34 @@ def load_expert_w3_w1_weight_scale_nvfp4(
18421919 device = device )
18431920 # Keep weights in device buffer
18441921 # w3
1845- split_length = module . intermediate_size_per_partition * module . intermediate_size_expand_ratio // 2
1922+ split_length = dst_w3_w1_weight_scale . shape [ 0 ] // 2
18461923 dst_w3_weight_scale = dst_w3_w1_weight_scale .narrow (dim = 0 ,
18471924 start = 0 ,
18481925 length = split_length )
1849- dst_w3_weight_scale .copy_ (
1850- w3_weight_scale .view (dst_w3_weight_scale .dtype ))
1926+ cast_w3_weight_scale = w3_weight_scale .view (dst_w3_weight_scale .dtype )
1927+
1928+ dst_w3_row , dst_w3_col = dst_w3_weight_scale .shape
1929+ _w3_row , _w3_col = cast_w3_weight_scale .shape
1930+ if _w3_row != dst_w3_row or _w3_col != dst_w3_col :
1931+ cast_w3_weight_scale = torch .nn .functional .pad (
1932+ cast_w3_weight_scale ,
1933+ (0 , dst_w3_col - _w3_col , 0 , dst_w3_row - _w3_row ), "constant" ,
1934+ 0 )
1935+ dst_w3_weight_scale .copy_ (cast_w3_weight_scale )
18511936
18521937 # w1
18531938 dst_w1_weight_scale = dst_w3_w1_weight_scale .narrow (dim = 0 ,
18541939 start = split_length ,
18551940 length = split_length )
1856- dst_w1_weight_scale .copy_ (
1857- w1_weight_scale .view (dst_w1_weight_scale .dtype ))
1941+ dst_w1_row , dst_w1_col = dst_w1_weight_scale .shape
1942+ cast_w1_weight_scale = w1_weight_scale .view (dst_w1_weight_scale .dtype )
1943+ _w1_row , _w1_col = cast_w1_weight_scale .shape
1944+ if _w1_row != dst_w1_row or _w1_col != dst_w1_col :
1945+ cast_w1_weight_scale = torch .nn .functional .pad (
1946+ cast_w1_weight_scale ,
1947+ (0 , dst_w1_col - _w1_col , 0 , dst_w1_row - _w1_row ), "constant" ,
1948+ 0 )
1949+ dst_w1_weight_scale .copy_ (cast_w1_weight_scale )
18581950
18591951 orig_shape = dst_w3_w1_weight_scale .shape
18601952
@@ -1876,9 +1968,19 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
18761968 module .tp_rank ,
18771969 TensorParallelMode .ROW ,
18781970 device = device )
1971+
1972+ cast_w2_weight_scale = w2_weight_scale .view (dst_w2_weight_scale .dtype )
1973+ dst_row , dst_col = dst_w2_weight_scale .shape
1974+ _row , _col = cast_w2_weight_scale .shape
1975+ if _row != dst_row or _col != dst_col :
1976+ cast_w2_weight_scale = torch .nn .functional .pad (
1977+ cast_w2_weight_scale ,
1978+ (0 , dst_col - _col , 0 ,
1979+ dst_row - _row ), # (left, right, top, bottom)
1980+ "constant" ,
1981+ 0 )
18791982 # Keep weights in device buffer
1880- dst_w2_weight_scale .copy_ (
1881- w2_weight_scale .view (dst_w2_weight_scale .dtype ))
1983+ dst_w2_weight_scale .copy_ (cast_w2_weight_scale )
18821984
18831985 orig_shape = dst_w2_weight_scale .shape
18841986
@@ -1890,6 +1992,67 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
18901992
18911993 dst_w2_weight_scale .copy_ (dst_w2_weight_scale_interleaved )
18921994
1995+ def load_expert_w3_w1_weight (self , module : torch .nn .Module ,
1996+ w1_weight : torch .Tensor ,
1997+ w3_weight : torch .Tensor ,
1998+ dst_w3_w1_weight : torch .Tensor ):
1999+ """Load and pad w1 and w3 weights for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2000+ device = dst_w3_w1_weight .device
2001+ w1_weight_shard = load_weight_shard (w1_weight ,
2002+ module .tp_size ,
2003+ module .tp_rank ,
2004+ TensorParallelMode .COLUMN ,
2005+ device = device )
2006+ w3_weight_shard = load_weight_shard (w3_weight ,
2007+ module .tp_size ,
2008+ module .tp_rank ,
2009+ TensorParallelMode .COLUMN ,
2010+ device = device )
2011+
2012+ cast_w1_weight_shard = w1_weight_shard .view (dst_w3_w1_weight .dtype )
2013+ cast_w3_weight_shard = w3_weight_shard .view (dst_w3_w1_weight .dtype )
2014+
2015+ dst_row , dst_col = dst_w3_w1_weight .shape
2016+ _w1_row , _w1_col = cast_w1_weight_shard .shape
2017+ _w3_row , _w3_col = cast_w3_weight_shard .shape
2018+ assert _w1_row == _w3_row and _w1_col == _w3_col , "w1 and w3 weights must have the same shape"
2019+ assert dst_row % 2 == 0 , "dst_w3_w1_weight must have even number of rows"
2020+ if _w1_row != dst_row // 2 or _w1_col != dst_col :
2021+ _pad_row = dst_row // 2 - _w1_row
2022+ _pad_col = dst_col - _w1_col
2023+ cast_w1_weight_shard = torch .nn .functional .pad (
2024+ cast_w1_weight_shard , (0 , _pad_col , 0 , _pad_row ), "constant" , 0 )
2025+ cast_w3_weight_shard = torch .nn .functional .pad (
2026+ cast_w3_weight_shard , (0 , _pad_col , 0 , _pad_row ), "constant" , 0 )
2027+
2028+ cast_w31_weight_shard = torch .cat (
2029+ [cast_w3_weight_shard , cast_w1_weight_shard ], dim = 0 )
2030+ dst_w3_w1_weight .copy_ (cast_w31_weight_shard , non_blocking = True )
2031+
2032+ def load_expert_w2_weight (self , module : torch .nn .Module ,
2033+ w2_weight : torch .Tensor ,
2034+ dst_w2_weight : torch .Tensor ):
2035+ """Load and pad w2 weight for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2036+ device = dst_w2_weight .device
2037+ w2_weight_shard = load_weight_shard (w2_weight ,
2038+ module .tp_size ,
2039+ module .tp_rank ,
2040+ TensorParallelMode .ROW ,
2041+ device = device )
2042+ cast_w2_weight_shard = w2_weight_shard .view (dst_w2_weight .dtype )
2043+
2044+ dst_row , dst_col = dst_w2_weight .shape
2045+ _row , _col = cast_w2_weight_shard .shape
2046+ if _row != dst_row or _col != dst_col :
2047+ cast_w2_weight_shard = torch .nn .functional .pad (
2048+ cast_w2_weight_shard ,
2049+ (0 , dst_col - _col , 0 ,
2050+ dst_row - _row ), # (left, right, top, bottom)
2051+ "constant" ,
2052+ 0 )
2053+
2054+ dst_w2_weight .copy_ (cast_w2_weight_shard , non_blocking = True )
2055+
18932056
18942057class NVFP4TRTLLMGenFusedMoEMethod (NVFP4FusedMoEMethod ):
18952058 weight_dtype = float4_sf_dtype
0 commit comments