Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported |
| `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention |
| `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation |
| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_fused_linear_all_reduce` | Fused linear layer followed by all-reduce (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation |
| `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation |
| `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values |
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce (PyTorch backend, demollm mode) |
| `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer |
| `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer |
| `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies |
Expand All @@ -38,4 +39,8 @@ The table below lists the operators ordered by their backend.
| `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs |
| `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions |
| `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation |
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT LLM fused linear layer followed by all-reduce operation |
| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) |
| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) |
| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | Fused linear layer followed by all-reduce (TRT-LLM backend, MPI mode) |
| `torch.ops.auto_deploy.trtllm_dist_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce (TRT-LLM backend, MPI mode) |
| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) |
74 changes: 61 additions & 13 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
"""Custom ops required for implementing tensor parallelism."""
"""Custom ops required for implementing tensor parallelism.

This module defines atomic distributed ops - each op uses a specific backend
(torch.distributed or TRT-LLM) without internal dispatch logic.
"""

from typing import List, Optional

Expand All @@ -7,38 +11,82 @@
from ..distributed import common as dist
from ..distributed import trtllm as trtllm_dist

# ============================================================================
# PyTorch Distributed Backend Ops (demollm mode)
# ============================================================================


@torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda")
def all_gather(
def torch_dist_all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""All gather followed by concat in dim = 0. This is the default nccl behavior."""
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes)
"""All gather using PyTorch distributed backend.

This op always uses torch.distributed.all_gather and is used in demollm mode.
"""
tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())]
dist.all_gather(tl, tensor)
return torch.cat(tl, dim=dim)


@all_gather.register_fake
def all_gather_fake(tensor, dim=0):
@torch_dist_all_gather.register_fake
def torch_dist_all_gather_fake(tensor, dim=0, sizes=None):
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)


@torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda")
def all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce across the ranks. Reduction op is SUM.
def torch_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce using PyTorch distributed backend. Reduction op is SUM.

This op always uses torch.distributed.all_reduce and is used in demollm mode.

NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For
efficient all_reduce ops one should write/replace it with a fused op.
"""
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)
t_res = t.clone()
dist.all_reduce(t_res, op=dist.ReduceOp.SUM)
return t_res


@all_reduce.register_fake
def all_reduce_fake(tensor):
@torch_dist_all_reduce.register_fake
def torch_dist_all_reduce_fake(tensor):
return torch.empty_like(tensor)


# ============================================================================
# TRT-LLM Backend Ops (MPI mode)
# ============================================================================


@torch.library.custom_op(
"auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda"
)
def trtllm_dist_all_gather(
tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None
) -> torch.Tensor:
"""All gather using TRT-LLM optimized backend.

This op always uses TRT-LLM's optimized allgather and is used in MPI mode.
"""
return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes)


@trtllm_dist_all_gather.register_fake
def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None):
return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim)


@torch.library.custom_op(
"auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda"
)
def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor:
"""All_reduce using TRT-LLM optimized backend. Reduction op is SUM.

This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
"""
return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM)


@trtllm_dist_all_reduce.register_fake
def trtllm_dist_all_reduce_fake(tensor):
return torch.empty_like(tensor)
42 changes: 32 additions & 10 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,48 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso
@simple.register_fake
def simple_fake(input, weight, bias):
"""Fake implementation of simple_linear."""
# return torch.empty(
# input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device
# )
return torch.ops.aten.linear(input, weight, bias)


# ============================================================================
# Fused Linear + AllReduce Ops (Atomic - Backend Specific)
# ============================================================================


@torch.library.custom_op(
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
"auto_deploy::torch_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def fused_linear_all_reduce(
def torch_fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
"""Fused linear followed by all_reduce on the output."""
"""Fused linear + all_reduce using PyTorch backend.

This op always uses torch.distributed and is used in demollm mode.
"""
output = torch.ops.aten.linear(input, weight, bias)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)
dist.all_reduce(output, op=dist.ReduceOp.SUM)
return output


@fused_linear_all_reduce.register_fake
def fused_linear_all_reduce_fake(input, weight, bias):
@torch_fused_linear_all_reduce.register_fake
def torch_fused_linear_all_reduce_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)


@torch.library.custom_op(
"auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda"
)
def trtllm_dist_fused_linear_all_reduce(
input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor]
) -> torch.Tensor:
"""Fused linear + all_reduce using TRT-LLM backend.

This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
"""
output = torch.ops.aten.linear(input, weight, bias)
return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM)


@trtllm_dist_fused_linear_all_reduce.register_fake
def trtllm_dist_fused_linear_all_reduce_fake(input, weight, bias):
return torch.ops.aten.linear(input, weight, bias)
49 changes: 44 additions & 5 deletions tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,26 +240,65 @@ def fp8_linear_fake(
return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias)


