perf(kimi-k2): fuse decode-path small kernels to reduce launch overhead#185
Conversation
There was a problem hiding this comment.
Code Review
This pull request introduces optimized CUDA kernels to fuse operations and improve performance. Specifically, it adds kimi_residual_add_scaled_f32_kernel to combine residual addition and scaling, and split_qkv_a_norm_kernel to fuse the splitting and RMS normalization of QKV projections. These are integrated into the Rust runner and worker pipelines. The review feedback highlights a return type mismatch between the CUDA implementation (cudaError_t) and Rust FFI bindings (CUresult) for the new MLA split-norm function, as well as an optimization opportunity in the CUDA kernel to restrict shared memory writes to a single thread within the warp.
| cudaError_t kimi_mla_split_qkv_a_norm_cuda(const DType* qkv_a, | ||
| const DType* q_a_weight, | ||
| const DType* ckv_weight, | ||
| DType* q_a_normed, | ||
| DType* ckv_normed, | ||
| DType* k_rope, | ||
| float eps, | ||
| int batch_size, | ||
| cudaStream_t stream) { |
There was a problem hiding this comment.
The function kimi_mla_split_qkv_a_norm_cuda is defined as returning cudaError_t in C++, but is declared as returning CUresult in the Rust FFI bindings (ffi.rs). While both are integer types and success (0) maps correctly, their non-success error codes do not align 1:1. This mismatch can lead to incorrect error propagation or silent failures. Consider returning CUresult and mapping the error code using cudaGetLastError() similar to how it is done in kimi_experts.cu (e.g., err == cudaSuccess ? CUDA_SUCCESS : CUDA_ERROR_LAUNCH_FAILED).
| smem_q[0] = total_q; | ||
|
|
||
| float total_ckv = (tx < CKV_WARPS) ? smem_ckv[tx] : 0.f; | ||
| #pragma unroll | ||
| for (int offset = kWarpSize / 2; offset > 0; offset /= 2) | ||
| total_ckv += fi_shfl_xor(total_ckv, offset); | ||
| smem_ckv[0] = total_ckv; |
There was a problem hiding this comment.
In split_qkv_a_norm_kernel, all 32 threads of warp 0 (ty == 0) write the same reduced total_q and total_ckv values to smem_q[0] and smem_ckv[0]. To avoid redundant writes and potential bank conflicts/coherency overhead, it is recommended to restrict these writes to thread 0 of the warp (tx == 0).
if (tx == 0) {
smem_q[0] = total_q;
}
float total_ckv = (tx < CKV_WARPS) ? smem_ckv[tx] : 0.f;
#pragma unroll
for (int offset = kWarpSize / 2; offset > 0; offset /= 2)
total_ckv += fi_shfl_xor(total_ckv, offset);
if (tx == 0) {
smem_ckv[0] = total_ckv;
}
Two fusions for the MLA + MoE decode path (TP1 DP8 EP8 PPLX): 1. split_qkv_a + 2x rms_norm → single kimi_mla_split_qkv_a_norm kernel (3 launches → 1, ×61 layers, 9.5→3.5μs per layer on H20) 2. add_into + scaled_add → single kimi_residual_add_scaled_f32 kernel (2 launches → 1, ×60 MoE layers, 6.3→2.8μs per layer on H20) Both fused kernels are bitwise equal to the original path. Total: 182 fewer launches/step, ~0.58ms TPOT reduction at bs=1. E2E bs=64: TPOT p50/p99 39.64/43.36ms (baseline 40.10/43.72ms). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
f08ad91 to
7a64b61
Compare
Summary
split_qkv_a+ 2×rms_norm→ singlekimi_mla_split_qkv_a_normkernel (3→1 launch, ×61 layers, 9.5→3.5μs on H20)add_into+scaled_add→ singlekimi_residual_add_scaled_f32kernel (2→1 launch, ×60 MoE layers, 6.3→2.8μs on H20)E2E Results (H20, TP1 DP8 EP8 PPLX, bs=64, output_len=128)
Test plan
bench_servingdp8 tp1 ep8 bs64 on H20 cluster🤖 Generated with Claude Code