Skip to content

Commit a36b48b

Browse files
authored
[#5860][autodeploy] GPT-OSS MXFP4 support (#7451)
Signed-off-by: Frida Hou <[email protected]> Signed-off-by: Fridah-nv <[email protected]>
1 parent c33f43e commit a36b48b

File tree

13 files changed

+1071
-5
lines changed

13 files changed

+1071
-5
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ transforms:
2929
############################################################################################
3030
match_moe_pattern:
3131
stage: pattern_matcher
32+
match_dense_moe_pattern:
33+
stage: pattern_matcher
3234
match_repeat_kv:
3335
stage: pattern_matcher
3436
match_eager_attention:
@@ -64,13 +66,16 @@ transforms:
6466
stage: pattern_matcher
6567
quantize_nvfp4_moe:
6668
stage: pattern_matcher
69+
quantize_mxfp4_moe:
70+
stage: pattern_matcher
6771
# TODO: Infer sharding parameters (tp_size, row/column sharding) from the model config.
6872
detect_sharding:
6973
stage: sharding
7074
simple_shard_only: false
7175
use_sharding_from_factory: false
7276
support_partial_config: false
7377
sharding_dims: ['tp', 'ep', 'bmm']
78+
requires_shape_prop: true
7479
# TODO: (hg) need to ensure run_shape_prop after sharding.
7580
sharding_transform_executor:
7681
stage: sharding

tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
from .flashinfer_rope import *
77
from .linear import *
88
from .mla import *
9+
from .mxfp4_moe import *
910
from .quant import *
1011
from .rms_norm import *
1112
from .torch_attention import *
1213
from .torch_backend_attention import *
1314
from .torch_moe import *
1415
from .torch_quant import *
1516
from .torch_rope import *
17+
from .torch_router import *
1618
from .triton_attention import *
1719
from .triton_rope import *
1820
from .trtllm_moe import *
Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,268 @@
1+
# Triton-kernels-based MXFP4 MoE ops (GPT-OSS style) with routing, swizzling, and fused activation
2+
3+
from typing import Callable, Tuple
4+
5+
import torch
6+
import torch.nn.functional as F
7+
8+
IS_TRITON_KERNELS_AVAILABLE = True
9+
TRITON_KERNELS_UNAVAILABLE_REASON = ""
10+
11+
try:
12+
from triton_kernels.matmul_ogs import (
13+
FlexCtx,
14+
FnSpecs,
15+
FusedActivation,
16+
PrecisionConfig,
17+
matmul_ogs,
18+
)
19+
from triton_kernels.numerics import InFlexData
20+
from triton_kernels.routing import RoutingData, routing
21+
from triton_kernels.swiglu import swiglu_fn
22+
from triton_kernels.tensor import FP4, convert_layout, wrap_torch_tensor
23+
from triton_kernels.tensor_details import layout
24+
from triton_kernels.tensor_details.layout import StridedLayout
25+
26+
from tensorrt_llm._torch.modules.fused_moe.fused_moe_triton import TritonEPRouter
27+
28+
except Exception as _e:
29+
IS_TRITON_KERNELS_AVAILABLE = False
30+
TRITON_KERNELS_UNAVAILABLE_REASON = f"{type(_e).__name__}: {_e}"
31+
32+
FlexCtx = FnSpecs = FusedActivation = PrecisionConfig = matmul_ogs = None
33+
InFlexData = RoutingData = routing = swiglu_fn = None
34+
FP4 = convert_layout = wrap_torch_tensor = None
35+
layout = StridedLayout = None
36+
TritonEPRouter = None
37+
38+
39+
# copied from transformers.integrations.mxfp4::swizzle_mxfp4 with minor modification
40+
def _swizzle_mxfp4(w, w_scale):
41+
value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
42+
w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
43+
w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
44+
return w, w_scale
45+
46+
47+
RouteFn = Callable[[torch.Tensor], Tuple[RoutingData, torch.Tensor, torch.Tensor]]
48+
49+
50+
def _prepare_weights_scales(
51+
hidden_size: int,
52+
gate_up_blocks: torch.Tensor, # [E_local, 2I, H//32, 16] in unit8
53+
gate_up_scales: torch.Tensor, # [E_local, 2I, H//32] in unit8
54+
down_blocks: torch.Tensor, # [E_local, H, I//32, 16] in uint8
55+
down_scales: torch.Tensor, # [E_local, H, I//32] in uint8
56+
):
57+
local_experts = gate_up_blocks.size(0)
58+
intermediate_size = gate_up_blocks.shape[1] // 2
59+
60+
# canon shapes for swizzling (use last two dims as [K, N] style)
61+
gate_up_blocks = gate_up_blocks.view(local_experts, intermediate_size * 2, -1)
62+
triton_gate_up_w, gate_up_w_scale_raw = _swizzle_mxfp4(
63+
gate_up_blocks.transpose(-2, -1), gate_up_scales.transpose(-2, -1)
64+
)
65+
triton_gate_up_w.shape = torch.Size([local_experts, hidden_size, intermediate_size * 2])
66+
67+
down_blocks = down_blocks.view(local_experts, -1, intermediate_size // 2)
68+
triton_down_w, down_w_scale_raw = _swizzle_mxfp4(
69+
down_blocks.transpose(-2, -1), down_scales.transpose(-2, -1)
70+
)
71+
triton_down_w.shape = torch.Size([local_experts, intermediate_size, hidden_size])
72+
73+
return (
74+
triton_gate_up_w,
75+
gate_up_w_scale_raw,
76+
triton_down_w,
77+
down_w_scale_raw,
78+
)
79+
80+
81+
def _run_mxfp4_mlp_core(
82+
hidden_states: torch.Tensor, # [B, S, H] or [B*S, H]
83+
router_weight: torch.Tensor,
84+
router_bias: torch.Tensor,
85+
gate_up_blocks: torch.Tensor,
86+
gate_up_bias: torch.Tensor,
87+
gate_up_scales: torch.Tensor,
88+
alpha: float,
89+
limit: float,
90+
down_blocks: torch.Tensor,
91+
down_bias: torch.Tensor,
92+
down_scales: torch.Tensor,
93+
route_fn: RouteFn, # injects routing variant
94+
) -> torch.Tensor:
95+
"""
96+
Shared core for both triton_mxfp4_moe and triton_mxfp4_moe_ep.
97+
- route_fn encapsulates the only difference: how we produce (routing_data, gather_idx, scatter_idx).
98+
"""
99+
leading_shape = hidden_states.shape[:-1]
100+
hidden_size = hidden_states.shape[-1]
101+
x = hidden_states.reshape(-1, hidden_size)
102+
103+
router_logits = F.linear(x, router_weight, router_bias)
104+
# route (global vs EP-aware)
105+
with torch.cuda.device(router_logits.device):
106+
routing_data, gather_idx, scatter_idx = route_fn(router_logits)
107+
108+
(
109+
triton_gate_up_w,
110+
gate_up_w_scale_raw,
111+
triton_down_w,
112+
down_w_scale_raw,
113+
) = _prepare_weights_scales(
114+
hidden_size, gate_up_blocks, gate_up_scales, down_blocks, down_scales
115+
)
116+
117+
gate_pc = PrecisionConfig(
118+
weight_scale=gate_up_w_scale_raw, flex_ctx=FlexCtx(rhs_data=InFlexData())
119+
)
120+
down_pc = PrecisionConfig(
121+
weight_scale=down_w_scale_raw, flex_ctx=FlexCtx(rhs_data=InFlexData())
122+
)
123+
124+
act = FusedActivation(
125+
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (float(alpha), float(limit)), 2
126+
)
127+
128+
# gate_up (with SWiGLU fused)
129+
inter = matmul_ogs(
130+
x,
131+
triton_gate_up_w,
132+
gate_up_bias.to(torch.float32),
133+
routing_data,
134+
gather_indx=gather_idx,
135+
precision_config=gate_pc,
136+
gammas=None,
137+
fused_activation=act,
138+
)
139+
140+
# down
141+
y = matmul_ogs(
142+
inter,
143+
triton_down_w,
144+
down_bias.to(torch.float32),
145+
routing_data,
146+
scatter_indx=scatter_idx,
147+
precision_config=down_pc,
148+
gammas=routing_data.gate_scal,
149+
)
150+
151+
y = y.reshape(*leading_shape, hidden_size)
152+
return y
153+
154+
155+
@torch.library.custom_op("auto_deploy::triton_mxfp4_moe", mutates_args=())
156+
def triton_mxfp4_moe(
157+
hidden_states: torch.Tensor, # [B, S, H] or [B*S, H]
158+
# router
159+
router_weight: torch.Tensor, # [E, H]
160+
router_bias: torch.Tensor, # [E]
161+
top_k: int,
162+
# gate_up path
163+
gate_up_blocks: torch.Tensor, # [E, 2I, H//32, 16] in unit8
164+
gate_up_bias: torch.Tensor, # [E, 2I]
165+
gate_up_scales: torch.Tensor, # [E, 2I, H//32] in unit8
166+
alpha: float,
167+
limit: float,
168+
# down path
169+
down_blocks: torch.Tensor, # [E, H, I//32, 16] in uint8
170+
down_bias: torch.Tensor, # [E, H]
171+
down_scales: torch.Tensor, # [E, H, I//32] in uint8
172+
) -> torch.Tensor:
173+
def _global_route_fn(logits: torch.Tensor):
174+
return routing(logits, top_k)
175+
176+
return _run_mxfp4_mlp_core(
177+
hidden_states,
178+
router_weight,
179+
router_bias,
180+
gate_up_blocks,
181+
gate_up_bias,
182+
gate_up_scales,
183+
alpha,
184+
limit,
185+
down_blocks,
186+
down_bias,
187+
down_scales,
188+
route_fn=_global_route_fn,
189+
)
190+
191+
192+
@triton_mxfp4_moe.register_fake
193+
def _mxfp4_mlp_fake(
194+
hidden_states: torch.Tensor,
195+
router_weight: torch.Tensor,
196+
router_bias: torch.Tensor,
197+
top_k: int,
198+
gate_up_blocks: torch.Tensor,
199+
gate_up_bias: torch.Tensor,
200+
gate_up_scales: torch.Tensor,
201+
alpha: float,
202+
limit: float,
203+
down_blocks: torch.Tensor,
204+
down_bias: torch.Tensor,
205+
down_scales: torch.Tensor,
206+
):
207+
return torch.empty_like(hidden_states)
208+
209+
210+
@torch.library.custom_op("auto_deploy::triton_mxfp4_moe_ep", mutates_args=())
211+
def triton_mxfp4_moe_ep(
212+
hidden_states: torch.Tensor, # [B, S, H] or [B*S, H]
213+
# router (replicated across EP)
214+
router_weight: torch.Tensor, # [E_total, H]
215+
router_bias: torch.Tensor, # [E_total]
216+
top_k: int,
217+
# expert params (already sharded along dim 0)
218+
gate_up_blocks: torch.Tensor, # [E_local, 2I, H//32, 16] in unit8
219+
gate_up_bias: torch.Tensor, # [E_local, 2I]
220+
gate_up_scales: torch.Tensor, # [E_local, 2I, H//32] in unit8
221+
alpha: float,
222+
limit: float,
223+
down_blocks: torch.Tensor, # [E_local, H, I//32, 16] in uint8
224+
down_bias: torch.Tensor, # [E_local, H]
225+
down_scales: torch.Tensor, # [E_local, H, I//32] in uint8
226+
# EP topology
227+
ep_size: int,
228+
ep_rank: int,
229+
) -> torch.Tensor:
230+
triton_ep_router = TritonEPRouter()
231+
232+
def _ep_route_fn(logits: torch.Tensor):
233+
return triton_ep_router(logits, top_k, ep=ep_size, node_idx=ep_rank)
234+
235+
return _run_mxfp4_mlp_core(
236+
hidden_states,
237+
router_weight,
238+
router_bias,
239+
gate_up_blocks,
240+
gate_up_bias,
241+
gate_up_scales,
242+
alpha,
243+
limit,
244+
down_blocks,
245+
down_bias,
246+
down_scales,
247+
route_fn=_ep_route_fn,
248+
)
249+
250+
251+
@triton_mxfp4_moe_ep.register_fake
252+
def _mxfp4_mlp_ep_fake(
253+
hidden_states: torch.Tensor,
254+
router_weight: torch.Tensor,
255+
router_bias: torch.Tensor,
256+
top_k: int,
257+
gate_up_blocks: torch.Tensor,
258+
gate_up_bias: torch.Tensor,
259+
gate_up_scales: torch.Tensor,
260+
alpha: float,
261+
limit: float,
262+
down_blocks: torch.Tensor,
263+
down_bias: torch.Tensor,
264+
down_scales: torch.Tensor,
265+
ep_size: int,
266+
ep_rank: int,
267+
):
268+
return torch.empty_like(hidden_states)

tensorrt_llm/_torch/auto_deploy/custom_ops/torch_moe.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -324,3 +324,53 @@ def torch_quant_nvfp4_moe_fake(
324324
w3_alpha: List[torch.Tensor],
325325
) -> torch.Tensor:
326326
return torch.empty_like(x)
327+
328+
329+
# GPT-OSS uses this style
330+
@torch.library.custom_op("auto_deploy::torch_moe_dense_mlp", mutates_args=())
331+
def torch_moe_dense_mlp(
332+
hidden_states: torch.Tensor, # [B, S, H] or [B*S, H]
333+
routing_weights: torch.Tensor, # [B*S, E]
334+
gate_up_w: torch.Tensor, # [E, H, 2I]
335+
gate_up_b: torch.Tensor, # [E, 2I]
336+
down_w: torch.Tensor, # [E, I, H]
337+
down_b: torch.Tensor, # [E, H]
338+
alpha: float = 1.0,
339+
limit: float = 10.0,
340+
) -> torch.Tensor:
341+
batch_size = hidden_states.shape[0]
342+
leading_shape = hidden_states.shape[:-1]
343+
hidden_size = hidden_states.shape[-1]
344+
hidden_states = hidden_states.reshape(-1, hidden_size) # (num_tokens, hidden_size)
345+
num_experts = routing_weights.shape[1]
346+
347+
hidden_states = hidden_states.repeat(num_experts, 1)
348+
hidden_states = hidden_states.view(num_experts, -1, hidden_size)
349+
gate_up = torch.bmm(hidden_states, gate_up_w) + gate_up_b[..., None, :]
350+
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
351+
gate = gate.clamp(min=None, max=limit)
352+
up = up.clamp(min=-limit, max=limit)
353+
glu = gate * torch.sigmoid(gate * alpha)
354+
next_states = torch.bmm(((up + 1) * glu), down_w)
355+
next_states = next_states + down_b[..., None, :]
356+
next_states = next_states.view(num_experts, batch_size, -1, hidden_size)
357+
next_states = (
358+
next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
359+
)
360+
next_states = next_states.sum(dim=0)
361+
next_states = next_states.reshape(*leading_shape, hidden_size)
362+
return next_states # [B, S, H] or [B*S, H]
363+
364+
365+
@torch_moe_dense_mlp.register_fake
366+
def _torch_moe_dense_mlp_fake(
367+
hidden_states: torch.Tensor,
368+
routing_weights: torch.Tensor,
369+
gate_up_w: torch.Tensor,
370+
gate_up_b: torch.Tensor,
371+
down_w: torch.Tensor,
372+
down_b: torch.Tensor,
373+
alpha: float = 1.0,
374+
limit: float = 10.0,
375+
) -> torch.Tensor:
376+
return torch.empty_like(hidden_states)

0 commit comments

Comments
 (0)