Skip to content

Commit bf5c6a7

Browse files
committed
[FMDL-1222][feat] Support weight and weight_scale padding for NVFP4 MoE cutlass
Signed-off-by: Wanli Jiang <[email protected]>
1 parent 004299a commit bf5c6a7

File tree

1 file changed

+195
-32
lines changed

1 file changed

+195
-32
lines changed

tensorrt_llm/_torch/modules/fused_moe/quantization.py

Lines changed: 195 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1548,6 +1548,42 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
15481548
Base class for NVFP4 fused MoE methods for all backends.
15491549
"""
15501550

1551+
def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
1552+
block_scales_vec_size: int):
1553+
# Divide by 16 because we use int64 to pack 16 fp4 values
1554+
w3_w1_weight_shape = (module.expert_size_per_partition,
1555+
module.intermediate_size_per_partition *
1556+
module.intermediate_size_expand_ratio,
1557+
module.hidden_size // weight_vec_size)
1558+
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
1559+
module.intermediate_size_per_partition //
1560+
weight_vec_size)
1561+
1562+
w3_w1_weight_scale_shape = (module.expert_size_per_partition,
1563+
module.intermediate_size_per_partition *
1564+
module.intermediate_size_expand_ratio,
1565+
module.hidden_size //
1566+
module.scaling_vector_size //
1567+
block_scales_vec_size)
1568+
w2_weight_scale_shape = (module.expert_size_per_partition,
1569+
module.hidden_size,
1570+
module.intermediate_size_per_partition //
1571+
module.scaling_vector_size //
1572+
block_scales_vec_size)
1573+
1574+
if module.bias:
1575+
w3_w1_bias_shape = (module.expert_size_per_partition,
1576+
module.intermediate_size_per_partition *
1577+
module.intermediate_size_expand_ratio)
1578+
w2_bias_shape = (module.expert_size_per_partition,
1579+
module.hidden_size)
1580+
else:
1581+
w3_w1_bias_shape = None
1582+
w2_bias_shape = None
1583+
1584+
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
1585+
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
1586+
15511587
def create_weights(self,
15521588
module: torch.nn.Module,
15531589
weight_dtype,
@@ -1557,35 +1593,23 @@ def create_weights(self,
15571593
scaling_vector_size=16):
15581594

15591595
module.scaling_vector_size = scaling_vector_size
1560-
# Divide by 16 because we use int64 to pack 16 fp4 values
1561-
w3_w1_weight_shape = (module.expert_size_per_partition,
1562-
module.intermediate_size_per_partition *
1563-
module.intermediate_size_expand_ratio,
1564-
module.hidden_size // weight_vec_size)
1565-
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
1566-
module.intermediate_size_per_partition //
1567-
weight_vec_size)
1596+
1597+
(w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, w2_bias_shape,
1598+
w3_w1_weight_scale_shape,
1599+
w2_weight_scale_shape) = self.get_weights_shapes(
1600+
module, weight_vec_size, block_scales_vec_size)
15681601

15691602
# Divide by 4 because we use int32 to pack 4 fp8 values
15701603
# column parallel
1571-
w3_w1_weight_scale = nn.Parameter(
1572-
torch.ones(module.expert_size_per_partition,
1573-
module.intermediate_size_per_partition *
1574-
module.intermediate_size_expand_ratio,
1575-
module.hidden_size // module.scaling_vector_size //
1576-
block_scales_vec_size,
1577-
dtype=block_scales_dtype),
1578-
requires_grad=False)
1604+
w3_w1_weight_scale = nn.Parameter(torch.ones(w3_w1_weight_scale_shape,
1605+
dtype=block_scales_dtype),
1606+
requires_grad=False)
15791607
module.register_parameter("w3_w1_weight_scale", w3_w1_weight_scale)
15801608

15811609
# row parallel
1582-
w2_weight_scale = nn.Parameter(
1583-
torch.ones(module.expert_size_per_partition,
1584-
module.hidden_size,
1585-
module.intermediate_size_per_partition //
1586-
module.scaling_vector_size // block_scales_vec_size,
1587-
dtype=block_scales_dtype),
1588-
requires_grad=False)
1610+
w2_weight_scale = nn.Parameter(torch.ones(w2_weight_scale_shape,
1611+
dtype=block_scales_dtype),
1612+
requires_grad=False)
15891613
module.register_parameter("w2_weight_scale", w2_weight_scale)
15901614

15911615
fc31_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32),
@@ -1606,8 +1630,12 @@ def create_weights(self,
16061630
requires_grad=False)
16071631
module.register_parameter("fc2_alpha", fc2_alpha)
16081632

1609-
super().create_weights(module, weight_dtype, w3_w1_weight_shape,
1610-
w2_weight_shape)
1633+
super().create_weights(module,
1634+
weight_dtype,
1635+
w3_w1_weight_shape=w3_w1_weight_shape,
1636+
w2_weight_shape=w2_weight_shape,
1637+
w3_w1_bias_shape=w3_w1_bias_shape,
1638+
w2_bias_shape=w2_bias_shape)
16111639

16121640
self.setup_quant_scales(module)
16131641

@@ -1816,6 +1844,55 @@ def setup_quant_scales(self, module: torch.nn.Module):
18161844
class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
18171845
weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
18181846
block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE
1847+
NVFP4_ROW_ALIGNMENT = 128
1848+
NVFP4_COL_ALIGNMENT = 4
1849+
1850+
def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
1851+
block_scales_vec_size: int):
1852+
"""Override the base method to get aligned weights shapes for Cutlass nvfp4 alignment."""
1853+
intermediate_size_expand = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio
1854+
intermediate_size_expand_aligned = (
1855+
intermediate_size_expand + self.NVFP4_ROW_ALIGNMENT -
1856+
1) // self.NVFP4_ROW_ALIGNMENT * self.NVFP4_ROW_ALIGNMENT
1857+
1858+
if module.hidden_size % self.NVFP4_COL_ALIGNMENT != 0:
1859+
raise ValueError(
1860+
f"hidden_size {module.hidden_size} must be divisible by {self.NVFP4_COL_ALIGNMENT}"
1861+
)
1862+
hidden_size_aligned = module.hidden_size
1863+
1864+
w3_w1_weight_shape = (module.expert_size_per_partition,
1865+
intermediate_size_expand_aligned,
1866+
hidden_size_aligned // weight_vec_size)
1867+
w2_weight_shape = (module.expert_size_per_partition,
1868+
hidden_size_aligned,
1869+
intermediate_size_expand_aligned //
1870+
module.intermediate_size_expand_ratio //
1871+
weight_vec_size)
1872+
1873+
w3_w1_weight_scale_shape = (module.expert_size_per_partition,
1874+
intermediate_size_expand_aligned,
1875+
hidden_size_aligned //
1876+
module.scaling_vector_size //
1877+
block_scales_vec_size)
1878+
w2_weight_scale_shape = (module.expert_size_per_partition,
1879+
hidden_size_aligned,
1880+
intermediate_size_expand_aligned //
1881+
module.intermediate_size_expand_ratio //
1882+
module.scaling_vector_size //
1883+
block_scales_vec_size)
1884+
1885+
if module.bias:
1886+
w3_w1_bias_shape = (module.expert_size_per_partition,
1887+
intermediate_size_expand_aligned)
1888+
w2_bias_shape = (module.expert_size_per_partition,
1889+
hidden_size_aligned)
1890+
else:
1891+
w3_w1_bias_shape = None
1892+
w2_bias_shape = None
1893+
1894+
return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
1895+
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)
18191896

18201897
def create_weights(self, module: torch.nn.Module):
18211898
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
@@ -1842,19 +1919,34 @@ def load_expert_w3_w1_weight_scale_nvfp4(
18421919
device=device)
18431920
# Keep weights in device buffer
18441921
# w3
1845-
split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2
1922+
split_length = dst_w3_w1_weight_scale.shape[0] // 2
18461923
dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0,
18471924
start=0,
18481925
length=split_length)
1849-
dst_w3_weight_scale.copy_(
1850-
w3_weight_scale.view(dst_w3_weight_scale.dtype))
1926+
cast_w3_weight_scale = w3_weight_scale.view(dst_w3_weight_scale.dtype)
1927+
1928+
dst_w3_row, dst_w3_col = dst_w3_weight_scale.shape
1929+
_w3_row, _w3_col = cast_w3_weight_scale.shape
1930+
if _w3_row != dst_w3_row or _w3_col != dst_w3_col:
1931+
cast_w3_weight_scale = torch.nn.functional.pad(
1932+
cast_w3_weight_scale,
1933+
(0, dst_w3_col - _w3_col, 0, dst_w3_row - _w3_row), "constant",
1934+
0)
1935+
dst_w3_weight_scale.copy_(cast_w3_weight_scale)
18511936

18521937
# w1
18531938
dst_w1_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0,
18541939
start=split_length,
18551940
length=split_length)
1856-
dst_w1_weight_scale.copy_(
1857-
w1_weight_scale.view(dst_w1_weight_scale.dtype))
1941+
dst_w1_row, dst_w1_col = dst_w1_weight_scale.shape
1942+
cast_w1_weight_scale = w1_weight_scale.view(dst_w1_weight_scale.dtype)
1943+
_w1_row, _w1_col = cast_w1_weight_scale.shape
1944+
if _w1_row != dst_w1_row or _w1_col != dst_w1_col:
1945+
cast_w1_weight_scale = torch.nn.functional.pad(
1946+
cast_w1_weight_scale,
1947+
(0, dst_w1_col - _w1_col, 0, dst_w1_row - _w1_row), "constant",
1948+
0)
1949+
dst_w1_weight_scale.copy_(cast_w1_weight_scale)
18581950

18591951
orig_shape = dst_w3_w1_weight_scale.shape
18601952

@@ -1876,9 +1968,19 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
18761968
module.tp_rank,
18771969
TensorParallelMode.ROW,
18781970
device=device)
1971+
1972+
cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype)
1973+
dst_row, dst_col = dst_w2_weight_scale.shape
1974+
_row, _col = cast_w2_weight_scale.shape
1975+
if _row != dst_row or _col != dst_col:
1976+
cast_w2_weight_scale = torch.nn.functional.pad(
1977+
cast_w2_weight_scale,
1978+
(0, dst_col - _col, 0,
1979+
dst_row - _row), # (left, right, top, bottom)
1980+
"constant",
1981+
0)
18791982
# Keep weights in device buffer
1880-
dst_w2_weight_scale.copy_(
1881-
w2_weight_scale.view(dst_w2_weight_scale.dtype))
1983+
dst_w2_weight_scale.copy_(cast_w2_weight_scale)
18821984

18831985
orig_shape = dst_w2_weight_scale.shape
18841986

@@ -1890,6 +1992,67 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
18901992

18911993
dst_w2_weight_scale.copy_(dst_w2_weight_scale_interleaved)
18921994

1995+
def load_expert_w3_w1_weight(self, module: torch.nn.Module,
1996+
w1_weight: torch.Tensor,
1997+
w3_weight: torch.Tensor,
1998+
dst_w3_w1_weight: torch.Tensor):
1999+
"""Load and pad w1 and w3 weights for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2000+
device = dst_w3_w1_weight.device
2001+
w1_weight_shard = load_weight_shard(w1_weight,
2002+
module.tp_size,
2003+
module.tp_rank,
2004+
TensorParallelMode.COLUMN,
2005+
device=device)
2006+
w3_weight_shard = load_weight_shard(w3_weight,
2007+
module.tp_size,
2008+
module.tp_rank,
2009+
TensorParallelMode.COLUMN,
2010+
device=device)
2011+
2012+
cast_w1_weight_shard = w1_weight_shard.view(dst_w3_w1_weight.dtype)
2013+
cast_w3_weight_shard = w3_weight_shard.view(dst_w3_w1_weight.dtype)
2014+
2015+
dst_row, dst_col = dst_w3_w1_weight.shape
2016+
_w1_row, _w1_col = cast_w1_weight_shard.shape
2017+
_w3_row, _w3_col = cast_w3_weight_shard.shape
2018+
assert _w1_row == _w3_row and _w1_col == _w3_col, "w1 and w3 weights must have the same shape"
2019+
assert dst_row % 2 == 0, "dst_w3_w1_weight must have even number of rows"
2020+
if _w1_row != dst_row // 2 or _w1_col != dst_col:
2021+
_pad_row = dst_row // 2 - _w1_row
2022+
_pad_col = dst_col - _w1_col
2023+
cast_w1_weight_shard = torch.nn.functional.pad(
2024+
cast_w1_weight_shard, (0, _pad_col, 0, _pad_row), "constant", 0)
2025+
cast_w3_weight_shard = torch.nn.functional.pad(
2026+
cast_w3_weight_shard, (0, _pad_col, 0, _pad_row), "constant", 0)
2027+
2028+
cast_w31_weight_shard = torch.cat(
2029+
[cast_w3_weight_shard, cast_w1_weight_shard], dim=0)
2030+
dst_w3_w1_weight.copy_(cast_w31_weight_shard, non_blocking=True)
2031+
2032+
def load_expert_w2_weight(self, module: torch.nn.Module,
2033+
w2_weight: torch.Tensor,
2034+
dst_w2_weight: torch.Tensor):
2035+
"""Load and pad w2 weight for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
2036+
device = dst_w2_weight.device
2037+
w2_weight_shard = load_weight_shard(w2_weight,
2038+
module.tp_size,
2039+
module.tp_rank,
2040+
TensorParallelMode.ROW,
2041+
device=device)
2042+
cast_w2_weight_shard = w2_weight_shard.view(dst_w2_weight.dtype)
2043+
2044+
dst_row, dst_col = dst_w2_weight.shape
2045+
_row, _col = cast_w2_weight_shard.shape
2046+
if _row != dst_row or _col != dst_col:
2047+
cast_w2_weight_shard = torch.nn.functional.pad(
2048+
cast_w2_weight_shard,
2049+
(0, dst_col - _col, 0,
2050+
dst_row - _row), # (left, right, top, bottom)
2051+
"constant",
2052+
0)
2053+
2054+
dst_w2_weight.copy_(cast_w2_weight_shard, non_blocking=True)
2055+
18932056

18942057
class NVFP4TRTLLMGenFusedMoEMethod(NVFP4FusedMoEMethod):
18952058
weight_dtype = float4_sf_dtype

0 commit comments

Comments
 (0)