diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md index 5a65d610493..16b19f95e3d 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/README.md @@ -17,13 +17,14 @@ The table below lists the operators ordered by their backend. | `torch.ops.auto_deploy.torch_attention` | Grouped SDPA implementation with `bsnd` and `bnsd` layout supported | | `torch.ops.auto_deploy.torch_attention_repeat_kv` | KV repetition for attention | | `torch.ops.auto_deploy.torch_attention_sdpa` | Standard SDPA implementation | -| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation | -| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation | +| `torch.ops.auto_deploy.torch_dist_all_gather` | Distributed all-gather operation (PyTorch backend, demollm mode) | +| `torch.ops.auto_deploy.torch_dist_all_reduce` | Distributed all-reduce operation (PyTorch backend, demollm mode) | +| `torch.ops.auto_deploy.torch_fused_linear_all_reduce` | Fused linear layer followed by all-reduce (PyTorch backend, demollm mode) | | `torch.ops.auto_deploy.torch_linear_simple` | Simple linear layer implementation | | `torch.ops.auto_deploy.torch_moe` | Mixture of Experts implementation | | `torch.ops.auto_deploy.torch_moe_fused` | Fused Mixture of Experts implementation | | `torch.ops.auto_deploy.torch_quant_fn` | Generic quantization function that scales, rounds, and clamps input values | -| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce operation | +| `torch.ops.auto_deploy.torch_quant_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce (PyTorch backend, demollm mode) | | `torch.ops.auto_deploy.torch_quant_nvfp4_linear` | FP4 quantized linear layer | | `torch.ops.auto_deploy.torch_quant_fp8_linear` | FP8 quantized linear layer | | `torch.ops.auto_deploy.torch_rope_with_complex_freqs` | RoPE with complex frequencies | @@ -38,4 +39,8 @@ The table below lists the operators ordered by their backend. | `torch.ops.auto_deploy.triton_rope_on_flattened_inputs` | Triton RoPE on flattened inputs | | `torch.ops.auto_deploy.triton_rope_with_input_pos` | Triton RoPE with input positions | | `torch.ops.auto_deploy.trtllm_moe_fused` | TensorRT LLM fused MoE implementation | -| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | TensorRT LLM fused linear layer followed by all-reduce operation | +| `torch.ops.auto_deploy.trtllm_dist_all_gather` | Distributed all-gather operation (TRT-LLM backend, MPI mode) | +| `torch.ops.auto_deploy.trtllm_dist_all_reduce` | Distributed all-reduce operation (TRT-LLM backend, MPI mode) | +| `torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce` | Fused linear layer followed by all-reduce (TRT-LLM backend, MPI mode) | +| `torch.ops.auto_deploy.trtllm_dist_fused_fp8_linear_all_reduce` | Fused FP8 linear layer followed by all-reduce (TRT-LLM backend, MPI mode) | +| `torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm` | Fused all-reduce + residual add + RMSNorm (TRT-LLM backend, MPI mode) | diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py index d6f13fbedd7..e94b039a05b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/dist.py @@ -1,4 +1,8 @@ -"""Custom ops required for implementing tensor parallelism.""" +"""Custom ops required for implementing tensor parallelism. + +This module defines atomic distributed ops - each op uses a specific backend +(torch.distributed or TRT-LLM) without internal dispatch logic. +""" from typing import List, Optional @@ -7,38 +11,82 @@ from ..distributed import common as dist from ..distributed import trtllm as trtllm_dist +# ============================================================================ +# PyTorch Distributed Backend Ops (demollm mode) +# ============================================================================ + @torch.library.custom_op("auto_deploy::torch_dist_all_gather", mutates_args=(), device_types="cuda") -def all_gather( +def torch_dist_all_gather( tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None ) -> torch.Tensor: - """All gather followed by concat in dim = 0. This is the default nccl behavior.""" - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes) + """All gather using PyTorch distributed backend. + + This op always uses torch.distributed.all_gather and is used in demollm mode. + """ tl = [torch.zeros_like(tensor) for _ in range(dist.get_world_size())] dist.all_gather(tl, tensor) return torch.cat(tl, dim=dim) -@all_gather.register_fake -def all_gather_fake(tensor, dim=0): +@torch_dist_all_gather.register_fake +def torch_dist_all_gather_fake(tensor, dim=0, sizes=None): return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim) @torch.library.custom_op("auto_deploy::torch_dist_all_reduce", mutates_args=(), device_types="cuda") -def all_reduce(t: torch.Tensor) -> torch.Tensor: - """All_reduce across the ranks. Reduction op is SUM. +def torch_dist_all_reduce(t: torch.Tensor) -> torch.Tensor: + """All_reduce using PyTorch distributed backend. Reduction op is SUM. + + This op always uses torch.distributed.all_reduce and is used in demollm mode. NOTE: this op requires an extra memory copy and should ONLY be used for debugging + testing. For efficient all_reduce ops one should write/replace it with a fused op. """ - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM) t_res = t.clone() dist.all_reduce(t_res, op=dist.ReduceOp.SUM) return t_res -@all_reduce.register_fake -def all_reduce_fake(tensor): +@torch_dist_all_reduce.register_fake +def torch_dist_all_reduce_fake(tensor): + return torch.empty_like(tensor) + + +# ============================================================================ +# TRT-LLM Backend Ops (MPI mode) +# ============================================================================ + + +@torch.library.custom_op( + "auto_deploy::trtllm_dist_all_gather", mutates_args=(), device_types="cuda" +) +def trtllm_dist_all_gather( + tensor: torch.Tensor, dim: int = 0, sizes: Optional[List[int]] = None +) -> torch.Tensor: + """All gather using TRT-LLM optimized backend. + + This op always uses TRT-LLM's optimized allgather and is used in MPI mode. + """ + return trtllm_dist.trtllm_allgather(tensor, dim=dim, sizes=sizes) + + +@trtllm_dist_all_gather.register_fake +def trtllm_dist_all_gather_fake(tensor, dim=0, sizes=None): + return torch.cat([torch.empty_like(tensor) for _ in range(dist.get_world_size())], dim=dim) + + +@torch.library.custom_op( + "auto_deploy::trtllm_dist_all_reduce", mutates_args=(), device_types="cuda" +) +def trtllm_dist_all_reduce(t: torch.Tensor) -> torch.Tensor: + """All_reduce using TRT-LLM optimized backend. Reduction op is SUM. + + This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. + """ + return trtllm_dist.trtllm_allreduce(t, op=dist.ReduceOp.SUM) + + +@trtllm_dist_all_reduce.register_fake +def trtllm_dist_all_reduce_fake(tensor): return torch.empty_like(tensor) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py index fda48e4ba57..8a1cfb5bcfc 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/linear.py @@ -24,26 +24,48 @@ def simple(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tenso @simple.register_fake def simple_fake(input, weight, bias): """Fake implementation of simple_linear.""" - # return torch.empty( - # input.shape[:-1] + (weight.shape[-1],), dtype=input.dtype, device=input.device - # ) return torch.ops.aten.linear(input, weight, bias) +# ============================================================================ +# Fused Linear + AllReduce Ops (Atomic - Backend Specific) +# ============================================================================ + + @torch.library.custom_op( - "auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda" + "auto_deploy::torch_fused_linear_all_reduce", mutates_args=(), device_types="cuda" ) -def fused_linear_all_reduce( +def torch_fused_linear_all_reduce( input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] ) -> torch.Tensor: - """Fused linear followed by all_reduce on the output.""" + """Fused linear + all_reduce using PyTorch backend. + + This op always uses torch.distributed and is used in demollm mode. + """ output = torch.ops.aten.linear(input, weight, bias) - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM) dist.all_reduce(output, op=dist.ReduceOp.SUM) return output -@fused_linear_all_reduce.register_fake -def fused_linear_all_reduce_fake(input, weight, bias): +@torch_fused_linear_all_reduce.register_fake +def torch_fused_linear_all_reduce_fake(input, weight, bias): + return torch.ops.aten.linear(input, weight, bias) + + +@torch.library.custom_op( + "auto_deploy::trtllm_dist_fused_linear_all_reduce", mutates_args=(), device_types="cuda" +) +def trtllm_dist_fused_linear_all_reduce( + input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] +) -> torch.Tensor: + """Fused linear + all_reduce using TRT-LLM backend. + + This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. + """ + output = torch.ops.aten.linear(input, weight, bias) + return trtllm_dist.trtllm_allreduce(output, op=dist.ReduceOp.SUM) + + +@trtllm_dist_fused_linear_all_reduce.register_fake +def trtllm_dist_fused_linear_all_reduce_fake(input, weight, bias): return torch.ops.aten.linear(input, weight, bias) diff --git a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py index 90ea04db862..41ac73a9b3b 100644 --- a/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py +++ b/tensorrt_llm/_torch/auto_deploy/custom_ops/quant.py @@ -240,26 +240,65 @@ def fp8_linear_fake( return torch.ops.aten.linear(input, weight_fp8.to(input.dtype), bias) +# ============================================================================ +# Fused FP8 Linear + AllReduce Ops (Atomic - Backend Specific) +# ============================================================================ + + @torch.library.custom_op("auto_deploy::torch_quant_fused_fp8_linear_all_reduce", mutates_args=()) @torch.compile(dynamic=True) -def fused_fp8_linear_all_reduce( +def torch_quant_fused_fp8_linear_all_reduce( input: torch.Tensor, weight_fp8: torch.Tensor, bias: Optional[torch.Tensor] = None, input_scale: Optional[torch.Tensor] = None, weight_scale: Optional[torch.Tensor] = None, ) -> torch.Tensor: + """Fused FP8 linear + all_reduce using PyTorch backend. + + This op always uses torch.distributed and is used in demollm mode. + """ out = torch.ops.auto_deploy.torch_quant_fp8_linear( input, weight_fp8, bias, input_scale, weight_scale ) - if trtllm_dist.is_trtllm_op_available(): - return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM) dist.all_reduce(out, op=dist.ReduceOp.SUM) return out -@fused_fp8_linear_all_reduce.register_fake -def fused_fp8_linear_all_reduce_fake( +@torch_quant_fused_fp8_linear_all_reduce.register_fake +def torch_quant_fused_fp8_linear_all_reduce_fake( + input: torch.Tensor, + weight_fp8: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.ops.auto_deploy.torch_quant_fp8_linear( + input, weight_fp8, bias, input_scale, weight_scale + ) + + +@torch.library.custom_op("auto_deploy::trtllm_dist_fused_fp8_linear_all_reduce", mutates_args=()) +@torch.compile(dynamic=True) +def trtllm_dist_fused_fp8_linear_all_reduce( + input: torch.Tensor, + weight_fp8: torch.Tensor, + bias: Optional[torch.Tensor] = None, + input_scale: Optional[torch.Tensor] = None, + weight_scale: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """Fused FP8 linear + all_reduce using TRT-LLM backend. + + This op always uses TRT-LLM's optimized allreduce and is used in MPI mode. + """ + out = torch.ops.auto_deploy.torch_quant_fp8_linear( + input, weight_fp8, bias, input_scale, weight_scale + ) + return trtllm_dist.trtllm_allreduce(out, op=dist.ReduceOp.SUM) + + +@trtllm_dist_fused_fp8_linear_all_reduce.register_fake +def trtllm_dist_fused_fp8_linear_all_reduce_fake( input: torch.Tensor, weight_fp8: torch.Tensor, bias: Optional[torch.Tensor] = None, diff --git a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py index 434cc1693eb..fde4f78994b 100644 --- a/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py +++ b/tensorrt_llm/_torch/auto_deploy/distributed/trtllm.py @@ -1,3 +1,9 @@ +"""TRT-LLM distributed operations and fused kernels. + +This module defines atomic TRT-LLM-specific ops that use optimized kernels. +The torch fallback variants are defined separately to enable multi-pattern matching. +""" + import torch from .common import ReduceOp, get_rank_world_size, is_ompi @@ -34,51 +40,28 @@ def trtllm_allreduce(tensor, op, all_reduce_params=None): torch_op = _allreduce_cache[cache_key] return torch_op(tensor, all_reduce_params=all_reduce_params) + # TRT-LLM fused op (atomic - always uses TRT-LLM backend) @torch.library.custom_op( - "dist::fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" + "dist::trtllm_fused_allreduce_residual_rmsnorm", mutates_args=(), device_types="cuda" ) - def fused_allreduce_residual_rmsnorm( + def trtllm_fused_allreduce_residual_rmsnorm( tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float ) -> tuple[torch.Tensor, torch.Tensor]: - """Fusing allreduce, residual (add), and hf_rms_norm together. + """Fused allreduce + residual + rmsnorm using TRT-LLM optimized kernel. - When TRT-LLM ops are available (MPI mode), uses the fused kernel. - Otherwise, falls back to separate operations using torch distributed. + This op always uses TRT-LLM's fused kernel and is used in MPI mode. """ - # Only use TRT-LLM fused op when running with MPI - if is_trtllm_op_available(): - all_reduce_params = AllReduceParams( - fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, - bias=None, - residual=residual, - norm_weight=norm_weight, - eps=eps, - ) - return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) - else: - # Fallback: unfused implementation using torch distributed - # This is used in demollm mode without MPI - from .common import all_reduce as torch_all_reduce - - # 1. All-reduce the tensor - tensor_reduced = tensor.clone() - torch_all_reduce(tensor_reduced, op=ReduceOp.SUM) - - # 2. Add residual - tensor_with_residual = tensor_reduced + residual - - # 3. Apply RMSNorm using PyTorch's built-in function - norm_out = torch.nn.functional.rms_norm( - tensor_with_residual, - normalized_shape=(tensor_with_residual.size(-1),), - weight=norm_weight, - eps=eps, - ) - - return norm_out, tensor_with_residual - - @fused_allreduce_residual_rmsnorm.register_fake - def fused_allreduce_residual_rmsnorm_fake( + all_reduce_params = AllReduceParams( + fusion_op=AllReduceFusionOp.RESIDUAL_RMS_NORM, + bias=None, + residual=residual, + norm_weight=norm_weight, + eps=eps, + ) + return trtllm_allreduce(tensor, ReduceOp.SUM, all_reduce_params=all_reduce_params) + + @trtllm_fused_allreduce_residual_rmsnorm.register_fake + def trtllm_fused_allreduce_residual_rmsnorm_fake( tensor: torch.Tensor, residual: torch.Tensor, norm_weight: torch.Tensor, eps: float ) -> tuple[torch.Tensor, torch.Tensor]: return torch.empty_like(tensor), torch.empty_like(tensor) @@ -89,12 +72,12 @@ def fused_allreduce_residual_rmsnorm_fake( def trtllm_allgather(tensor, dim, sizes=None): raise ImportError("TRT-LLM is not available.") - def trtllm_allreduce(tensor, op): + def trtllm_allreduce(tensor, op, all_reduce_params=None): raise ImportError("TRT-LLM is not available.") TRTLLM_OP_AVAILABLE = False def is_trtllm_op_available(): - # TRT-LLM only work with MPI + """Check if TRT-LLM ops are available and running with MPI.""" return TRTLLM_OP_AVAILABLE and is_ompi() diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py index d0ebcd0eec8..2a800efccbf 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/collectives.py @@ -1,3 +1,10 @@ +"""Transformations for fusing collective operations. + +This module registers TRT-LLM backend patterns only. Fusion is only applied +when TRT-LLM is available (MPI mode) since it provides optimized fused kernels. +The torch backend (demollm mode) does not benefit from fusion. +""" + from typing import Tuple import torch @@ -16,63 +23,85 @@ # * ... -def _allreduce_residual_rmsnorm_pattern( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253 -): - """ - Reference PyTorch composition of: - y = all_reduce(x) - z = residual + y - normed = RMSNorm(z, weight, eps) - Returns (normed, z) - """ +# ============================================================================ +# Pattern Template Factory Functions +# ============================================================================ - input_dtype = x.dtype - hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x) - add = residual + hidden_states - hidden_states = add.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + eps) +def _make_allreduce_residual_rmsnorm_pattern(add_order: str = "residual_first"): + """Factory function to create pattern functions for allreduce+residual+rmsnorm fusion. - normed = weight * hidden_states.to(input_dtype) + Args: + add_order: Either "residual_first" (residual + x) or "x_first" (x + residual) - return normed, add + Returns: + A pattern function that can be used with register_ad_pattern + """ + def pattern_fn( + x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253 + ): + """Pattern: trtllm_dist_all_reduce(x) -> add residual -> RMSNorm -def _allreduce_residual_rmsnorm_pattern2( - x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float = 0.1253 -): - """ - Reference PyTorch composition of: - y = all_reduce(x) - z = y + residual - normed = RMSNorm(z, weight, eps) - Returns (normed, z) - """ + Reference PyTorch composition: + y = trtllm_dist_all_reduce(x) + z = residual + y (or y + residual) + normed = RMSNorm(z, weight, eps) + Returns (normed, z) + """ + input_dtype = x.dtype + hidden_states = torch.ops.auto_deploy.trtllm_dist_all_reduce(x) + + # Handle addition order + if add_order == "residual_first": + add = residual + hidden_states + else: # x_first + add = hidden_states + residual - input_dtype = x.dtype - hidden_states = torch.ops.auto_deploy.torch_dist_all_reduce(x) - add = hidden_states + residual + hidden_states = add.to(torch.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + eps) - hidden_states = add.to(torch.float32) - variance = hidden_states.pow(2).mean(-1, keepdim=True) - hidden_states = hidden_states * torch.rsqrt(variance + eps) + normed = weight * hidden_states.to(input_dtype) - normed = weight * hidden_states.to(input_dtype) + return normed, add - return normed, add + return pattern_fn -def _allreduce_residual_rmsnorm_repl( +def _allreduce_residual_rmsnorm_replacement( x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, eps: float ): - return torch.ops.dist.fused_allreduce_residual_rmsnorm(x, residual, weight, eps) + """Replacement using TRT-LLM fused kernel.""" + return torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm(x, residual, weight, eps) + + +# ============================================================================ +# Instantiate Pattern Functions +# ============================================================================ + +# TRT-LLM backend (MPI mode) - two patterns for different addition orders +_allreduce_residual_rmsnorm_pattern_trtllm = _make_allreduce_residual_rmsnorm_pattern( + add_order="residual_first" +) +_allreduce_residual_rmsnorm_pattern2_trtllm = _make_allreduce_residual_rmsnorm_pattern( + add_order="x_first" +) + + +# ============================================================================ +# Transform Implementation +# ============================================================================ @TransformRegistry.register("fuse_allreduce_residual_rmsnorm") class FuseAllreduceResidualRMSNorm(BaseTransform): - """Fuse (allreduce + residual add + RMSNorm) into one fused op with tuple output.""" + """Fuse (allreduce + residual add + RMSNorm) into one fused op with tuple output. + + This transform only applies when TRT-LLM ops are used (MPI mode), as it provides + optimized fused kernels. The torch backend (demollm mode) does not benefit from + this fusion and uses unfused operations. + """ def _apply( self, @@ -92,21 +121,28 @@ def _apply( 0.1253, # eps ] + op_ignore_types = {torch.ops.aten.to.dtype: (torch.dtype,)} + scalar_workaround = {"eps": 0.1253} + + # Register TRT-LLM backend patterns only (no torch backend fusion) + # Pattern 1: residual + allreduce(x) register_ad_pattern( - search_fn=_allreduce_residual_rmsnorm_pattern, - replace_fn=_allreduce_residual_rmsnorm_repl, + search_fn=_allreduce_residual_rmsnorm_pattern_trtllm, + replace_fn=_allreduce_residual_rmsnorm_replacement, patterns=patterns, dummy_args=dummy_args, - op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)}, - scalar_workaround={"eps": 0.1253}, + op_ignore_types=op_ignore_types, + scalar_workaround=scalar_workaround, ) + + # Pattern 2: allreduce(x) + residual register_ad_pattern( - search_fn=_allreduce_residual_rmsnorm_pattern2, - replace_fn=_allreduce_residual_rmsnorm_repl, + search_fn=_allreduce_residual_rmsnorm_pattern2_trtllm, + replace_fn=_allreduce_residual_rmsnorm_replacement, patterns=patterns, dummy_args=dummy_args, - op_ignore_types={torch.ops.aten.to.dtype: (torch.dtype,)}, - scalar_workaround={"eps": 0.1253}, + op_ignore_types=op_ignore_types, + scalar_workaround=scalar_workaround, ) num_matches = patterns.apply(gm.graph) diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py index 1bb99974ac1..7bb5d4c2d20 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/sharding.py @@ -1004,7 +1004,8 @@ def detect_dp_bmm_shard(gm: GraphModule, sharding_config: ShardingConfig) -> Tra base_size = bmm_batch_size // world_size remainder = bmm_batch_size % world_size - # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. + # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather/trtllm_dist_all_gather + # doesn't support uneven splits at the moment. if remainder: ad_logger.warning( f"BMM batch size {bmm_batch_size} is not divisible by world size {world_size}. " diff --git a/tensorrt_llm/_torch/auto_deploy/transform/library/visualization.py b/tensorrt_llm/_torch/auto_deploy/transform/library/visualization.py index 0ed97c3ad24..10c13f9635e 100644 --- a/tensorrt_llm/_torch/auto_deploy/transform/library/visualization.py +++ b/tensorrt_llm/_torch/auto_deploy/transform/library/visualization.py @@ -77,6 +77,7 @@ def add_outputs_metadata(self, fx_node: torch.fx.node.Node, node: GraphNode): # TODO(yudong): make custom_ops configurable CUSTOM_OPS = ( torch.ops.auto_deploy.torch_dist_all_reduce.default, + torch.ops.auto_deploy.trtllm_dist_all_reduce.default, torch.ops.aten.slice.Tensor, torch.ops.auto_deploy.triton_attention_fused_mha_with_cache.default, torch.ops.auto_deploy.trtllm_dist_fused_linear_all_reduce.default, diff --git a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py index 33c31a9ba22..9cdb20422fb 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/node_utils.py @@ -305,10 +305,14 @@ def is_bmm_op(node: Node) -> bool: def is_dist_op(node: Node) -> bool: - """Check if the node is a distributed op.""" + """Check if the node is a distributed op (torch or trtllm backend).""" dist_ops = { + # PyTorch backend ops torch.ops.auto_deploy.torch_dist_all_gather, torch.ops.auto_deploy.torch_dist_all_reduce, + # TRT-LLM backend ops + torch.ops.auto_deploy.trtllm_dist_all_gather, + torch.ops.auto_deploy.trtllm_dist_all_reduce, } return is_op(node, dist_ops) diff --git a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py index 736318d355a..7093daefb77 100644 --- a/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py +++ b/tensorrt_llm/_torch/auto_deploy/utils/sharding_utils.py @@ -29,6 +29,27 @@ ) +def _get_dist_ops(): + """Get the appropriate distributed ops based on backend availability. + + Returns tuple of (all_gather_op, all_reduce_op) for the current backend. + """ + from ..distributed.trtllm import is_trtllm_op_available + + if is_trtllm_op_available(): + # Use TRT-LLM optimized ops in MPI mode + return ( + torch.ops.auto_deploy.trtllm_dist_all_gather.default, + torch.ops.auto_deploy.trtllm_dist_all_reduce.default, + ) + else: + # Use PyTorch distributed ops in demollm mode + return ( + torch.ops.auto_deploy.torch_dist_all_gather.default, + torch.ops.auto_deploy.torch_dist_all_reduce.default, + ) + + def _load_hook( state_dict, prefix, @@ -485,10 +506,11 @@ def _shard_parameter_node( ) return - # figure out the right dist op + # figure out the right dist op (backend-aware) + all_gather_op, all_reduce_op = _get_dist_ops() dist_lookup = { - 0: (torch.ops.auto_deploy.torch_dist_all_gather.default, -1), - 1: (torch.ops.auto_deploy.torch_dist_all_reduce.default,), + 0: (all_gather_op, -1), + 1: (all_reduce_op,), } fn_dist, *dist_args = dist_lookup[dim] @@ -859,7 +881,8 @@ def validate(self, gm: GraphModule = None, node: Node = None) -> bool: # Check if the distribution is balanced remainder = bmm_batch_size % self.world_size - # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather doesn't support uneven splits at the moment. + # NOTE: our torch.ops.auto_deploy.torch_dist_all_gather/trtllm_dist_all_gather + # doesn't support uneven splits at the moment. if remainder: ad_logger.warning( f"BMM batch size {bmm_batch_size} is not divisible by world size {self.world_size}. " @@ -924,9 +947,10 @@ def slice_tensor(t: torch.Tensor) -> torch.Tensor: handle_tensor(node, rhs_tensor, 1, self.start_idx, self.end_idx) # Add all_gather node after BMM to collect results + all_gather_op, _ = _get_dist_ops() with gm.graph.inserting_after(node): gather_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_gather.default, + all_gather_op, args=(node, 0), # Gather along batch dimension (0) ) node.replace_all_uses_with(gather_node) @@ -1012,10 +1036,9 @@ def get_partition(lst, world_size, rank): node.args = tuple(args) # -- add an all_reduce node -- + _, all_reduce_op = _get_dist_ops() with gm.graph.inserting_after(node): - dist_node = gm.graph.call_function( - torch.ops.auto_deploy.torch_dist_all_reduce.default, args=(node,) - ) + dist_node = gm.graph.call_function(all_reduce_op, args=(node,)) node.replace_all_uses_with(dist_node) dist_node.replace_input_with(dist_node, node) @@ -1048,7 +1071,7 @@ def _insert_sharded_mxfp4_mlp_ep( Transform a call to auto_deploy::triton_mxfp4_moe into: - sharded expert parameters along dim 0 (this rank's slice), - call to auto_deploy::triton_mxfp4_moe_ep(..., local_lo, local_hi), - - followed by torch_dist_all_reduce. + - followed by torch_dist_all_reduce/trtllm_dist_all_reduce. Expects the original op signature: (hidden_states, @@ -1084,8 +1107,9 @@ def _insert_sharded_mxfp4_mlp_ep( node.args = args_ep # Add a dist all-reduce after the op (sum partial results across EP ranks) + _, all_reduce_op = _get_dist_ops() with gm.graph.inserting_after(node): - red = gm.graph.call_function(torch.ops.auto_deploy.torch_dist_all_reduce, args=(node,)) + red = gm.graph.call_function(all_reduce_op, args=(node,)) node.replace_all_uses_with(red) # keep dataflow: red(input=node) red.replace_input_with(red, node) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py index d4c8091158a..3dcc7fabaf8 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/custom_ops/test_dist.py @@ -9,6 +9,7 @@ def _run_all_reduce_test(rank, world_size): x = torch.ones(10, 10).to("cuda") + # Test torch backend (demollm mode with Python multiprocessing) y = torch.ops.auto_deploy.torch_dist_all_reduce(x) assert torch.equal(x * world_size, y) @@ -16,6 +17,7 @@ def _run_all_reduce_test(rank, world_size): def _run_all_gather_test(rank, world_size): x = torch.ones(10, 10).to("cuda") + # Test torch backend (demollm mode with Python multiprocessing) y = torch.ops.auto_deploy.torch_dist_all_gather(x) assert torch.sum(y) == world_size * torch.sum(x) diff --git a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py index 797e9f94cec..05a21b80183 100644 --- a/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py +++ b/tests/unittest/_torch/auto_deploy/unit/multigpu/transformations/library/test_allreduce_residual_rmsnorm_fusion.py @@ -37,7 +37,8 @@ def __init__(self, hidden_size, dtype): self.norm = RMSNorm(hidden_size, 1e-5, dtype) def forward(self, x, residual): - x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x) + # Use trtllm backend ops when running with MPI/TRT-LLM + x = torch.ops.auto_deploy.trtllm_dist_all_reduce.default(x) y = residual + x normed = self.norm(y) return normed, y @@ -51,7 +52,8 @@ def __init__(self, hidden_size, dtype): self.norm = RMSNorm(hidden_size, 1e-5, dtype) def forward(self, x, residual): - x = torch.ops.auto_deploy.torch_dist_all_reduce.default(x) + # Use trtllm backend ops when running with MPI/TRT-LLM + x = torch.ops.auto_deploy.trtllm_dist_all_reduce.default(x) y = x + residual normed = self.norm(y) return normed, y @@ -94,7 +96,7 @@ def _test_allreduce_fusion(port: int, ModuleCls): # Check if fused node in the graph has_fused_node = False for node in gm_transformed.graph.nodes: - if is_op(node, torch.ops.dist.fused_allreduce_residual_rmsnorm): + if is_op(node, torch.ops.dist.trtllm_fused_allreduce_residual_rmsnorm): has_fused_node = True assert has_fused_node, "Fused node not found."