|
| 1 | +# Triton-kernels-based MXFP4 MoE ops (GPT-OSS style) with routing, swizzling, and fused activation |
| 2 | + |
| 3 | +from typing import Callable, Tuple |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn.functional as F |
| 7 | + |
| 8 | +IS_TRITON_KERNELS_AVAILABLE = True |
| 9 | +TRITON_KERNELS_UNAVAILABLE_REASON = "" |
| 10 | + |
| 11 | +try: |
| 12 | + from triton_kernels.matmul_ogs import ( |
| 13 | + FlexCtx, |
| 14 | + FnSpecs, |
| 15 | + FusedActivation, |
| 16 | + PrecisionConfig, |
| 17 | + matmul_ogs, |
| 18 | + ) |
| 19 | + from triton_kernels.numerics import InFlexData |
| 20 | + from triton_kernels.routing import RoutingData, routing |
| 21 | + from triton_kernels.swiglu import swiglu_fn |
| 22 | + from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor |
| 23 | + from triton_kernels.tensor_details import layout |
| 24 | + from triton_kernels.tensor_details.layout import StridedLayout |
| 25 | + |
| 26 | + from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter |
| 27 | + |
| 28 | +except Exception as _e: |
| 29 | + IS_TRITON_KERNELS_AVAILABLE = False |
| 30 | + TRITON_KERNELS_UNAVAILABLE_REASON = f"{type(_e).__name__}: {_e}" |
| 31 | + |
| 32 | + FlexCtx = FnSpecs = FusedActivation = PrecisionConfig = matmul_ogs = None |
| 33 | + InFlexData = RoutingData = routing = swiglu_fn = None |
| 34 | + FP4 = convert_layout = wrap_torch_tensor = None |
| 35 | + layout = StridedLayout = None |
| 36 | + TritonEPRouter = None |
| 37 | + |
| 38 | + |
| 39 | +# copied from transformers.integrations.mxfp4::swizzle_mxfp4 with minor modification |
| 40 | +def _swizzle_mxfp4(w, w_scale): |
| 41 | + value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1) |
| 42 | + w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts) |
| 43 | + w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout) |
| 44 | + return w, w_scale |
| 45 | + |
| 46 | + |
| 47 | +RouteFn = Callable[[torch.Tensor], Tuple[RoutingData, torch.Tensor, torch.Tensor]] |
| 48 | + |
| 49 | + |
| 50 | +def _prepare_weights_scales( |
| 51 | + hidden_size: int, |
| 52 | + gate_up_blocks: torch.Tensor, # [E_local, 2I, H//32, 16] in unit8 |
| 53 | + gate_up_scales: torch.Tensor, # [E_local, 2I, H//32] in unit8 |
| 54 | + down_blocks: torch.Tensor, # [E_local, H, I//32, 16] in uint8 |
| 55 | + down_scales: torch.Tensor, # [E_local, H, I//32] in uint8 |
| 56 | +): |
| 57 | + local_experts = gate_up_blocks.size(0) |
| 58 | + intermediate_size = gate_up_blocks.shape[1] // 2 |
| 59 | + |
| 60 | + # canon shapes for swizzling (use last two dims as [K, N] style) |
| 61 | + gate_up_blocks = gate_up_blocks.view(local_experts, intermediate_size * 2, -1) |
| 62 | + triton_gate_up_w, gate_up_w_scale_raw = _swizzle_mxfp4( |
| 63 | + gate_up_blocks.transpose(-2, -1), gate_up_scales.transpose(-2, -1) |
| 64 | + ) |
| 65 | + triton_gate_up_w.shape = torch.Size([local_experts, hidden_size, intermediate_size * 2]) |
| 66 | + |
| 67 | + down_blocks = down_blocks.view(local_experts, -1, intermediate_size // 2) |
| 68 | + triton_down_w, down_w_scale_raw = _swizzle_mxfp4( |
| 69 | + down_blocks.transpose(-2, -1), down_scales.transpose(-2, -1) |
| 70 | + ) |
| 71 | + triton_down_w.shape = torch.Size([local_experts, intermediate_size, hidden_size]) |
| 72 | + |
| 73 | + return ( |
| 74 | + triton_gate_up_w, |
| 75 | + gate_up_w_scale_raw, |
| 76 | + triton_down_w, |
| 77 | + down_w_scale_raw, |
| 78 | + ) |
| 79 | + |
| 80 | + |
| 81 | +def _run_mxfp4_mlp_core( |
| 82 | + hidden_states: torch.Tensor, # [B, S, H] or [B*S, H] |
| 83 | + router_weight: torch.Tensor, |
| 84 | + router_bias: torch.Tensor, |
| 85 | + gate_up_blocks: torch.Tensor, |
| 86 | + gate_up_bias: torch.Tensor, |
| 87 | + gate_up_scales: torch.Tensor, |
| 88 | + alpha: float, |
| 89 | + limit: float, |
| 90 | + down_blocks: torch.Tensor, |
| 91 | + down_bias: torch.Tensor, |
| 92 | + down_scales: torch.Tensor, |
| 93 | + route_fn: RouteFn, # injects routing variant |
| 94 | +) -> torch.Tensor: |
| 95 | + """ |
| 96 | + Shared core for both triton_mxfp4_moe and triton_mxfp4_moe_ep. |
| 97 | + - route_fn encapsulates the only difference: how we produce (routing_data, gather_idx, scatter_idx). |
| 98 | + """ |
| 99 | + leading_shape = hidden_states.shape[:-1] |
| 100 | + hidden_size = hidden_states.shape[-1] |
| 101 | + x = hidden_states.reshape(-1, hidden_size) |
| 102 | + |
| 103 | + router_logits = F.linear(x, router_weight, router_bias) |
| 104 | + # route (global vs EP-aware) |
| 105 | + with torch.cuda.device(router_logits.device): |
| 106 | + routing_data, gather_idx, scatter_idx = route_fn(router_logits) |
| 107 | + |
| 108 | + ( |
| 109 | + triton_gate_up_w, |
| 110 | + gate_up_w_scale_raw, |
| 111 | + triton_down_w, |
| 112 | + down_w_scale_raw, |
| 113 | + ) = _prepare_weights_scales( |
| 114 | + hidden_size, gate_up_blocks, gate_up_scales, down_blocks, down_scales |
| 115 | + ) |
| 116 | + |
| 117 | + gate_pc = PrecisionConfig( |
| 118 | + weight_scale=gate_up_w_scale_raw, flex_ctx=FlexCtx(rhs_data=InFlexData()) |
| 119 | + ) |
| 120 | + down_pc = PrecisionConfig( |
| 121 | + weight_scale=down_w_scale_raw, flex_ctx=FlexCtx(rhs_data=InFlexData()) |
| 122 | + ) |
| 123 | + |
| 124 | + act = FusedActivation( |
| 125 | + FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (float(alpha), float(limit)), 2 |
| 126 | + ) |
| 127 | + |
| 128 | + # gate_up (with SWiGLU fused) |
| 129 | + inter = matmul_ogs( |
| 130 | + x, |
| 131 | + triton_gate_up_w, |
| 132 | + gate_up_bias.to(torch.float32), |
| 133 | + routing_data, |
| 134 | + gather_indx=gather_idx, |
| 135 | + precision_config=gate_pc, |
| 136 | + gammas=None, |
| 137 | + fused_activation=act, |
| 138 | + ) |
| 139 | + |
| 140 | + # down |
| 141 | + y = matmul_ogs( |
| 142 | + inter, |
| 143 | + triton_down_w, |
| 144 | + down_bias.to(torch.float32), |
| 145 | + routing_data, |
| 146 | + scatter_indx=scatter_idx, |
| 147 | + precision_config=down_pc, |
| 148 | + gammas=routing_data.gate_scal, |
| 149 | + ) |
| 150 | + |
| 151 | + y = y.reshape(*leading_shape, hidden_size) |
| 152 | + return y |
| 153 | + |
| 154 | + |
| 155 | +@torch.library.custom_op("auto_deploy::triton_mxfp4_moe", mutates_args=()) |
| 156 | +def triton_mxfp4_moe( |
| 157 | + hidden_states: torch.Tensor, # [B, S, H] or [B*S, H] |
| 158 | + # router |
| 159 | + router_weight: torch.Tensor, # [E, H] |
| 160 | + router_bias: torch.Tensor, # [E] |
| 161 | + top_k: int, |
| 162 | + # gate_up path |
| 163 | + gate_up_blocks: torch.Tensor, # [E, 2I, H//32, 16] in unit8 |
| 164 | + gate_up_bias: torch.Tensor, # [E, 2I] |
| 165 | + gate_up_scales: torch.Tensor, # [E, 2I, H//32] in unit8 |
| 166 | + alpha: float, |
| 167 | + limit: float, |
| 168 | + # down path |
| 169 | + down_blocks: torch.Tensor, # [E, H, I//32, 16] in uint8 |
| 170 | + down_bias: torch.Tensor, # [E, H] |
| 171 | + down_scales: torch.Tensor, # [E, H, I//32] in uint8 |
| 172 | +) -> torch.Tensor: |
| 173 | + def _global_route_fn(logits: torch.Tensor): |
| 174 | + return routing(logits, top_k) |
| 175 | + |
| 176 | + return _run_mxfp4_mlp_core( |
| 177 | + hidden_states, |
| 178 | + router_weight, |
| 179 | + router_bias, |
| 180 | + gate_up_blocks, |
| 181 | + gate_up_bias, |
| 182 | + gate_up_scales, |
| 183 | + alpha, |
| 184 | + limit, |
| 185 | + down_blocks, |
| 186 | + down_bias, |
| 187 | + down_scales, |
| 188 | + route_fn=_global_route_fn, |
| 189 | + ) |
| 190 | + |
| 191 | + |
| 192 | +@triton_mxfp4_moe.register_fake |
| 193 | +def _mxfp4_mlp_fake( |
| 194 | + hidden_states: torch.Tensor, |
| 195 | + router_weight: torch.Tensor, |
| 196 | + router_bias: torch.Tensor, |
| 197 | + top_k: int, |
| 198 | + gate_up_blocks: torch.Tensor, |
| 199 | + gate_up_bias: torch.Tensor, |
| 200 | + gate_up_scales: torch.Tensor, |
| 201 | + alpha: float, |
| 202 | + limit: float, |
| 203 | + down_blocks: torch.Tensor, |
| 204 | + down_bias: torch.Tensor, |
| 205 | + down_scales: torch.Tensor, |
| 206 | +): |
| 207 | + return torch.empty_like(hidden_states) |
| 208 | + |
| 209 | + |
| 210 | +@torch.library.custom_op("auto_deploy::triton_mxfp4_moe_ep", mutates_args=()) |
| 211 | +def triton_mxfp4_moe_ep( |
| 212 | + hidden_states: torch.Tensor, # [B, S, H] or [B*S, H] |
| 213 | + # router (replicated across EP) |
| 214 | + router_weight: torch.Tensor, # [E_total, H] |
| 215 | + router_bias: torch.Tensor, # [E_total] |
| 216 | + top_k: int, |
| 217 | + # expert params (already sharded along dim 0) |
| 218 | + gate_up_blocks: torch.Tensor, # [E_local, 2I, H//32, 16] in unit8 |
| 219 | + gate_up_bias: torch.Tensor, # [E_local, 2I] |
| 220 | + gate_up_scales: torch.Tensor, # [E_local, 2I, H//32] in unit8 |
| 221 | + alpha: float, |
| 222 | + limit: float, |
| 223 | + down_blocks: torch.Tensor, # [E_local, H, I//32, 16] in uint8 |
| 224 | + down_bias: torch.Tensor, # [E_local, H] |
| 225 | + down_scales: torch.Tensor, # [E_local, H, I//32] in uint8 |
| 226 | + # EP topology |
| 227 | + ep_size: int, |
| 228 | + ep_rank: int, |
| 229 | +) -> torch.Tensor: |
| 230 | + triton_ep_router = TritonEPRouter() |
| 231 | + |
| 232 | + def _ep_route_fn(logits: torch.Tensor): |
| 233 | + return triton_ep_router(logits, top_k, ep=ep_size, node_idx=ep_rank) |
| 234 | + |
| 235 | + return _run_mxfp4_mlp_core( |
| 236 | + hidden_states, |
| 237 | + router_weight, |
| 238 | + router_bias, |
| 239 | + gate_up_blocks, |
| 240 | + gate_up_bias, |
| 241 | + gate_up_scales, |
| 242 | + alpha, |
| 243 | + limit, |
| 244 | + down_blocks, |
| 245 | + down_bias, |
| 246 | + down_scales, |
| 247 | + route_fn=_ep_route_fn, |
| 248 | + ) |
| 249 | + |
| 250 | + |
| 251 | +@triton_mxfp4_moe_ep.register_fake |
| 252 | +def _mxfp4_mlp_ep_fake( |
| 253 | + hidden_states: torch.Tensor, |
| 254 | + router_weight: torch.Tensor, |
| 255 | + router_bias: torch.Tensor, |
| 256 | + top_k: int, |
| 257 | + gate_up_blocks: torch.Tensor, |
| 258 | + gate_up_bias: torch.Tensor, |
| 259 | + gate_up_scales: torch.Tensor, |
| 260 | + alpha: float, |
| 261 | + limit: float, |
| 262 | + down_blocks: torch.Tensor, |
| 263 | + down_bias: torch.Tensor, |
| 264 | + down_scales: torch.Tensor, |
| 265 | + ep_size: int, |
| 266 | + ep_rank: int, |
| 267 | +): |
| 268 | + return torch.empty_like(hidden_states) |
0 commit comments