# ============================================================================
# Fused FP8 Linear + AllReduce Ops (Atomic - Backend Specific)
# ============================================================================


@torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=())
@torch.compile(dynamic=True)
def fused_fp8_linear_all_reduce(
def torch_quant_fused_fp8_linear_all_reduce(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Fused FP8 linear + all_reduce using PyTorch backend.

This op always uses torch.distributed and is used in demollm mode.
"""
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
if trtllm_dist.is_trtllm_op_available():
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)
dist.all_reduce(out, op=dist.ReduceOp.SUM)
return out


@fused_fp8_linear_all_reduce.register_fake
def fused_fp8_linear_all_reduce_fake(
@torch_quant_fused_fp8_linear_all_reduce.register_fake
def torch_quant_fused_fp8_linear_all_reduce_fake(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)


@torch.library.custom_op("auto_deploy::trtllm_dist_fused_fp8_linear_all_reduce", mutates_args=())
@torch.compile(dynamic=True)
def trtllm_dist_fused_fp8_linear_all_reduce(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
input_scale: Optional[torch.Tensor] = None,
weight_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Fused FP8 linear + all_reduce using TRT-LLM backend.

This op always uses TRT-LLM's optimized allreduce and is used in MPI mode.
"""
out = torch.ops.auto_deploy.torch_quant_fp8_linear(
input, weight_fp8, bias, input_scale, weight_scale
)
return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM)


@trtllm_dist_fused_fp8_linear_all_reduce.register_fake
def trtllm_dist_fused_fp8_linear_all_reduce_fake(
input: torch.Tensor,
weight_fp8: torch.Tensor,
bias: Optional[torch.Tensor] = None,
Expand Down
65 changes: 24 additions & 41 deletions tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
"""TRT-LLM distributed operations and fused kernels.

This module defines atomic TRT-LLM-specific ops that use optimized kernels.
The torch fallback variants are defined separately to enable multi-pattern matching.
"""

import torch

from .common import ReduceOp, get_rank_world_size, is_ompi
Expand Down Expand Up @@ -34,51 +40,28 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None):
torch_op = _allreduce_cache[cache_key]
return torch_op(tensor, all_reduce_params=all_reduce_params)

# TRT-LLM fused op (atomic - always uses TRT-LLM backend)
@torch.library.custom_op(
"dist::fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
"dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda"
)
def fused_allreduce_residual_rmsnorm(
def trtllm_fused_allreduce_residual_rmsnorm(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
"""Fusing allreduce, residual (add), and hf_rms_norm together.
"""Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel.

When TRT-LLM ops are available (MPI mode), uses the fused kernel.
Otherwise, falls back to separate operations using torch distributed.
This op always uses TRT-LLM's fused kernel and is used in MPI mode.
"""
# Only use TRT-LLM fused op when running with MPI
if is_trtllm_op_available():
all_reduce_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
bias=None,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)
else:
# Fallback: unfused implementation using torch distributed
# This is used in demollm mode without MPI
from .common import all_reduce as torch_all_reduce

# 1. All-reduce the tensor
tensor_reduced = tensor.clone()
torch_all_reduce(tensor_reduced, op=ReduceOp.SUM)

# 2. Add residual
tensor_with_residual = tensor_reduced + residual

# 3. Apply RMSNorm using PyTorch's built-in function
norm_out = torch.nn.functional.rms_norm(
tensor_with_residual,
normalized_shape=(tensor_with_residual.size(-1),),
weight=norm_weight,
eps=eps,
)

return norm_out, tensor_with_residual

@fused_allreduce_residual_rmsnorm.register_fake
def fused_allreduce_residual_rmsnorm_fake(
all_reduce_params = AllReduceParams(
fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM,
bias=None,
residual=residual,
norm_weight=norm_weight,
eps=eps,
)
return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params)

@trtllm_fused_allreduce_residual_rmsnorm.register_fake
def trtllm_fused_allreduce_residual_rmsnorm_fake(
tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float
) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(tensor), torch.empty_like(tensor)
Expand All @@ -89,12 +72,12 @@ def fused_allreduce_residual_rmsnorm_fake(
def trtllm_allgather(tensor, dim, sizes=None):
raise ImportError("TRT-LLM is not available.")

def trtllm_allreduce(tensor, op):
def trtllm_allreduce(tensor, op, all_reduce_params=None):
raise ImportError("TRT-LLM is not available.")

TRTLLM_OP_AVAILABLE = False


def is_trtllm_op_available():
# TRT-LLM only work with MPI
"""Check if TRT-LLM ops are available and running with MPI."""
return TRTLLM_OP_AVAILABLE and is_ompi()
Loading