Skip to content

perf(kimi-k2): fuse decode-path small kernels to reduce launch overhead#185

Merged
xiaguan merged 1 commit into
mainfrom
feat/kimi-k2-kernel-fusions
May 28, 2026
Merged

perf(kimi-k2): fuse decode-path small kernels to reduce launch overhead#185
xiaguan merged 1 commit into
mainfrom
feat/kimi-k2-kernel-fusions

Conversation

@xiaguan
Copy link
Copy Markdown
Owner

@xiaguan xiaguan commented May 28, 2026

Summary

  • Fusion 1: split_qkv_a + 2× rms_norm → single kimi_mla_split_qkv_a_norm kernel (3→1 launch, ×61 layers, 9.5→3.5μs on H20)
  • Fusion 2: add_into + scaled_add → single kimi_residual_add_scaled_f32 kernel (2→1 launch, ×60 MoE layers, 6.3→2.8μs on H20)
  • Both fused kernels are bitwise equal to the original path
  • Total: 182 fewer kernel launches per decode step, ~0.58ms TPOT reduction at bs=1

E2E Results (H20, TP1 DP8 EP8 PPLX, bs=64, output_len=128)

Metric Baseline Fused Delta
TPOT p50 40.10ms 39.64ms −0.46ms (−1.1%)
TPOT p99 43.72ms 43.36ms −0.36ms

Test plan

  • Standalone C++ bench: bitwise correctness at B∈{1,4,8,16,32,64}
  • ncu profile: confirmed launch reduction and kernel duration improvement
  • E2E bench_serving dp8 tp1 ep8 bs64 on H20 cluster

🤖 Generated with Claude Code

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces optimized CUDA kernels to fuse operations and improve performance. Specifically, it adds kimi_residual_add_scaled_f32_kernel to combine residual addition and scaling, and split_qkv_a_norm_kernel to fuse the splitting and RMS normalization of QKV projections. These are integrated into the Rust runner and worker pipelines. The review feedback highlights a return type mismatch between the CUDA implementation (cudaError_t) and Rust FFI bindings (CUresult) for the new MLA split-norm function, as well as an optimization opportunity in the CUDA kernel to restrict shared memory writes to a single thread within the warp.

Comment on lines +399 to +407
cudaError_t kimi_mla_split_qkv_a_norm_cuda(const DType* qkv_a,
const DType* q_a_weight,
const DType* ckv_weight,
DType* q_a_normed,
DType* ckv_normed,
DType* k_rope,
float eps,
int batch_size,
cudaStream_t stream) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The function kimi_mla_split_qkv_a_norm_cuda is defined as returning cudaError_t in C++, but is declared as returning CUresult in the Rust FFI bindings (ffi.rs). While both are integer types and success (0) maps correctly, their non-success error codes do not align 1:1. This mismatch can lead to incorrect error propagation or silent failures. Consider returning CUresult and mapping the error code using cudaGetLastError() similar to how it is done in kimi_experts.cu (e.g., err == cudaSuccess ? CUDA_SUCCESS : CUDA_ERROR_LAUNCH_FAILED).

Comment on lines +169 to +175
smem_q[0] = total_q;

float total_ckv = (tx < CKV_WARPS) ? smem_ckv[tx] : 0.f;
#pragma unroll
for (int offset = kWarpSize / 2; offset > 0; offset /= 2)
total_ckv += fi_shfl_xor(total_ckv, offset);
smem_ckv[0] = total_ckv;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

In split_qkv_a_norm_kernel, all 32 threads of warp 0 (ty == 0) write the same reduced total_q and total_ckv values to smem_q[0] and smem_ckv[0]. To avoid redundant writes and potential bank conflicts/coherency overhead, it is recommended to restrict these writes to thread 0 of the warp (tx == 0).

    if (tx == 0) {
      smem_q[0] = total_q;
    }

    float total_ckv = (tx < CKV_WARPS) ? smem_ckv[tx] : 0.f;
#pragma unroll
    for (int offset = kWarpSize / 2; offset > 0; offset /= 2)
      total_ckv += fi_shfl_xor(total_ckv, offset);
    if (tx == 0) {
      smem_ckv[0] = total_ckv;
    }

Two fusions for the MLA + MoE decode path (TP1 DP8 EP8 PPLX):

1. split_qkv_a + 2x rms_norm → single kimi_mla_split_qkv_a_norm kernel
   (3 launches → 1, ×61 layers, 9.5→3.5μs per layer on H20)

2. add_into + scaled_add → single kimi_residual_add_scaled_f32 kernel
   (2 launches → 1, ×60 MoE layers, 6.3→2.8μs per layer on H20)

Both fused kernels are bitwise equal to the original path.
Total: 182 fewer launches/step, ~0.58ms TPOT reduction at bs=1.
E2E bs=64: TPOT p50/p99 39.64/43.36ms (baseline 40.10/43.72ms).

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@xiaguan xiaguan force-pushed the feat/kimi-k2-kernel-fusions branch from f08ad91 to 7a64b61 Compare May 28, 2026 04:58
@xiaguan xiaguan merged commit 627ff5f into main May 28, 2026
1 check passed
@xiaguan xiaguan deleted the feat/kimi-k2-kernel-fusions branch May 28, 2026 04:59
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant