Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tensorrt_llm/_torch/modules/fused_moe/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def __init__(
self.parallel_rank = self.mapping.tp_rank
self.parallel_size = self.mapping.tp_size
self.intermediate_size_per_partition = intermediate_size // self.tp_size
self.expand_intermediate_size_per_partition = self.intermediate_size_per_partition * self.intermediate_size_expand_ratio

self.all_reduce = None
if not self.use_dp and self.mapping.tp_size > 1:
Expand Down
221 changes: 172 additions & 49 deletions tensorrt_llm/_torch/modules/fused_moe/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,9 +219,9 @@ def create_weights(
# bias
if module.bias:
if w3_w1_bias_shape is None:
w3_w1_bias_shape = (module.expert_size_per_partition,
module.intermediate_size_per_partition *
module.intermediate_size_expand_ratio)
w3_w1_bias_shape = (
module.expert_size_per_partition,
module.expand_intermediate_size_per_partition)
if w2_bias_shape is None:
w2_bias_shape = (module.expert_size_per_partition,
module.hidden_size)
Expand Down Expand Up @@ -518,8 +518,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase):
def create_weights(self, module: torch.nn.Module):
weight_dtype = module.dtype
w3_w1_weight_shape = (module.expert_size_per_partition,
module.intermediate_size_per_partition *
module.intermediate_size_expand_ratio,
module.expand_intermediate_size_per_partition,
module.hidden_size)
w2_weight_shape = (
module.expert_size_per_partition,
Expand Down Expand Up @@ -584,7 +583,7 @@ def requantize_expert_w3_w1_weight_fp8_qdq(module: torch.nn.Module,
w3_weight_scale = w3_weight_scale[...].reshape([])
max_w3_w1_weight_scale = max(w1_weight_scale, w3_weight_scale)

split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2
split_length = module.expand_intermediate_size_per_partition // 2
w3_weight = dst_w3_w1_weight.narrow(
dim=0, start=0, length=split_length).to(dtype=module.dtype)
w1_weight = dst_w3_w1_weight.narrow(
Expand All @@ -608,8 +607,7 @@ def create_weights(self, module: torch.nn.Module):
weight_dtype = torch.float8_e4m3fn

w3_w1_weight_shape = (module.expert_size_per_partition,
module.intermediate_size_per_partition *
module.intermediate_size_expand_ratio,
module.expand_intermediate_size_per_partition,
module.hidden_size)
w2_weight_shape = (
module.expert_size_per_partition,
Expand Down Expand Up @@ -1658,6 +1656,38 @@ class NVFP4FusedMoEMethod(FusedMoEMethodBase):
Base class for NVFP4 fused MoE methods for all backends.
"""

def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
block_scales_vec_size: int):
# Divide by 16 because we use int64 to pack 16 fp4 values
w3_w1_weight_shape = (module.expert_size_per_partition,
module.expand_intermediate_size_per_partition,
module.hidden_size // weight_vec_size)
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
module.intermediate_size_per_partition //
weight_vec_size)

w3_w1_weight_scale_shape = (
module.expert_size_per_partition,
module.expand_intermediate_size_per_partition, module.hidden_size //
module.scaling_vector_size // block_scales_vec_size)
w2_weight_scale_shape = (module.expert_size_per_partition,
module.hidden_size,
module.intermediate_size_per_partition //
module.scaling_vector_size //
block_scales_vec_size)

if module.bias:
w3_w1_bias_shape = (module.expert_size_per_partition,
module.expand_intermediate_size_per_partition)
w2_bias_shape = (module.expert_size_per_partition,
module.hidden_size)
else:
w3_w1_bias_shape = None
w2_bias_shape = None

return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)

def create_weights(self,
module: torch.nn.Module,
weight_dtype,
Expand All @@ -1667,35 +1697,23 @@ def create_weights(self,
scaling_vector_size=16):

module.scaling_vector_size = scaling_vector_size
# Divide by 16 because we use int64 to pack 16 fp4 values
w3_w1_weight_shape = (module.expert_size_per_partition,
module.intermediate_size_per_partition *
module.intermediate_size_expand_ratio,
module.hidden_size // weight_vec_size)
w2_weight_shape = (module.expert_size_per_partition, module.hidden_size,
module.intermediate_size_per_partition //
weight_vec_size)

(w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape, w2_bias_shape,
w3_w1_weight_scale_shape,
w2_weight_scale_shape) = self.get_weights_shapes(
module, weight_vec_size, block_scales_vec_size)

# Divide by 4 because we use int32 to pack 4 fp8 values
# column parallel
w3_w1_weight_scale = nn.Parameter(
torch.ones(module.expert_size_per_partition,
module.intermediate_size_per_partition *
module.intermediate_size_expand_ratio,
module.hidden_size // module.scaling_vector_size //
block_scales_vec_size,
dtype=block_scales_dtype),
requires_grad=False)
w3_w1_weight_scale = nn.Parameter(torch.ones(w3_w1_weight_scale_shape,
dtype=block_scales_dtype),
requires_grad=False)
module.register_parameter("w3_w1_weight_scale", w3_w1_weight_scale)

# row parallel
w2_weight_scale = nn.Parameter(
torch.ones(module.expert_size_per_partition,
module.hidden_size,
module.intermediate_size_per_partition //
module.scaling_vector_size // block_scales_vec_size,
dtype=block_scales_dtype),
requires_grad=False)
w2_weight_scale = nn.Parameter(torch.ones(w2_weight_scale_shape,
dtype=block_scales_dtype),
requires_grad=False)
module.register_parameter("w2_weight_scale", w2_weight_scale)

fc31_input_scale = nn.Parameter(torch.tensor(1., dtype=torch.float32),
Expand All @@ -1716,8 +1734,12 @@ def create_weights(self,
requires_grad=False)
module.register_parameter("fc2_alpha", fc2_alpha)

super().create_weights(module, weight_dtype, w3_w1_weight_shape,
w2_weight_shape)
super().create_weights(module,
weight_dtype,
w3_w1_weight_shape=w3_w1_weight_shape,
w2_weight_shape=w2_weight_shape,
w3_w1_bias_shape=w3_w1_bias_shape,
w2_bias_shape=w2_bias_shape)

self.setup_quant_scales(module)

Expand Down Expand Up @@ -1926,6 +1948,55 @@ def setup_quant_scales(self, module: torch.nn.Module):
class NVFP4CutlassFusedMoEMethod(NVFP4FusedMoEMethod):
weight_dtype = FUSED_MOE_NVFP4_WEIGHT_DTYPE
block_scales_dtype = FUSED_MOE_NVFP4_WEIGHT_BLOCK_SCALE_DTYPE
NVFP4_ROW_ALIGNMENT = 128
NVFP4_COL_ALIGNMENT = 4

def get_weights_shapes(self, module: torch.nn.Module, weight_vec_size: int,
block_scales_vec_size: int):
"""Override the base method to get aligned weights shapes for Cutlass nvfp4 alignment."""
intermediate_size_expand_aligned = (
module.expand_intermediate_size_per_partition +
self.NVFP4_ROW_ALIGNMENT -
1) // self.NVFP4_ROW_ALIGNMENT * self.NVFP4_ROW_ALIGNMENT

if module.hidden_size % self.NVFP4_COL_ALIGNMENT != 0:
raise ValueError(
f"hidden_size {module.hidden_size} must be divisible by {self.NVFP4_COL_ALIGNMENT}"
)
hidden_size_aligned = module.hidden_size

w3_w1_weight_shape = (module.expert_size_per_partition,
intermediate_size_expand_aligned,
hidden_size_aligned // weight_vec_size)
w2_weight_shape = (module.expert_size_per_partition,
hidden_size_aligned,
intermediate_size_expand_aligned //
module.intermediate_size_expand_ratio //
weight_vec_size)

w3_w1_weight_scale_shape = (module.expert_size_per_partition,
intermediate_size_expand_aligned,
hidden_size_aligned //
module.scaling_vector_size //
block_scales_vec_size)
w2_weight_scale_shape = (module.expert_size_per_partition,
hidden_size_aligned,
intermediate_size_expand_aligned //
module.intermediate_size_expand_ratio //
module.scaling_vector_size //
block_scales_vec_size)

if module.bias:
w3_w1_bias_shape = (module.expert_size_per_partition,
intermediate_size_expand_aligned)
w2_bias_shape = (module.expert_size_per_partition,
hidden_size_aligned)
else:
w3_w1_bias_shape = None
w2_bias_shape = None

return (w3_w1_weight_shape, w2_weight_shape, w3_w1_bias_shape,
w2_bias_shape, w3_w1_weight_scale_shape, w2_weight_scale_shape)

def create_weights(self, module: torch.nn.Module):
weight_vec_size = torch.iinfo(self.weight_dtype).bits // 4
Expand All @@ -1950,21 +2021,16 @@ def load_expert_w3_w1_weight_scale_nvfp4(
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
# Keep weights in device buffer
# w3
split_length = module.intermediate_size_per_partition * module.intermediate_size_expand_ratio // 2
dst_w3_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0,
start=0,
length=split_length)
dst_w3_weight_scale.copy_(
w3_weight_scale.view(dst_w3_weight_scale.dtype))

# w1
dst_w1_weight_scale = dst_w3_w1_weight_scale.narrow(dim=0,
start=split_length,
length=split_length)
dst_w1_weight_scale.copy_(
w1_weight_scale.view(dst_w1_weight_scale.dtype))
cast_w3_weight_scale = w3_weight_scale.view(
dst_w3_w1_weight_scale.dtype)
cast_w1_weight_scale = w1_weight_scale.view(
dst_w3_w1_weight_scale.dtype)
cast_w31_weight_scale = torch.cat(
[cast_w3_weight_scale, cast_w1_weight_scale], dim=0)
cast_w31_weight_scale = self._maybe_padding_shape(
cast_w31_weight_scale, dst_w3_w1_weight_scale)
dst_w3_w1_weight_scale.copy_(cast_w31_weight_scale)

orig_shape = dst_w3_w1_weight_scale.shape

Expand All @@ -1986,9 +2052,12 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,
module.tp_rank,
TensorParallelMode.ROW,
device=device)

cast_w2_weight_scale = w2_weight_scale.view(dst_w2_weight_scale.dtype)
cast_w2_weight_scale = self._maybe_padding_shape(
cast_w2_weight_scale, dst_w2_weight_scale)
# Keep weights in device buffer
dst_w2_weight_scale.copy_(
w2_weight_scale.view(dst_w2_weight_scale.dtype))
dst_w2_weight_scale.copy_(cast_w2_weight_scale)

orig_shape = dst_w2_weight_scale.shape

Expand All @@ -2000,6 +2069,60 @@ def load_expert_w2_weight_scale_nvfp4(self, module: torch.nn.Module,

dst_w2_weight_scale.copy_(dst_w2_weight_scale_interleaved)

def load_expert_w3_w1_weight(self, module: torch.nn.Module,
w1_weight: torch.Tensor,
w3_weight: torch.Tensor,
dst_w3_w1_weight: torch.Tensor):
"""Load and pad w1 and w3 weights for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
device = dst_w3_w1_weight.device
w1_weight_shard = load_weight_shard(w1_weight,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)
w3_weight_shard = load_weight_shard(w3_weight,
module.tp_size,
module.tp_rank,
TensorParallelMode.COLUMN,
device=device)

cast_w1_weight_shard = w1_weight_shard.view(dst_w3_w1_weight.dtype)
cast_w3_weight_shard = w3_weight_shard.view(dst_w3_w1_weight.dtype)
cast_w31_weight_shard = torch.cat(
[cast_w3_weight_shard, cast_w1_weight_shard], dim=0)
cast_w31_weight_shard = self._maybe_padding_shape(
cast_w31_weight_shard, dst_w3_w1_weight)
dst_w3_w1_weight.copy_(cast_w31_weight_shard, non_blocking=True)

def load_expert_w2_weight(self, module: torch.nn.Module,
w2_weight: torch.Tensor,
dst_w2_weight: torch.Tensor):
"""Load and pad w2 weight for each expert, to match shape requirements for Cutlass nvfp4 alignment."""
device = dst_w2_weight.device
w2_weight_shard = load_weight_shard(w2_weight,
module.tp_size,
module.tp_rank,
TensorParallelMode.ROW,
device=device)
cast_w2_weight_shard = w2_weight_shard.view(dst_w2_weight.dtype)
cast_w2_weight_shard = self._maybe_padding_shape(
cast_w2_weight_shard, dst_w2_weight)
dst_w2_weight.copy_(cast_w2_weight_shard, non_blocking=True)

def _maybe_padding_shape(self, source_tensor, dst_tensor):
"""Pad the source tensor to match the shape of the destination tensor."""
# In `get_weights_shapes` method, the shape of `weights` and `weight_scales` might be tuned to align with `NVFP4_ROW_ALIGNMENT`.
# Padding the `source_tensor` to match the shape of `dst_tensor` here.
assert len(source_tensor.shape) == 2 and len(
dst_tensor.shape) == 2, "Only support 2D weights padding for now."
dst_row, dst_col = dst_tensor.shape
_row, _col = source_tensor.shape
if _row != dst_row or _col != dst_col:
source_tensor = torch.nn.functional.pad(
source_tensor, (0, dst_col - _col, 0, dst_row - _row),
"constant", 0).contiguous()
return source_tensor


class NVFP4CuteDslFusedMoEMethod(NVFP4CutlassFusedMoEMethod):

Expand Down
2 changes: 2 additions & 0 deletions tests/integration/test_lists/test-db/l0_a10.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ l0_a10:
- unittest/utils/test_util.py
- unittest/_torch/modeling/test_modeling_mistral.py
- unittest/_torch/modeling/test_modeling_pixtral.py
- unittest/_torch/modules/test_fused_moe.py::test_nvfp4_cutlass_get_weights_shapes
- unittest/_torch/modules/test_fused_moe.py::test_nvfp4_cutlass_get_weights_shapes_error_cases
- unittest/_torch/sampler/test_trtllm_sampler.py
# NOTE: this is a CPU-only test, but we do not have a dedicated job for this (and therefore no
# test list either).
Expand Down
Loading