From 7fcaa2ae13b8e2d0cf7ae439a7aa3aa48d3ec148 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 26 Jun 2026 10:29:22 -0500 Subject: [PATCH 1/6] Fix gathered MXFP4 activation scales in Gluon MoE When MXFP4 activations are gathered through routed MoE metadata, X block-scale rows must be gathered using the scale tile's own row layout instead of reusing the already-gathered activation data rows. Keep W scales on the swizzled LDS path, but route gathered X scales through the direct load path so the scale indices match the gathered activation rows. This is scoped to MXFP4-X with gather and does not affect the existing FP8 activation path. Signed-off-by: Quinn Dawkins --- AGENTS.md | 3 + .../ops/moe/fused_mxfp_gfx950.py | 145 ++++++++++-------- .../test/ops/test_gluon_moe_gemm_gfx950.py | 83 ++++++++++ 3 files changed, 167 insertions(+), 64 deletions(-) diff --git a/AGENTS.md b/AGENTS.md index f0b984d6c..7b13d5c98 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -12,6 +12,9 @@ * Add tests and update docs for the changed code. * Before creating commits, run `pre-commit run --all-files` to format. +* Do not substitute a narrower lint command for the repository hook before + committing. Always run the exact `pre-commit run --all-files` command and + commit any formatter changes it makes. * When creating commits, perform sign off on behalf of the author. ## Dependency boundaries diff --git a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py index c33c9b78f..2010659da 100644 --- a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py +++ b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py @@ -455,6 +455,8 @@ class MoEConfig: WITH_W_MX_SCALE: gl.constexpr SCALE_LOAD_MODE: gl.constexpr SCALE_VIA_LDS: gl.constexpr + X_SCALE_VIA_LDS: gl.constexpr + W_SCALE_VIA_LDS: gl.constexpr PRESHUFFLE_FACTOR: gl.constexpr BLOCK_M_PRESHUFFLED: gl.constexpr BLOCK_N_PRESHUFFLED: gl.constexpr @@ -510,6 +512,8 @@ def __init__( W_PRESHUFFLED=False, W_VIA_VGPR=False, W_PREFETCH=True, + X_SCALE_VIA_LDS=None, + W_SCALE_VIA_LDS=None, ): if SCALE_LOAD_MODE not in _SCALE_LOAD_MODES: raise ValueError( @@ -533,10 +537,14 @@ def __init__( self.DTYPE_X = gl.constexpr(DTYPE_X) self.DTYPE_W = gl.constexpr(DTYPE_W) - _scale_via_lds = SCALE_LOAD_MODE == "swizzle" and ( - WITH_X_MX_SCALE or WITH_W_MX_SCALE - ) + if X_SCALE_VIA_LDS is None: + X_SCALE_VIA_LDS = SCALE_LOAD_MODE == "swizzle" and WITH_X_MX_SCALE + if W_SCALE_VIA_LDS is None: + W_SCALE_VIA_LDS = SCALE_LOAD_MODE == "swizzle" and WITH_W_MX_SCALE + _scale_via_lds = X_SCALE_VIA_LDS or W_SCALE_VIA_LDS self.SCALE_VIA_LDS = gl.constexpr(_scale_via_lds) + self.X_SCALE_VIA_LDS = gl.constexpr(X_SCALE_VIA_LDS) + self.W_SCALE_VIA_LDS = gl.constexpr(W_SCALE_VIA_LDS) self.PRESHUFFLE_FACTOR = gl.constexpr(_SCALE_PRESHUFFLE_FACTOR) self.BLOCK_M_PRESHUFFLED = gl.constexpr(BLOCK_M // _SCALE_PRESHUFFLE_FACTOR) self.BLOCK_N_PRESHUFFLED = gl.constexpr(BLOCK_N // _SCALE_PRESHUFFLE_FACTOR) @@ -558,9 +566,9 @@ def __init__( if not W_VIA_VGPR: num_loads += 1 # w (LDS path) if _scale_via_lds: - if WITH_X_MX_SCALE: + if X_SCALE_VIA_LDS: num_loads += 1 - if WITH_W_MX_SCALE: + if W_SCALE_VIA_LDS: num_loads += 1 self.NUM_LOADS_IN_BATCH = gl.constexpr(num_loads) @@ -727,11 +735,11 @@ def issue_global_loads(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): load_idx, self.w_buffer, pred, USE_MASK=USE_MASK ) if cfg.SCALE_VIA_LDS: - if cfg.WITH_X_MX_SCALE: + if cfg.X_SCALE_VIA_LDS: self.x_scale_desc.issue_async_load( load_idx, self.x_scale_buffer, pred, USE_MASK=USE_MASK ) - if cfg.WITH_W_MX_SCALE: + if cfg.W_SCALE_VIA_LDS: self.w_scale_desc.issue_async_load( load_idx, self.w_scale_buffer, pred, USE_MASK=USE_MASK ) @@ -1167,16 +1175,8 @@ def __init__( self.cfg = cfg self.x_buffer = x_buffer self.w_buffer = w_buffer if not cfg.W_VIA_VGPR else gl.constexpr(0) - self.x_scale_buffer = ( - x_scale_buffer - if (cfg.SCALE_VIA_LDS and cfg.WITH_X_MX_SCALE) - else gl.constexpr(0) - ) - self.w_scale_buffer = ( - w_scale_buffer - if (cfg.SCALE_VIA_LDS and cfg.WITH_W_MX_SCALE) - else gl.constexpr(0) - ) + self.x_scale_buffer = x_scale_buffer if cfg.X_SCALE_VIA_LDS else gl.constexpr(0) + self.w_scale_buffer = w_scale_buffer if cfg.W_SCALE_VIA_LDS else gl.constexpr(0) self.x_desc = x_desc self.w_desc = w_desc self.x_scale_desc = x_scale_desc if cfg.WITH_X_MX_SCALE else gl.constexpr(0) @@ -1215,7 +1215,7 @@ def initialize(cfg: MoEConfig, x_desc, w_desc, x_scale_desc, w_scale_desc): layout=cfg.shared_layout_w, ) - if cfg.SCALE_VIA_LDS and cfg.WITH_X_MX_SCALE: + if cfg.X_SCALE_VIA_LDS: x_scale_buffer = gl.allocate_shared_memory( gl.uint8, shape=[ @@ -1228,7 +1228,7 @@ def initialize(cfg: MoEConfig, x_desc, w_desc, x_scale_desc, w_scale_desc): else: x_scale_buffer = gl.constexpr(0) - if cfg.SCALE_VIA_LDS and cfg.WITH_W_MX_SCALE: + if cfg.W_SCALE_VIA_LDS: w_scale_buffer = gl.allocate_shared_memory( gl.uint8, shape=[ @@ -1308,7 +1308,7 @@ def _load_x_scales(self, mfma_idx): BLOCK_K_SCALE: gl.constexpr = cfg.BLOCK_K // cfg.SCALE_BLOCK if cfg.USE_MFMA_SCALED: if cfg.WITH_X_MX_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.X_SCALE_VIA_LDS: scale_x = self.x_scale_desc.issue_local_load_unswizzle( mfma_idx, self.x_scale_buffer, @@ -1327,7 +1327,7 @@ def _load_x_scales(self, mfma_idx): layout=cfg.layout_x_scale, ) if cfg.WITH_W_MX_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.W_SCALE_VIA_LDS: scale_w = self.w_scale_desc.issue_local_load_unswizzle( mfma_idx, self.w_scale_buffer, @@ -1359,7 +1359,7 @@ def _load_scales(self, mfma_idx): if cfg.USE_MFMA_SCALED: if cfg.WITH_X_MX_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.X_SCALE_VIA_LDS: scale_x = self.x_scale_desc.issue_local_load_unswizzle( mfma_idx, self.x_scale_buffer, @@ -1379,7 +1379,7 @@ def _load_scales(self, mfma_idx): ) if cfg.WITH_W_MX_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.W_SCALE_VIA_LDS: scale_w = self.w_scale_desc.issue_local_load_unswizzle( mfma_idx, self.w_scale_buffer, @@ -1644,16 +1644,8 @@ def __init__( self.x_buffer_bot = x_buffer_bot self.w_buffer_left = w_buffer_left self.w_buffer_right = w_buffer_right - self.x_scale_buffer = ( - x_scale_buffer - if (cfg.SCALE_VIA_LDS and cfg.WITH_X_MX_SCALE) - else gl.constexpr(0) - ) - self.w_scale_buffer = ( - w_scale_buffer - if (cfg.SCALE_VIA_LDS and cfg.WITH_W_MX_SCALE) - else gl.constexpr(0) - ) + self.x_scale_buffer = x_scale_buffer if cfg.X_SCALE_VIA_LDS else gl.constexpr(0) + self.w_scale_buffer = w_scale_buffer if cfg.W_SCALE_VIA_LDS else gl.constexpr(0) self.x_desc_top = x_desc_top self.x_desc_bot = x_desc_bot self.w_desc_left = w_desc_left @@ -1705,7 +1697,7 @@ def initialize( layout=cfg.shared_layout_w_half_n, ) - if cfg.SCALE_VIA_LDS and cfg.WITH_X_MX_SCALE: + if cfg.X_SCALE_VIA_LDS: x_scale_buffer = gl.allocate_shared_memory( gl.uint8, shape=[ @@ -1718,7 +1710,7 @@ def initialize( else: x_scale_buffer = gl.constexpr(0) - if cfg.SCALE_VIA_LDS and cfg.WITH_W_MX_SCALE: + if cfg.W_SCALE_VIA_LDS: w_scale_buffer = gl.allocate_shared_memory( gl.uint8, shape=[ @@ -1841,7 +1833,7 @@ def issue_x_top(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): load_idx, self.x_buffer_top, pred, USE_MASK=USE_MASK, COMMIT=0 ) if cfg.SCALE_VIA_LDS: - if cfg.WITH_X_MX_SCALE: + if cfg.X_SCALE_VIA_LDS: self.x_scale_desc.issue_async_load( load_idx, self.x_scale_buffer, @@ -1849,7 +1841,7 @@ def issue_x_top(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): USE_MASK=USE_MASK, COMMIT=0, ) - if cfg.WITH_W_MX_SCALE: + if cfg.W_SCALE_VIA_LDS: self.w_scale_desc.issue_async_load( load_idx, self.w_scale_buffer, @@ -2074,16 +2066,8 @@ def __init__( self.x_buffer = x_buffer self.w_buffer_top = w_buffer_top if not cfg.W_VIA_VGPR else gl.constexpr(0) self.w_buffer_bot = w_buffer_bot if not cfg.W_VIA_VGPR else gl.constexpr(0) - self.x_scale_buffer = ( - x_scale_buffer - if (cfg.SCALE_VIA_LDS and cfg.WITH_X_MX_SCALE) - else gl.constexpr(0) - ) - self.w_scale_buffer = ( - w_scale_buffer - if (cfg.SCALE_VIA_LDS and cfg.WITH_W_MX_SCALE) - else gl.constexpr(0) - ) + self.x_scale_buffer = x_scale_buffer if cfg.X_SCALE_VIA_LDS else gl.constexpr(0) + self.w_scale_buffer = w_scale_buffer if cfg.W_SCALE_VIA_LDS else gl.constexpr(0) self.x_desc = x_desc self.w_desc_top = w_desc_top self.w_desc_bot = w_desc_bot @@ -2153,7 +2137,7 @@ def initialize( layout=cfg.shared_layout_w_half_n, ) - if cfg.SCALE_VIA_LDS and cfg.WITH_X_MX_SCALE: + if cfg.X_SCALE_VIA_LDS: x_scale_buffer = gl.allocate_shared_memory( gl.uint8, shape=[ @@ -2166,7 +2150,7 @@ def initialize( else: x_scale_buffer = gl.constexpr(0) - if cfg.SCALE_VIA_LDS and cfg.WITH_W_MX_SCALE: + if cfg.W_SCALE_VIA_LDS: w_scale_buffer = gl.allocate_shared_memory( gl.uint8, shape=[ @@ -2216,7 +2200,7 @@ def issue_local_load_x(self, mfma_idx): if cfg.USE_MFMA_SCALED: if cfg.WITH_X_MX_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.X_SCALE_VIA_LDS: scale_x = self.x_scale_desc.issue_local_load_unswizzle( mfma_idx, self.x_scale_buffer, @@ -2247,7 +2231,7 @@ def issue_global_load_top(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): load_idx, self.x_buffer, pred, USE_MASK=USE_MASK, COMMIT=0 ) if cfg.SCALE_VIA_LDS: - if cfg.WITH_X_MX_SCALE: + if cfg.X_SCALE_VIA_LDS: self.x_scale_desc.issue_async_load( load_idx, self.x_scale_buffer, @@ -2255,7 +2239,7 @@ def issue_global_load_top(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): USE_MASK=USE_MASK, COMMIT=0, ) - if cfg.WITH_W_MX_SCALE: + if cfg.W_SCALE_VIA_LDS: self.w_scale_desc.issue_async_load( load_idx, self.w_scale_buffer, @@ -3738,6 +3722,8 @@ def _pipelined_moe_tile_compute( W_VIA_VGPR: gl.constexpr = False, W_PREFETCH: gl.constexpr = True, W_CACHE_CG: gl.constexpr = False, + X_SCALE_VIA_LDS: gl.constexpr = False, + W_SCALE_VIA_LDS: gl.constexpr = False, USE_NARROW_N_STORE_LAYOUT: gl.constexpr = False, ): expert_id = compact_idx @@ -3791,6 +3777,8 @@ def _pipelined_moe_tile_compute( W_PRESHUFFLED=W_PRESHUFFLED, W_VIA_VGPR=W_VIA_VGPR, W_PREFETCH=W_PREFETCH, + X_SCALE_VIA_LDS=X_SCALE_VIA_LDS, + W_SCALE_VIA_LDS=W_SCALE_VIA_LDS, ) BLOCK_K_X: gl.constexpr = cfg.BLOCK_K // cfg.DIV_FACTOR_X BLOCK_K_W: gl.constexpr = cfg.BLOCK_K // cfg.DIV_FACTOR_W @@ -3831,23 +3819,21 @@ def _pipelined_moe_tile_compute( # Post-gather rows_m is in global token-id space (size M_X); # mask out junk gather_idx values too. Don't conflate M_X with # ``M`` (= dispatched tile count, can exceed M_X for top-k>1). - mask_m = pre_gather_mask & (rows_m < M_X) mask_m_x = pre_gather_mask_x & (rows_m_x < M_X) else: # Clamp OOB lanes to 0 so the buffer_load address stays in # bounds during HIP graph warm-up; mask still filters. rows_m = gl.where(pre_gather_mask, rows_m, gl.zeros_like(rows_m)) rows_m_x = gl.where(pre_gather_mask_x, rows_m_x, gl.zeros_like(rows_m_x)) - mask_m = pre_gather_mask mask_m_x = pre_gather_mask_x k_limit_x = gl.multiple_of(K // cfg.DIV_FACTOR_X, 16) k_limit_w = gl.multiple_of(K // cfg.DIV_FACTOR_W, 16) - # SCALE_VIA_LDS uses post-swizzle HBM shape via buffer_load_to_shared; - # other modes load scales G->VGPR via gl.load. + # Swizzled scale loads use post-swizzle HBM shape via buffer_load_to_shared; + # direct scale loads use G->VGPR gl.load and can follow gathered X rows. if HAS_X_BLOCK_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.X_SCALE_VIA_LDS: BLOCK_M_PS: gl.constexpr = cfg.BLOCK_M_PRESHUFFLED BLOCK_K_S_PS: gl.constexpr = cfg.BLOCK_K_SCALE_PRESHUFFLED LX_S: gl.constexpr = cfg.load_layout_x_scale @@ -3884,7 +3870,25 @@ def _pipelined_moe_tile_compute( ) rows_m_scale = off_m + offs_xs_m if HAS_GATHER: - rows_m_scale = rows_m + pre_gather_mask_scale = rows_m_scale < m_limit + rows_m_scale_safe = gl.where( + pre_gather_mask_scale, + rows_m_scale, + gl.zeros_like(rows_m_scale), + ) + rows_m_scale = gl.load( + gather_idx_ptr + rows_m_scale_safe, + mask=pre_gather_mask_scale, + other=0, + ).to(gl.int32) + mask_m_scale = pre_gather_mask_scale & (rows_m_scale < M_X) + else: + mask_m_scale = rows_m_scale < m_limit + rows_m_scale = gl.where( + mask_m_scale, + rows_m_scale, + gl.zeros_like(rows_m_scale), + ) x_scale_desc = AsyncCopyDescriptor.initialize( cfg, 0, @@ -3894,14 +3898,14 @@ def _pipelined_moe_tile_compute( offs_xs_k, stride_xsm, stride_xsk, - rows_m_scale[:, None] < M_X, + mask_m_scale[:, None], K // cfg.SCALE_BLOCK, ) else: x_scale_desc: gl.constexpr = 0 if HAS_W_BLOCK_SCALE: - if cfg.SCALE_VIA_LDS: + if cfg.W_SCALE_VIA_LDS: BLOCK_N_PS: gl.constexpr = cfg.BLOCK_N_PRESHUFFLED BLOCK_K_S_PS_W: gl.constexpr = cfg.BLOCK_K_SCALE_PRESHUFFLED LW_S: gl.constexpr = cfg.load_layout_w_scale @@ -4718,6 +4722,8 @@ def _pipelined_moe_kernel_scaled( W_VIA_VGPR: gl.constexpr = False, W_PREFETCH: gl.constexpr = True, W_CACHE_CG: gl.constexpr = False, + X_SCALE_VIA_LDS: gl.constexpr = False, + W_SCALE_VIA_LDS: gl.constexpr = False, USE_NARROW_N_STORE_LAYOUT: gl.constexpr = False, IS_MEDIUM_DECODE: gl.constexpr = False, MEDIUM_COMBINE: gl.constexpr = False, @@ -4876,6 +4882,8 @@ def _pipelined_moe_kernel_scaled( W_VIA_VGPR=W_VIA_VGPR, W_PREFETCH=W_PREFETCH, W_CACHE_CG=W_CACHE_CG, + X_SCALE_VIA_LDS=X_SCALE_VIA_LDS, + W_SCALE_VIA_LDS=W_SCALE_VIA_LDS, USE_NARROW_N_STORE_LAYOUT=USE_NARROW_N_STORE_LAYOUT, ) @@ -5269,20 +5277,27 @@ def _launch_kernel( # N-contig W staged as [BK, BN] in LDS. stride_wn, stride_wk = w3.stride(-1), w3.stride(-2) + x_scale_load_mode = scale_load_mode + if has_x_block_scale and gather_indx is not None: + x_scale_load_mode = "transpose" + w_scale_load_mode = scale_load_mode + x_scale_via_lds = x_scale_load_mode == "swizzle" and has_x_block_scale + w_scale_via_lds = w_scale_load_mode == "swizzle" and has_w_block_scale + if has_w_block_scale: w_scale3 = w_scale if w_scale.ndim == 3 else w_scale.unsqueeze(0) - w_scale_proc3 = _preprocess_scale(w_scale3, scale_load_mode) + w_scale_proc3 = _preprocess_scale(w_scale3, w_scale_load_mode) stride_wse = w_scale_proc3.stride(0) - stride_wsn, stride_wsk = _scale_strides(w_scale_proc3, scale_load_mode) + stride_wsn, stride_wsk = _scale_strides(w_scale_proc3, w_scale_load_mode) w_scale_buf = w_scale_proc3 else: stride_wse = stride_wsn = stride_wsk = 0 w_scale_buf = _make_dummy(x.device, torch.uint8) x_scale_proc = ( - _preprocess_scale(x_scale, scale_load_mode) if has_x_block_scale else None + _preprocess_scale(x_scale, x_scale_load_mode) if has_x_block_scale else None ) - stride_xsm, stride_xsk = _scale_strides(x_scale_proc, scale_load_mode) + stride_xsm, stride_xsk = _scale_strides(x_scale_proc, x_scale_load_mode) x_scale_buf = ( x_scale_proc if x_scale_proc is not None else _make_dummy(x.device, torch.uint8) @@ -5415,6 +5430,8 @@ def _launch_kernel( W_VIA_VGPR=False, W_PREFETCH=False, W_CACHE_CG=bool(w_cache_cg), + X_SCALE_VIA_LDS=bool(x_scale_via_lds), + W_SCALE_VIA_LDS=bool(w_scale_via_lds), USE_NARROW_N_STORE_LAYOUT=bool(use_narrow_n_store_layout), GRID_N=grid_n, GROUP_M=group_m, diff --git a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py index c02037ca5..5cac900f2 100644 --- a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py +++ b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py @@ -5,6 +5,7 @@ import pytest import torch +from tokenspeed_kernel import quantize_mxfp4 def _is_gfx950() -> bool: @@ -529,6 +530,32 @@ def _compute_torch_gemm1_reference( return output +def _compute_torch_gemm1_mxfp4_reference( + raw: RawMxfp4Weights, + gemm1_input: torch.Tensor, + gemm1_scale: torch.Tensor, + weights: Mxfp4Weights, + ragged_metadata: Any, + gather_indx: torch.Tensor, +) -> torch.Tensor: + output = torch.empty( + (gather_indx.numel(), INTERMEDIATE_SIZE), + device=gemm1_input.device, + dtype=torch.bfloat16, + ) + x = _mxfp4_dequant(gemm1_input, gemm1_scale) + for expert, (start, end) in enumerate(_expert_ranges(ragged_metadata)): + if start == end: + continue + row_idx = gather_indx[start:end].long() + w13 = _mxfp4_dequant(raw.w13_weight[expert], raw.w13_scale[expert]) + gate_up = x[row_idx] @ w13.T + if weights.w13_bias is not None: + gate_up = gate_up + weights.w13_bias[expert][None, :] + output[start:end] = _swiglu_reference(gate_up).to(torch.bfloat16) + return output + + def _compute_torch_gemm2_reference( raw: RawMxfp4Weights, gemm2_input: torch.Tensor, @@ -700,3 +727,59 @@ def test_gluon_moe_gemms_with_preshuffle_match_torch_gfx950( weights=mxfp4_weights.preshuffled, torch_references=torch_references, ) + + +@requires_gfx950 +@pytest.mark.parametrize("variant", ("nonpreshuffled", "preshuffled")) +def test_gluon_moe_gemm1_dynamic_mxfp4_gather_scales_match_torch_gfx950( + mxfp4_weights: Mxfp4WeightVariants, + variant: str, +) -> None: + num_tokens = 64 + weights = getattr(mxfp4_weights, variant) + hidden_states, router_logits = _make_hidden_and_router(num_tokens) + ragged_metadata, gather_indx, _scatter_indx, _gate_scal = default_route( + router_logits, + TOPK, + dtype=router_logits.dtype, + ) + gemm1_input, gemm1_scale = quantize_mxfp4( + hidden_states, + scale_size=MXFP4_BLOCK, + scale_layout="linear", + solution="triton", + enable_pdl=False, + ) + precision_config = gluon_moe.PrecisionConfig( + a_mx_scale=gemm1_scale, + a_microblock_size=MXFP4_BLOCK, + b_mx_scale=weights.w13_precision_config.b_mx_scale, + b_microblock_size=MXFP4_BLOCK, + out_dtype=torch.bfloat16, + ) + + with torch.no_grad(): + actual = gluon_moe.gluon_mxfp_ragged_matmul( + gemm1_input, + weights.w13_weight, + weights.w13_bias, + a_ragged_metadata=ragged_metadata, + gather_indx=gather_indx, + precision_config=precision_config, + fused_activation=_swiglu_activation(), + ) + expected = _compute_torch_gemm1_mxfp4_reference( + mxfp4_weights.raw, + gemm1_input, + gemm1_scale, + weights, + ragged_metadata, + gather_indx, + ) + + torch.testing.assert_close( + actual.float(), + expected.float(), + atol=GEMM_ATOL, + rtol=RTOL, + ) From 0d51f3c8bf6e14d4b914e48d9a4918b47b1ce17a Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 26 Jun 2026 12:34:09 -0500 Subject: [PATCH 2/6] Address MXFP4 gathered-scale review comments Split X and W scale LDS handling fully by removing the aggregate SCALE_VIA_LDS config field, and make SliceMN direct scale loads read the requested subtile from the descriptor instead of assuming an LDS scale buffer exists. Also delay the vendor-neutral quantize_mxfp4 test import until after gfx950 collection gating and run the new gathered-scale test over the existing key token counts. Signed-off-by: Quinn Dawkins --- .../ops/moe/fused_mxfp_gfx950.py | 111 ++++++++++++------ .../test/ops/test_gluon_moe_gemm_gfx950.py | 6 +- 2 files changed, 79 insertions(+), 38 deletions(-) diff --git a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py index 2010659da..43d50dd59 100644 --- a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py +++ b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py @@ -454,7 +454,6 @@ class MoEConfig: WITH_X_MX_SCALE: gl.constexpr WITH_W_MX_SCALE: gl.constexpr SCALE_LOAD_MODE: gl.constexpr - SCALE_VIA_LDS: gl.constexpr X_SCALE_VIA_LDS: gl.constexpr W_SCALE_VIA_LDS: gl.constexpr PRESHUFFLE_FACTOR: gl.constexpr @@ -542,7 +541,6 @@ def __init__( if W_SCALE_VIA_LDS is None: W_SCALE_VIA_LDS = SCALE_LOAD_MODE == "swizzle" and WITH_W_MX_SCALE _scale_via_lds = X_SCALE_VIA_LDS or W_SCALE_VIA_LDS - self.SCALE_VIA_LDS = gl.constexpr(_scale_via_lds) self.X_SCALE_VIA_LDS = gl.constexpr(X_SCALE_VIA_LDS) self.W_SCALE_VIA_LDS = gl.constexpr(W_SCALE_VIA_LDS) self.PRESHUFFLE_FACTOR = gl.constexpr(_SCALE_PRESHUFFLE_FACTOR) @@ -734,7 +732,8 @@ def issue_global_loads(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): self.w_desc.issue_async_load( load_idx, self.w_buffer, pred, USE_MASK=USE_MASK ) - if cfg.SCALE_VIA_LDS: + scale_via_lds: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + if scale_via_lds: if cfg.X_SCALE_VIA_LDS: self.x_scale_desc.issue_async_load( load_idx, self.x_scale_buffer, pred, USE_MASK=USE_MASK @@ -1145,6 +1144,23 @@ def _load_scale_tile_via_gl_load(desc, mfma_idx): return gl.load(base + desc.offsets, mask=mask, other=0) +@gluon.jit +def _load_scale_subtile_via_gl_load( + desc, mfma_idx, subtile_start_nonk: gl.constexpr, SUBTILE_NONK: gl.constexpr +): + EVEN_K: gl.constexpr = desc.cfg.EVEN_K + off_k_step = mfma_idx * desc.BLOCK_K + base = desc.ptr + off_k_step * desc.stride_k + offsets = desc.offsets.slice(subtile_start_nonk, SUBTILE_NONK, 0) + masks_nonk = desc.masks_nonk.slice(subtile_start_nonk, SUBTILE_NONK, 0) + if EVEN_K: + mask = masks_nonk + else: + mask_k = gl.expand_dims(off_k_step + desc.off_k, desc.op_idx) < desc.k_limit + mask = mask_k & masks_nonk + return gl.load(base + offsets, mask=mask, other=0) + + @composition @gluon.aggregate class MoEPipelinedProgram: @@ -1754,16 +1770,24 @@ def issue_local_load_x_sub(self, mfma_idx, subtile_idx_m: gl.constexpr): if cfg.USE_MFMA_SCALED: if cfg.WITH_X_MX_SCALE: - scale_x = self.x_scale_desc.issue_local_load_unswizzle_sub( - mfma_idx, - self.x_scale_buffer, - cfg.layout_x_scale, - cfg.BLOCK_M_PRESHUFFLED, - cfg.BLOCK_M, - BLOCK_K_SCALE, - SUBTILE_M, - subtile_start_m, - ) + if cfg.X_SCALE_VIA_LDS: + scale_x = self.x_scale_desc.issue_local_load_unswizzle_sub( + mfma_idx, + self.x_scale_buffer, + cfg.layout_x_scale, + cfg.BLOCK_M_PRESHUFFLED, + cfg.BLOCK_M, + BLOCK_K_SCALE, + SUBTILE_M, + subtile_start_m, + ) + else: + scale_x = _load_scale_subtile_via_gl_load( + self.x_scale_desc, + mfma_idx, + subtile_start_m, + SUBTILE_M, + ) else: scale_x = gl.full( [SUBTILE_M, BLOCK_K_SCALE], @@ -1797,16 +1821,24 @@ def issue_local_load_w_sub(self, mfma_idx, subtile_idx_n: gl.constexpr): if cfg.USE_MFMA_SCALED: if cfg.WITH_W_MX_SCALE: - scale_w = self.w_scale_desc.issue_local_load_unswizzle_sub( - mfma_idx, - self.w_scale_buffer, - cfg.layout_w_scale, - cfg.BLOCK_N_PRESHUFFLED, - cfg.BLOCK_N, - BLOCK_K_SCALE, - SUBTILE_N, - subtile_start_n, - ) + if cfg.W_SCALE_VIA_LDS: + scale_w = self.w_scale_desc.issue_local_load_unswizzle_sub( + mfma_idx, + self.w_scale_buffer, + cfg.layout_w_scale, + cfg.BLOCK_N_PRESHUFFLED, + cfg.BLOCK_N, + BLOCK_K_SCALE, + SUBTILE_N, + subtile_start_n, + ) + else: + scale_w = _load_scale_subtile_via_gl_load( + self.w_scale_desc, + mfma_idx, + subtile_start_n, + SUBTILE_N, + ) else: scale_w = gl.full( [SUBTILE_N, BLOCK_K_SCALE], @@ -1832,7 +1864,8 @@ def issue_x_top(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): self.x_desc_top.issue_async_load( load_idx, self.x_buffer_top, pred, USE_MASK=USE_MASK, COMMIT=0 ) - if cfg.SCALE_VIA_LDS: + scale_via_lds: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + if scale_via_lds: if cfg.X_SCALE_VIA_LDS: self.x_scale_desc.issue_async_load( load_idx, @@ -2230,7 +2263,8 @@ def issue_global_load_top(self, load_idx, pred=1, USE_MASK: gl.constexpr = -1): self.x_desc.issue_async_load( load_idx, self.x_buffer, pred, USE_MASK=USE_MASK, COMMIT=0 ) - if cfg.SCALE_VIA_LDS: + scale_via_lds: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + if scale_via_lds: if cfg.X_SCALE_VIA_LDS: self.x_scale_desc.issue_async_load( load_idx, @@ -3091,6 +3125,7 @@ def _run_moe_tile_w_via_vgpr( if USE_SLICE_N: SUB_BN: gl.constexpr = BLOCK_N // 2 + scale_via_lds_slice_n: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS gl.static_assert( SUB_BN == 128 and BLOCK_K_W == 128 and NUM_WARPS == 4, "USE_SLICE_N + W_VIA_VGPR requires SUB_BN=BLOCK_K_W=128 " @@ -3098,7 +3133,7 @@ def _run_moe_tile_w_via_vgpr( "this shape (re-derive otherwise).", ) LOAD_W_HALF_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - SUB_BN // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS, False + SUB_BN // 16, BLOCK_K_W, scale_via_lds_slice_n, False ) ( offsets_h, @@ -3144,6 +3179,7 @@ def _run_moe_tile_w_via_vgpr( ) return pgm.pipeline(K) else: + scale_via_lds_full: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS gl.static_assert( BLOCK_N == 128, "W_VIA_VGPR full-tile layout bases assume BLOCK_N=128. " @@ -3151,7 +3187,7 @@ def _run_moe_tile_w_via_vgpr( ) BLOCK_N_LAYOUT: gl.constexpr = BLOCK_N LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS, False + BLOCK_N_LAYOUT // 16, BLOCK_K_W, scale_via_lds_full, False ) offsets_b_vgpr, base_off_b_vgpr = _make_preshuffled_w_full_offsets( w_base_offset, @@ -3237,6 +3273,7 @@ def _run_moe_tile_preshuffled_lds_w( if USE_SLICE_N: SUB_BN: gl.constexpr = BLOCK_N // 2 + scale_via_lds_slice_n: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS gl.static_assert( SUB_BN == 128 and BLOCK_K_W == 128 and NUM_WARPS == 4, "USE_SLICE_N + preshuffled W requires SUB_BN=BLOCK_K_W=128 " @@ -3244,10 +3281,10 @@ def _run_moe_tile_preshuffled_lds_w( "this shape (re-derive otherwise).", ) LOAD_W_HALF_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - SUB_BN // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS + SUB_BN // 16, BLOCK_K_W, scale_via_lds_slice_n ) LOAD_W_HALF_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - SUB_BN // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS, True + SUB_BN // 16, BLOCK_K_W, scale_via_lds_slice_n, True ) ( offsets_h, @@ -3303,11 +3340,12 @@ def _run_moe_tile_preshuffled_lds_w( # Keep the original half-tile layout in that specialization so the # preshuffled copy/read layouts remain valid during compilation. BLOCK_N_LAYOUT: gl.constexpr = (BLOCK_N // 2) if USE_SLICE_N else BLOCK_N + scale_via_lds_full: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS LOAD_W_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS + BLOCK_N_LAYOUT // 16, BLOCK_K_W, scale_via_lds_full ) LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS, True + BLOCK_N_LAYOUT // 16, BLOCK_K_W, scale_via_lds_full, True ) offsets_b_vgpr, base_off_b_vgpr = _make_preshuffled_w_full_offsets( w_base_offset, @@ -5115,7 +5153,7 @@ def _launch_kernel( block_n, block_k, scale_block=32, - has_x_scale=has_x_block_scale, + has_x_scale=has_x_block_scale and gather_indx is None, has_w_scale=has_w_block_scale, k=K, x_format=x_format, @@ -7252,7 +7290,7 @@ def _warp_decode_stage1_coop_compute( """Cooperative gate_up GEMM + bias + SwiGLU + fp8-quant + store for one (token, slot, expert). N runs over the INTERLEAVED gate_up rows (2*I); ``_swiglu_reduce`` splits even=gate / odd=up. Mirrors the plain path of - ``_pipelined_moe_tile_compute`` (W_TRANSPOSE=False, SCALE_VIA_LDS w-scale, + ``_pipelined_moe_tile_compute`` (W_TRANSPOSE=False, swizzled w-scale, per-tensor x scale) but specialized to a single decode token (row 0 of the BLOCK_M tile). """ @@ -7274,7 +7312,7 @@ def _warp_decode_stage1_coop_compute( not W_PRESHUFFLED, # W_TRANSPOSE for non-preshuffled K-packed-contiguous W False, # WITH_X_MX_SCALE (per-tensor x scale only) True, # WITH_W_MX_SCALE (e8m0 block scales) - "swizzle", # SCALE_LOAD_MODE -> SCALE_VIA_LDS unswizzle + "swizzle", # SCALE_LOAD_MODE -> W_SCALE_VIA_LDS unswizzle gl.int32, (1, 1, 1), # NUM_SUBTILES False, # EVEN_K (D=2880 not a multiple of BLOCK_K) @@ -7325,11 +7363,12 @@ def _warp_decode_stage1_coop_compute( "warp_decode preshuffled W13 path assumes 128x128 W tiles " "and NUM_WARPS=4; re-derive the copy/read layouts for other shapes.", ) + scale_via_lds: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS LOAD_W_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - BLOCK_N // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS + BLOCK_N // 16, BLOCK_K_W, scale_via_lds ) LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N // 16, BLOCK_K_W, cfg.SCALE_VIA_LDS, True + BLOCK_N // 16, BLOCK_K_W, scale_via_lds, True ) offsets_w, base_off_w = _make_preshuffled_w_full_offsets( w_base_offset, diff --git a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py index 5cac900f2..488c92215 100644 --- a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py +++ b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py @@ -5,7 +5,6 @@ import pytest import torch -from tokenspeed_kernel import quantize_mxfp4 def _is_gfx950() -> bool: @@ -730,12 +729,15 @@ def test_gluon_moe_gemms_with_preshuffle_match_torch_gfx950( @requires_gfx950 +@pytest.mark.parametrize("num_tokens", KEY_NUM_TOKENS) @pytest.mark.parametrize("variant", ("nonpreshuffled", "preshuffled")) def test_gluon_moe_gemm1_dynamic_mxfp4_gather_scales_match_torch_gfx950( + num_tokens: int, mxfp4_weights: Mxfp4WeightVariants, variant: str, ) -> None: - num_tokens = 64 + from tokenspeed_kernel import quantize_mxfp4 + weights = getattr(mxfp4_weights, variant) hidden_states, router_logits = _make_hidden_and_router(num_tokens) ragged_metadata, gather_indx, _scatter_indx, _gate_scal = default_route( From e0451d9b722598037fb9d13c073cbc0b73cd6b29 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 26 Jun 2026 12:55:17 -0500 Subject: [PATCH 3/6] Keep MXFP4 gathered-scale test AMD-local Replace the gathered activation test's tokenspeed_kernel quantization dependency with a small test-local MXFP4 quantizer. This keeps tokenspeed-kernel-amd tests free of vendor-neutral package imports while preserving the packed e2m1/e8m0 inputs needed by the Gluon MoE path. Signed-off-by: Quinn Dawkins --- .../test/ops/test_gluon_moe_gemm_gfx950.py | 40 ++++++++++++++----- 1 file changed, 31 insertions(+), 9 deletions(-) diff --git a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py index 488c92215..d3eef496c 100644 --- a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py +++ b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py @@ -441,6 +441,36 @@ def _make_hidden_and_router(num_tokens: int) -> tuple[torch.Tensor, torch.Tensor return hidden_states, router_logits +def _quantize_mxfp4_for_test(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + if x.shape[-1] % MXFP4_BLOCK != 0: + raise ValueError("MXFP4 test quantization requires 32-element blocks") + + x_blocks = x.to(torch.float32).reshape( + *x.shape[:-1], + x.shape[-1] // MXFP4_BLOCK, + MXFP4_BLOCK, + ) + max_abs = x_blocks.abs().amax(dim=-1) + min_exp = torch.full_like(max_abs, -127.0) + scale_exp = torch.where( + max_abs > 0, + torch.floor(torch.log2(max_abs)) - 2, + min_exp, + ).clamp(-127, 127) + + scaled = x_blocks * torch.exp2(-scale_exp).unsqueeze(-1) + abs_scaled = scaled.abs() + codes = torch.zeros_like(abs_scaled, dtype=torch.uint8) + for threshold in (0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0): + codes += (abs_scaled >= threshold).to(torch.uint8) + codes |= (scaled < 0).to(torch.uint8) * 8 + + packed_blocks = codes[..., 0::2] | (codes[..., 1::2] << 4) + packed = packed_blocks.reshape(*x.shape[:-1], x.shape[-1] // 2).contiguous() + scales = (scale_exp.to(torch.int16) + 127).to(torch.uint8).contiguous() + return packed, scales + + def _make_gemm2_input(num_tokens: int, scale: torch.Tensor) -> torch.Tensor: generator = torch.Generator(device="cuda").manual_seed(19000 + num_tokens) exact_values = ( @@ -736,8 +766,6 @@ def test_gluon_moe_gemm1_dynamic_mxfp4_gather_scales_match_torch_gfx950( mxfp4_weights: Mxfp4WeightVariants, variant: str, ) -> None: - from tokenspeed_kernel import quantize_mxfp4 - weights = getattr(mxfp4_weights, variant) hidden_states, router_logits = _make_hidden_and_router(num_tokens) ragged_metadata, gather_indx, _scatter_indx, _gate_scal = default_route( @@ -745,13 +773,7 @@ def test_gluon_moe_gemm1_dynamic_mxfp4_gather_scales_match_torch_gfx950( TOPK, dtype=router_logits.dtype, ) - gemm1_input, gemm1_scale = quantize_mxfp4( - hidden_states, - scale_size=MXFP4_BLOCK, - scale_layout="linear", - solution="triton", - enable_pdl=False, - ) + gemm1_input, gemm1_scale = _quantize_mxfp4_for_test(hidden_states) precision_config = gluon_moe.PrecisionConfig( a_mx_scale=gemm1_scale, a_microblock_size=MXFP4_BLOCK, From 692a6fd3a8783d17b6f0eb0ddc4f63901851d9c6 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 29 Jun 2026 14:44:36 -0500 Subject: [PATCH 4/6] Exercise gathered MXFP4 scales in Gluon test Signed-off-by: Quinn Dawkins --- .../test/ops/test_gluon_moe_gemm_gfx950.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py index d3eef496c..c18538be0 100644 --- a/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py +++ b/tokenspeed-kernel-amd/test/ops/test_gluon_moe_gemm_gfx950.py @@ -774,22 +774,18 @@ def test_gluon_moe_gemm1_dynamic_mxfp4_gather_scales_match_torch_gfx950( dtype=router_logits.dtype, ) gemm1_input, gemm1_scale = _quantize_mxfp4_for_test(hidden_states) - precision_config = gluon_moe.PrecisionConfig( - a_mx_scale=gemm1_scale, - a_microblock_size=MXFP4_BLOCK, - b_mx_scale=weights.w13_precision_config.b_mx_scale, - b_microblock_size=MXFP4_BLOCK, - out_dtype=torch.bfloat16, - ) with torch.no_grad(): actual = gluon_moe.gluon_mxfp_ragged_matmul( gemm1_input, weights.w13_weight, weights.w13_bias, + w_mx_scale=weights.w13_precision_config.b_mx_scale, + x_mx_scale=gemm1_scale, + x_format="e2m1", + out_dtype=weights.w13_precision_config.out_dtype, a_ragged_metadata=ragged_metadata, gather_indx=gather_indx, - precision_config=precision_config, fused_activation=_swiglu_activation(), ) expected = _compute_torch_gemm1_mxfp4_reference( From 19a4a7821250dd68c0dd94abf9f37924dd75aa13 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 30 Jun 2026 14:52:09 -0500 Subject: [PATCH 5/6] Fix preshuffled MXFP4 scale layouts Signed-off-by: Quinn Dawkins --- .../ops/moe/fused_mxfp_gfx950.py | 78 ++++++++++--------- 1 file changed, 40 insertions(+), 38 deletions(-) diff --git a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py index 43d50dd59..6831df0f0 100644 --- a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py +++ b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py @@ -3125,7 +3125,7 @@ def _run_moe_tile_w_via_vgpr( if USE_SLICE_N: SUB_BN: gl.constexpr = BLOCK_N // 2 - scale_via_lds_slice_n: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + w_scale_via_lds_slice_n: gl.constexpr = cfg.W_SCALE_VIA_LDS gl.static_assert( SUB_BN == 128 and BLOCK_K_W == 128 and NUM_WARPS == 4, "USE_SLICE_N + W_VIA_VGPR requires SUB_BN=BLOCK_K_W=128 " @@ -3133,7 +3133,7 @@ def _run_moe_tile_w_via_vgpr( "this shape (re-derive otherwise).", ) LOAD_W_HALF_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - SUB_BN // 16, BLOCK_K_W, scale_via_lds_slice_n, False + SUB_BN // 16, BLOCK_K_W, w_scale_via_lds_slice_n, False ) ( offsets_h, @@ -3179,7 +3179,7 @@ def _run_moe_tile_w_via_vgpr( ) return pgm.pipeline(K) else: - scale_via_lds_full: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + w_scale_via_lds_full: gl.constexpr = cfg.W_SCALE_VIA_LDS gl.static_assert( BLOCK_N == 128, "W_VIA_VGPR full-tile layout bases assume BLOCK_N=128. " @@ -3187,7 +3187,7 @@ def _run_moe_tile_w_via_vgpr( ) BLOCK_N_LAYOUT: gl.constexpr = BLOCK_N LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, scale_via_lds_full, False + BLOCK_N_LAYOUT // 16, BLOCK_K_W, w_scale_via_lds_full, False ) offsets_b_vgpr, base_off_b_vgpr = _make_preshuffled_w_full_offsets( w_base_offset, @@ -3273,7 +3273,7 @@ def _run_moe_tile_preshuffled_lds_w( if USE_SLICE_N: SUB_BN: gl.constexpr = BLOCK_N // 2 - scale_via_lds_slice_n: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + w_scale_via_lds_slice_n: gl.constexpr = cfg.W_SCALE_VIA_LDS gl.static_assert( SUB_BN == 128 and BLOCK_K_W == 128 and NUM_WARPS == 4, "USE_SLICE_N + preshuffled W requires SUB_BN=BLOCK_K_W=128 " @@ -3281,10 +3281,10 @@ def _run_moe_tile_preshuffled_lds_w( "this shape (re-derive otherwise).", ) LOAD_W_HALF_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - SUB_BN // 16, BLOCK_K_W, scale_via_lds_slice_n + SUB_BN // 16, BLOCK_K_W, w_scale_via_lds_slice_n ) LOAD_W_HALF_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - SUB_BN // 16, BLOCK_K_W, scale_via_lds_slice_n, True + SUB_BN // 16, BLOCK_K_W, w_scale_via_lds_slice_n, True ) ( offsets_h, @@ -3340,12 +3340,12 @@ def _run_moe_tile_preshuffled_lds_w( # Keep the original half-tile layout in that specialization so the # preshuffled copy/read layouts remain valid during compilation. BLOCK_N_LAYOUT: gl.constexpr = (BLOCK_N // 2) if USE_SLICE_N else BLOCK_N - scale_via_lds_full: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + w_scale_via_lds_full: gl.constexpr = cfg.W_SCALE_VIA_LDS LOAD_W_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, scale_via_lds_full + BLOCK_N_LAYOUT // 16, BLOCK_K_W, w_scale_via_lds_full ) LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, scale_via_lds_full, True + BLOCK_N_LAYOUT // 16, BLOCK_K_W, w_scale_via_lds_full, True ) offsets_b_vgpr, base_off_b_vgpr = _make_preshuffled_w_full_offsets( w_base_offset, @@ -6629,9 +6629,9 @@ def _gluon_mxfp4_fp8_warp_decode_moe( return None if w13_preshuffled and two_i % 128 != 0: return None - I = two_i // 2 + i_dim = two_i // 2 w2_k_pk = int(getattr(w2_raw, "original_k_pk", int(w2_raw.shape[1]))) - if w2_k_pk * 2 != I: + if w2_k_pk * 2 != i_dim: return None w2_n_phys = int(w2_raw.shape[2]) N = int(getattr(w2_raw, "original_n", w2_n_phys)) @@ -6676,7 +6676,7 @@ def _gluon_mxfp4_fp8_warp_decode_moe( s2_split_k = 1 inter = torch.empty( - (n_tokens * top_k, I), dtype=x_fp8.dtype, device=hidden_states.device + (n_tokens * top_k, i_dim), dtype=x_fp8.dtype, device=hidden_states.device ) # Cooperative-LDS, num_warps=4, software-pipelined stage1 -- the smallest-M # decode path (this wrapper is only entered for n_tokens <= WARP_DECODE_MAX_M). @@ -6687,14 +6687,14 @@ def _gluon_mxfp4_fp8_warp_decode_moe( COOP_BLOCK_N = 128 if w13_preshuffled else 64 COOP_BLOCK_K = 256 COOP_NUM_BUFFERS = 3 - coop_grid = (n_tokens * ((2 * I + COOP_BLOCK_N - 1) // COOP_BLOCK_N) * top_k,) + coop_grid = (n_tokens * ((2 * i_dim + COOP_BLOCK_N - 1) // COOP_BLOCK_N) * top_k,) # X is stored as raw i8 in LDS and bitcast to e4m3 in mfma_scaled; pass the # uint8 view (an fp8 LDS buffer fails to lower). x_uint8 = x_fp8.view(torch.uint8) # fmt: off _warp_decode_topk_stage1_coop_kernel[coop_grid]( x_uint8, router_logits_c, w13_raw, w13_scale, topk_ids, topk_weights, inter, - n_tokens, n_experts, D, I, + n_tokens, n_experts, D, i_dim, x_uint8.stride(0), x_uint8.stride(1), router_logits_c.stride(0), topk_ids.stride(0), topk_weights.stride(0), w13_raw.stride(0), w13_raw.stride(-2), w13_raw.stride(-1), @@ -6734,13 +6734,13 @@ def _gluon_mxfp4_fp8_warp_decode_moe( # fmt: off _warp_decode_stage2_fp8_mxfp4_kernel[s2_grid]( inter, w2_raw, w2_scale, topk_ids, topk_weights, s2_dst, - n_tokens, N, w2_n_phys, I, + n_tokens, N, w2_n_phys, i_dim, inter.stride(0), inter.stride(1), w2_raw.stride(0), w2_raw.stride(-2), w2_raw.stride(-1), w2_scale.stride(0), w2_scale.stride(-2), w2_scale.stride(-1), s2_stride_om, s2_stride_on, s2_stride_ok, w2_act_scale, b2, - I_PACKED=I // 2, TOPK=top_k, + I_PACKED=i_dim // 2, TOPK=top_k, BLOCK_K=BLOCK_K, BLOCK_N=S2_BLOCK_N, M_DUP=S2_M_DUP, W_PRESHUFFLED=w2_preshuffled, HAS_BIAS=w2_bias is not None, SPLIT_K=s2_split_k, @@ -7262,7 +7262,7 @@ def _warp_decode_stage1_coop_compute( Y, M, D, - I, + i_dim, stride_xm, stride_xk, stride_we, @@ -7294,7 +7294,7 @@ def _warp_decode_stage1_coop_compute( per-tensor x scale) but specialized to a single decode token (row 0 of the BLOCK_M tile). """ - N = 2 * I + N = 2 * i_dim off_n = pid_n * BLOCK_N # Keep base offsets int32 (buffer_load_to_shared requires int32/uint32 # offsets); expert * stride fits int32 for GPT-OSS shapes. @@ -7363,12 +7363,12 @@ def _warp_decode_stage1_coop_compute( "warp_decode preshuffled W13 path assumes 128x128 W tiles " "and NUM_WARPS=4; re-derive the copy/read layouts for other shapes.", ) - scale_via_lds: gl.constexpr = cfg.X_SCALE_VIA_LDS or cfg.W_SCALE_VIA_LDS + w_scale_via_lds: gl.constexpr = cfg.W_SCALE_VIA_LDS LOAD_W_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - BLOCK_N // 16, BLOCK_K_W, scale_via_lds + BLOCK_N // 16, BLOCK_K_W, w_scale_via_lds ) LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N // 16, BLOCK_K_W, scale_via_lds, True + BLOCK_N // 16, BLOCK_K_W, w_scale_via_lds, True ) offsets_w, base_off_w = _make_preshuffled_w_full_offsets( w_base_offset, @@ -7477,7 +7477,7 @@ def _warp_decode_stage1_coop_compute( + offs_y_n[None, :].to(gl.int64) * stride_yn + offs_y_m[:, None].to(gl.int64) * 0 ) - mask_y = (offs_y_m[:, None] == 0) & valid & (offs_y_n[None, :] < I) + mask_y = (offs_y_m[:, None] == 0) & valid & (offs_y_n[None, :] < i_dim) gl.store(Y + y_offs, out, mask=mask_y) @@ -7493,7 +7493,7 @@ def _warp_decode_topk_stage1_coop_kernel( M, E, D, - I, + i_dim, stride_xm, stride_xk, stride_lm, @@ -7532,7 +7532,7 @@ def _warp_decode_topk_stage1_coop_kernel( by TOPK. Routing layouts span all warps (EP/TKP padded to 64*NUM_WARPS). """ pid = gl.program_id(axis=0) - num_pid_n = gl.cdiv(2 * I, BLOCK_N) + num_pid_n = gl.cdiv(2 * i_dim, BLOCK_N) slot = pid % TOPK rest = pid // TOPK pid_n = rest % num_pid_n @@ -7585,7 +7585,7 @@ def _warp_decode_topk_stage1_coop_kernel( _warp_decode_stage1_coop_compute( token, slot, expert, pid_n, X, W, WScale, Y, - M, D, I, + M, D, i_dim, stride_xm, stride_xk, stride_we, stride_wk, stride_wn, stride_wse, stride_wsk, stride_wsn, @@ -7645,7 +7645,7 @@ def _warp_decode_stage2_load_tile( stride_wk, stride_wsk, N_PHYS, - I, + i_dim, BLOCK_K: gl.constexpr, BLOCK_K_PACKED: gl.constexpr, BLOCK_K_SCALE: gl.constexpr, @@ -7674,8 +7674,10 @@ def _warp_decode_stage2_load_tile( if MASK_TAIL: # Partial / odd final K-tile (K = intermediate dim I): mask out-of-range # K lanes to 0 so they contribute nothing and never over-read. - sk_valid = (kt * BLOCK_K_SCALE + bsk) < (I // 32) - a = gl.amd.cdna4.buffer_load(ptr=X, offsets=a_off, mask=k_elem < I, other=0.0) + sk_valid = (kt * BLOCK_K_SCALE + bsk) < (i_dim // 32) + a = gl.amd.cdna4.buffer_load( + ptr=X, offsets=a_off, mask=k_elem < i_dim, other=0.0 + ) b_mask = k_pack < I_PACKED b = gl.amd.cdna4.buffer_load(ptr=W, offsets=b_off, mask=b_mask, other=0) s = gl.amd.cdna4.buffer_load(ptr=WScale, offsets=s_off, mask=sk_valid, other=0) @@ -7706,7 +7708,7 @@ def _warp_decode_stage2_load_pair( stride_wk, stride_wsk, N_PHYS, - I, + i_dim, BLOCK_K: gl.constexpr, BLOCK_K_PACKED: gl.constexpr, BLOCK_K_SCALE: gl.constexpr, @@ -7718,13 +7720,13 @@ def _warp_decode_stage2_load_pair( a_even, b_even, s_even = _warp_decode_stage2_load_tile( kt, ak, bk, bsk, am, X, W, WScale, x_row_off, w_expert_off, w_n_off, ws_expert_off, scale_row_off, - n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, I, + n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, i_dim, BLOCK_K, BLOCK_K_PACKED, BLOCK_K_SCALE, I_PACKED, W_PRESHUFFLED, ) a_odd, b_odd, s_odd = _warp_decode_stage2_load_tile( kt + 1, ak, bk, bsk, am, X, W, WScale, x_row_off, w_expert_off, w_n_off, ws_expert_off, scale_row_off, - n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, I, + n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, i_dim, BLOCK_K, BLOCK_K_PACKED, BLOCK_K_SCALE, I_PACKED, W_PRESHUFFLED, ) # fmt: on @@ -7760,7 +7762,7 @@ def _warp_decode_stage2_fp8_mxfp4_kernel( M, N, N_PHYS, - I, + i_dim, stride_xm, stride_xk, stride_we, @@ -7814,8 +7816,8 @@ def _warp_decode_stage2_fp8_mxfp4_kernel( # Full + partial K-tile coverage (K = intermediate dim I). The old # `num_kt = I // BLOCK_K` dropped the partial final tile, miscomputing any # I not a multiple of BLOCK_K (GPT-OSS I=2880 lost K=2816..2879). - num_full = I // BLOCK_K - total_kt = (I + BLOCK_K - 1) // BLOCK_K + num_full = i_dim // BLOCK_K + total_kt = (i_dim + BLOCK_K - 1) // BLOCK_K kt_per = (total_kt + SPLIT_K - 1) // SPLIT_K kt_start = pid_k * kt_per kt_stop = gl.minimum(kt_start + kt_per, total_kt) @@ -7869,7 +7871,7 @@ def _warp_decode_stage2_fp8_mxfp4_kernel( a_odd, b_odd, s_odd) = _warp_decode_stage2_load_pair( kt_start, ak, bk, bsk, am, X, W, WScale, x_row_off, w_expert_off, w_n_off, ws_expert_off, scale_row_off, - n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, I, + n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, i_dim, BLOCK_K, BLOCK_K_PACKED, BLOCK_K_SCALE, I_PACKED, W_PRESHUFFLED, ) for kt in range(kt_start, main_end - 2, 2): @@ -7877,7 +7879,7 @@ def _warp_decode_stage2_fp8_mxfp4_kernel( nxt_a_odd, nxt_b_odd, nxt_s_odd) = _warp_decode_stage2_load_pair( kt + 2, ak, bk, bsk, am, X, W, WScale, x_row_off, w_expert_off, w_n_off, ws_expert_off, scale_row_off, - n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, I, + n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, i_dim, BLOCK_K, BLOCK_K_PACKED, BLOCK_K_SCALE, I_PACKED, W_PRESHUFFLED, ) acc = _warp_decode_stage2_mfma_pair( @@ -7896,7 +7898,7 @@ def _warp_decode_stage2_fp8_mxfp4_kernel( a_t, b_t, s_t = _warp_decode_stage2_load_tile( kt, ak, bk, bsk, am, X, W, WScale, x_row_off, w_expert_off, w_n_off, ws_expert_off, scale_row_off, - n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, I, + n_cols, stride_xk, stride_wk, stride_wsk, N_PHYS, i_dim, BLOCK_K, BLOCK_K_PACKED, BLOCK_K_SCALE, I_PACKED, W_PRESHUFFLED, MASK_TAIL=True, ) From 2a7bf4a33440ac58f40553e70fbfa4eed36475a9 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Tue, 30 Jun 2026 15:13:49 -0500 Subject: [PATCH 6/6] Inline preshuffled W scale layout config Signed-off-by: Quinn Dawkins --- .../ops/moe/fused_mxfp_gfx950.py | 21 +++++++------------ 1 file changed, 8 insertions(+), 13 deletions(-) diff --git a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py index 6831df0f0..927104155 100644 --- a/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py +++ b/tokenspeed-kernel-amd/python/tokenspeed_kernel_amd/ops/moe/fused_mxfp_gfx950.py @@ -3125,7 +3125,6 @@ def _run_moe_tile_w_via_vgpr( if USE_SLICE_N: SUB_BN: gl.constexpr = BLOCK_N // 2 - w_scale_via_lds_slice_n: gl.constexpr = cfg.W_SCALE_VIA_LDS gl.static_assert( SUB_BN == 128 and BLOCK_K_W == 128 and NUM_WARPS == 4, "USE_SLICE_N + W_VIA_VGPR requires SUB_BN=BLOCK_K_W=128 " @@ -3133,7 +3132,7 @@ def _run_moe_tile_w_via_vgpr( "this shape (re-derive otherwise).", ) LOAD_W_HALF_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - SUB_BN // 16, BLOCK_K_W, w_scale_via_lds_slice_n, False + SUB_BN // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS, False ) ( offsets_h, @@ -3179,7 +3178,6 @@ def _run_moe_tile_w_via_vgpr( ) return pgm.pipeline(K) else: - w_scale_via_lds_full: gl.constexpr = cfg.W_SCALE_VIA_LDS gl.static_assert( BLOCK_N == 128, "W_VIA_VGPR full-tile layout bases assume BLOCK_N=128. " @@ -3187,7 +3185,7 @@ def _run_moe_tile_w_via_vgpr( ) BLOCK_N_LAYOUT: gl.constexpr = BLOCK_N LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, w_scale_via_lds_full, False + BLOCK_N_LAYOUT // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS, False ) offsets_b_vgpr, base_off_b_vgpr = _make_preshuffled_w_full_offsets( w_base_offset, @@ -3273,7 +3271,6 @@ def _run_moe_tile_preshuffled_lds_w( if USE_SLICE_N: SUB_BN: gl.constexpr = BLOCK_N // 2 - w_scale_via_lds_slice_n: gl.constexpr = cfg.W_SCALE_VIA_LDS gl.static_assert( SUB_BN == 128 and BLOCK_K_W == 128 and NUM_WARPS == 4, "USE_SLICE_N + preshuffled W requires SUB_BN=BLOCK_K_W=128 " @@ -3281,10 +3278,10 @@ def _run_moe_tile_preshuffled_lds_w( "this shape (re-derive otherwise).", ) LOAD_W_HALF_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - SUB_BN // 16, BLOCK_K_W, w_scale_via_lds_slice_n + SUB_BN // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS ) LOAD_W_HALF_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - SUB_BN // 16, BLOCK_K_W, w_scale_via_lds_slice_n, True + SUB_BN // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS, True ) ( offsets_h, @@ -3340,12 +3337,11 @@ def _run_moe_tile_preshuffled_lds_w( # Keep the original half-tile layout in that specialization so the # preshuffled copy/read layouts remain valid during compilation. BLOCK_N_LAYOUT: gl.constexpr = (BLOCK_N // 2) if USE_SLICE_N else BLOCK_N - w_scale_via_lds_full: gl.constexpr = cfg.W_SCALE_VIA_LDS LOAD_W_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, w_scale_via_lds_full + BLOCK_N_LAYOUT // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS ) LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N_LAYOUT // 16, BLOCK_K_W, w_scale_via_lds_full, True + BLOCK_N_LAYOUT // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS, True ) offsets_b_vgpr, base_off_b_vgpr = _make_preshuffled_w_full_offsets( w_base_offset, @@ -7363,12 +7359,11 @@ def _warp_decode_stage1_coop_compute( "warp_decode preshuffled W13 path assumes 128x128 W tiles " "and NUM_WARPS=4; re-derive the copy/read layouts for other shapes.", ) - w_scale_via_lds: gl.constexpr = cfg.W_SCALE_VIA_LDS LOAD_W_LAYOUT: gl.constexpr = _preshuffled_w_read_layout( - BLOCK_N // 16, BLOCK_K_W, w_scale_via_lds + BLOCK_N // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS ) LOAD_W_COPY_LAYOUT: gl.constexpr = _preshuffled_w_copy_layout( - BLOCK_N // 16, BLOCK_K_W, w_scale_via_lds, True + BLOCK_N // 16, BLOCK_K_W, cfg.W_SCALE_VIA_LDS, True ) offsets_w, base_off_w = _make_preshuffled_w_full_offsets( w_base_offset,