From fcaab12b53da72a864528e4f77a5d80ef8388df3 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 8 Aug 2025 10:40:12 +0000 Subject: [PATCH 01/13] Dockerfile add install fa3_mtp --- docker/Dockerfile | 3 ++- docker/Dockerfile.deepep | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 653b227a8..4d8bfafe5 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly -RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . +RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \ + cd flash-attention/hopper/ && python setup.py install RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel diff --git a/docker/Dockerfile.deepep b/docker/Dockerfile.deepep index 058181e63..f8aad1664 100644 --- a/docker/Dockerfile.deepep +++ b/docker/Dockerfile.deepep @@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly -RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . +RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \ + cd flash-attention/hopper/ && python setup.py install RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev From 22e1ca672e01bc4133800fef7b067f4d165eac47 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 8 Aug 2025 10:41:36 +0000 Subject: [PATCH 02/13] add wrapper and benchmark script --- lightllm/common/flash_attn.py | 213 ++++++++++++++++++ lightllm/utils/bench_utils.py | 118 ++++++++++ .../kernel/benchmark_fa3_decode_mtp.py | 192 ++++++++++++++++ 3 files changed, 523 insertions(+) create mode 100644 lightllm/common/flash_attn.py create mode 100644 lightllm/utils/bench_utils.py create mode 100644 test/benchmark/kernel/benchmark_fa3_decode_mtp.py diff --git a/lightllm/common/flash_attn.py b/lightllm/common/flash_attn.py new file mode 100644 index 000000000..8566fd30e --- /dev/null +++ b/lightllm/common/flash_attn.py @@ -0,0 +1,213 @@ +# This file is adapted from sgl-project/sglang: +# https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/flash_attn.py +# The original code and this file are licensed under the Apache License, Version 2.0. +# +# Copyright (c) sgl-project and other contributors. +# Modifications Copyright (c) LightLLM contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Tuple, Union + +import torch +import flash_attn_3._C # Registers operators with PyTorch + +# isort: on + +def maybe_contiguous(x): + return x.contiguous() if x is not None and x.stride(-1) != 1 else x + +flash_attn_3_cuda = torch.ops.flash_attn_3 + +def flash_attn_with_kvcache( + q, + k_cache, + v_cache, + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + mtp_step=0 +): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( + -0.5 + ) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + + q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] + v_cache = ( + v_cache.contiguous() + if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 + else v_cache + ) + cu_seqlens_q, cu_seqlens_k_new = [ + maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new) + ] + page_table, cache_batch_idx, cache_leftpad = [ + maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) + ] + rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + rotary_seqlens = maybe_contiguous(rotary_seqlens) + + # out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( + out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + q, + k_cache, + v_cache, + k, + v, + qv, + None, # out + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_seqlens, + max_seqlen_q, + None, # max_seqlen_k + page_table, + cache_batch_idx, + cache_leftpad, + rotary_cos, + rotary_sin, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size[0], + window_size[1], + 0, + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + mtp_step + ) + return (out, softmax_lse, *rest) if return_softmax_lse else out \ No newline at end of file diff --git a/lightllm/utils/bench_utils.py b/lightllm/utils/bench_utils.py new file mode 100644 index 000000000..bb063f6c8 --- /dev/null +++ b/lightllm/utils/bench_utils.py @@ -0,0 +1,118 @@ +# This file is adapted from tile-ai/tilelang: +# https://github.com/tile-ai/tilelang/blob/main/tilelang/profiler/bench.py +# The original code and this file are licensed under the Apache License, Version 2.0. +# +# Copyright (c) sgl-project and other contributors. +# Modifications Copyright (c) LightLLM contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""The profiler and convert to torch utils""" + +import torch +from typing import Callable, List, Literal, Optional, Union + + +def do_bench( + fn: Callable, + warmup: float = 25, + rep: float = 100, + _n_warmup: int = 0, + _n_repeat: int = 0, + grad_to_none: Optional[List[torch.Tensor]] = None, + quantiles: Optional[List[float]] = None, + fast_flush: bool = True, + return_mode: Literal["min", "max", "mean", "median"] = "mean", +) -> Union[float, List[float]]: + """Benchmarks the runtime of a PyTorch function. + + This function handles: + - L2 cache flushing between runs for consistent timing + - Automatic warmup and repeat count calculation + - Optional gradient clearing for backward passes + - Multiple measurement modes (mean, median, min, max) + + Args: + fn: Function to benchmark + warmup: Target warmup time in milliseconds + rep: Target number of repetitions + _n_warmup: Override for number of warmup iterations + _n_repeat: Override for number of timing iterations + grad_to_none: Tensors whose gradients should be cleared between runs + quantiles: Optional performance percentiles to compute + fast_flush: Whether to use faster L2 cache flushing + return_mode: How to aggregate timing results ("mean", "median", "min", "max") + + Returns: + float: Aggregated runtime in milliseconds + """ + assert return_mode in ["min", "max", "mean", "median"] + fn() + torch.cuda.synchronize() + + # We maintain a buffer of 256 MB that we clear + # before each kernel call to make sure that the L2 + # doesn't contain any input data before the run + if fast_flush: + cache = torch.empty(int(256e6 // 4), dtype=torch.int, device="cuda") + else: + cache = torch.empty(int(256e6), dtype=torch.int8, device="cuda") + + # Estimate the runtime of the function + start_event = torch.cuda.Event(enable_timing=True) + end_event = torch.cuda.Event(enable_timing=True) + start_event.record() + for _ in range(5): + cache.zero_() + fn() + end_event.record() + torch.cuda.synchronize() + estimate_ms = start_event.elapsed_time(end_event) / 5 + + # compute number of warmup and repeat + n_warmup = max(1, int(warmup / estimate_ms)) + n_repeat = max(1, int(rep / estimate_ms)) + if _n_warmup > 0: + n_warmup = _n_warmup + if _n_repeat > 0: + n_repeat = _n_repeat + start_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + end_event = [torch.cuda.Event(enable_timing=True) for i in range(n_repeat)] + # Warm-up + for _ in range(n_warmup): + fn() + # Benchmark + for i in range(n_repeat): + # we don't want `fn` to accumulate gradient values + # if it contains a backward pass. So we clear the + # provided gradients + if grad_to_none is not None: + for x in grad_to_none: + x.grad = None + # we clear the L2 cache before each run + cache.zero_() + # record time of `fn` + start_event[i].record() + fn() + end_event[i].record() + # Record clocks + torch.cuda.synchronize() + times = torch.tensor( + [s.elapsed_time(e) for s, e in zip(start_event, end_event)], + dtype=torch.float, + ) + if quantiles is not None: + ret = torch.quantile(times, torch.tensor(quantiles, dtype=torch.float)).tolist() + if len(ret) == 1: + ret = ret[0] + return ret + return getattr(torch, return_mode)(times).item() diff --git a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py new file mode 100644 index 000000000..176cf0536 --- /dev/null +++ b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py @@ -0,0 +1,192 @@ +# This file is adapted from tile-ai/tilelang: +# https://github.com/tile-ai/tilelang/blob/main/examples/deepseek_mla/example_mla_decode_paged.py +# The original code and this file are licensed under the Apache License, Version 2.0. +# +# Copyright (c) sgl-project and other contributors. +# Modifications Copyright (c) LightLLM contributors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# type: ignore +import torch +import argparse +import math +from typing import Callable, Optional, List, Literal, Union +from lightllm.common.flash_attn import flash_attn_with_kvcache +from lightllm.utils.bench_utils import do_bench + +def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): + query = query.float() + key = key.float() + value = value.float() + key = key.repeat_interleave(h_q // h_kv, dim=0) + value = value.repeat_interleave(h_q // h_kv, dim=0) + attn_weight = query @ key.transpose(-2, -1) / math.sqrt(query.size(-1)) + if is_causal: + s_q = query.shape[-2] + s_k = key.shape[-2] + attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) + temp_mask = torch.ones( + s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + attn_weight += attn_bias + lse = attn_weight.logsumexp(dim=-1) + attn_weight = torch.softmax(attn_weight, dim=-1, dtype=torch.float32) + return attn_weight @ value, lse + + +@torch.inference_mode() +def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, + h_kv, d, dv, causal, dtype): + # q: [b, s_q, h_q, d] + # block_table: [b, max_seqlen_pad // block_size] + # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] + # cache_seqlens: [b] + blocked_v = blocked_k[..., :dv] + + def ref_mla(): + out = torch.empty(b, s_q, h_q, dv, dtype=torch.float32, device=q.device) + lse = torch.empty(b, h_q, s_q, dtype=torch.float32, device=q.device) + for i in range(b): + seq_len = cache_seqlens[i // 2] - ((i + 1) % 2) + kv_indices = block_table[i // 2, :seq_len] # 获取前seq_len个block索引 + O, LSE = scaled_dot_product_attention( + q[i].transpose(0, 1), + blocked_k.view(-1, h_kv, d)[kv_indices].transpose(0, 1), + blocked_v.view(-1, h_kv, dv)[kv_indices].transpose(0, 1), + h_q, + h_kv, + is_causal=causal, + ) + out[i] = O.transpose(0, 1) + lse[i] = LSE + return out.to(dtype), lse.to(dtype) + + out_torch, _ = ref_mla() + return out_torch + + +def run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, + h_q, h_kv, d, dv, causal, dtype): + + assert d > dv, "mla with rope dim should be larger than no rope dim" + q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., + dv:].contiguous() + + dpe = d - dv + num_kv_splits = 1 + + out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) + glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) + + batch_mtp = b // mtp_size + cu_seqlens_q = torch.arange( + 0, batch_mtp + 1, step=s_q, dtype=torch.int32, device=q.device + ) + cu_seqlens_k = torch.cumsum(cache_seqlens, dim=0) + cu_seqlens_k = torch.cat([torch.tensor([0]).to(cu_seqlens_k), cu_seqlens_k]) + scale = (1.0 / (dv + dpe))**0.5 # log2(e) + k_descale, v_descale = None, None + BLOCK_H = h_q * mtp_size + + def flash_mla_fa3(): + out = flash_attn_with_kvcache( + q=q_pe.view(-1, BLOCK_H, dpe), + k_cache=blocked_k_pe, + v_cache=blocked_k_nope, + qv=q_nope.view(-1, BLOCK_H, dv), + page_table=block_table, + cache_seqlens=cache_seqlens, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k_new=cu_seqlens_k, + max_seqlen_q=1, + softmax_scale=scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + mtp_step=1 + ) + return out.view([b, s_q, h_q, dv]) + + out_flash = flash_mla_fa3() + t = do_bench(flash_mla_fa3) + + out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, + cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + # 计算相对绝对误差 + def print_error(a, b, name=""): + max_absolute_error = torch.abs(a - b).max() + relative_abs_error = torch.abs(a - b) / (torch.abs(a) + 1e-4) + max_relative_abs_error = relative_abs_error.max() + mean_relative_abs_error = relative_abs_error.mean() + + print(f"{name}: Maximum absolute difference: {max_absolute_error:.6e}") + print(f"Maximum relative absolute error: {max_relative_abs_error:.6e}") + print(f"Mean relative absolute error: {mean_relative_abs_error:.6e}") + + print_error(out_flash, out_ref, "out_flash, out_ref") + torch.testing.assert_close(out_flash, out_ref, rtol=0.001, atol=0.001) + print("All close") + return out_flash, t + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument('--batch', type=int, default=128, help='batch size') + parser.add_argument('--h_q', type=int, default=16, help='q heads number') + parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') + parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') + parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') + parser.add_argument('--dv', type=int, default=512, help='value head dim') + parser.add_argument('--mtp_size', type=int, default=2, help='Specifies the number of tokens per prediction.') + args = parser.parse_args() + b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv + mtp_size = args.mtp_size + + device = "cuda" + dtype = torch.float16 + + s_q = 1 # for decode, s_q = 1 + block_size = 1 + batch_mtp = b // mtp_size + cache_seqlens = torch.tensor([cache_seqlen + i for i in range(batch_mtp)], + dtype=torch.int32, + device=device) + # print(cache_seqlens[-1]) + dpe = d - dv + causal = True + + total_seqlens = cache_seqlens.sum().item() + mean_seqlens = cache_seqlens.float().mean().int().item() + max_seqlen = cache_seqlens.max().item() + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256 + + total_flops = s_q * (total_seqlens * 2 - batch_mtp) * h_q * (d + dv) * 2 + + q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) + block_table = torch.arange( + batch_mtp * max_seqlen_pad, dtype=torch.int32, + device=device).view(batch_mtp, max_seqlen_pad) + + blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) + out_flash, latency = run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_size, b, + s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + + print("Tile-lang: {:.3f} ms".format(latency)) + print("Tile-lang: {:.3f} TFlops".format(total_flops / latency * 1e-9)) From e5608facc2861a833a7a6ccbc07b8199868c9868 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 Aug 2025 02:38:18 +0000 Subject: [PATCH 03/13] add _token_gqa_decode_attention_mtp to ds2 --- lightllm/common/flash_attn.py | 375 +++++++++--------- .../layer_infer/transformer_layer_infer.py | 45 ++- .../kernel/benchmark_fa3_decode_mtp.py | 4 +- 3 files changed, 234 insertions(+), 190 deletions(-) diff --git a/lightllm/common/flash_attn.py b/lightllm/common/flash_attn.py index 8566fd30e..456349aa9 100644 --- a/lightllm/common/flash_attn.py +++ b/lightllm/common/flash_attn.py @@ -16,198 +16,203 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import List, Optional, Tuple, Union - import torch -import flash_attn_3._C # Registers operators with PyTorch +from typing import List, Optional, Tuple, Union +from lightllm.utils.log_utils import init_logger -# isort: on +logger = init_logger(__name__) def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -flash_attn_3_cuda = torch.ops.flash_attn_3 - -def flash_attn_with_kvcache( - q, - k_cache, - v_cache, - k=None, - v=None, - qv=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, - cu_seqlens_q: Optional[torch.Tensor] = None, - cu_seqlens_k_new: Optional[torch.Tensor] = None, - max_seqlen_q: Optional[int] = None, - rotary_seqlens: Optional[torch.Tensor] = None, - q_descale: Optional[torch.Tensor] = None, - k_descale: Optional[torch.Tensor] = None, - v_descale: Optional[torch.Tensor] = None, - softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window - softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, - scheduler_metadata=None, - num_splits=0, # Can be tuned for speed - pack_gqa=None, # Can be tuned for speed - sm_margin=0, # Can be tuned if some SMs are used for communication - return_softmax_lse=False, - mtp_step=0 -): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. +try: + import flash_attn_3._C # Registers operators with PyTorch + flash_attn_3_mtp = torch.ops.flash_attn_3 - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, - or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. - qv [optional]: (batch_size, seqlen, nheads, headdim_v) - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. - page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" - if softmax_scale is None: - softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( - -0.5 - ) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) - cache_seqlens = maybe_contiguous(cache_seqlens) - - q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] - v_cache = ( - v_cache.contiguous() - if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 - else v_cache - ) - cu_seqlens_q, cu_seqlens_k_new = [ - maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new) - ] - page_table, cache_batch_idx, cache_leftpad = [ - maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) - ] - rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] - rotary_seqlens = maybe_contiguous(rotary_seqlens) - - # out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( - out, softmax_lse, *rest = flash_attn_3_cuda.fwd( + def flash_attn_with_kvcache_mtp( q, k_cache, v_cache, - k, - v, - qv, - None, # out - cu_seqlens_q, - None, # cu_seqlens_k - cu_seqlens_k_new, - None, # seqused_q - cache_seqlens, - max_seqlen_q, - None, # max_seqlen_k - page_table, - cache_batch_idx, - cache_leftpad, - rotary_cos, - rotary_sin, - rotary_seqlens, - q_descale, - k_descale, - v_descale, - softmax_scale, - causal, - window_size[0], - window_size[1], - 0, - softcap, - rotary_interleaved, - scheduler_metadata, - num_splits, - pack_gqa, - sm_margin, - mtp_step - ) - return (out, softmax_lse, *rest) if return_softmax_lse else out \ No newline at end of file + k=None, + v=None, + qv=None, + rotary_cos=None, + rotary_sin=None, + cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + page_table: Optional[torch.Tensor] = None, + cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k_new: Optional[torch.Tensor] = None, + max_seqlen_q: Optional[int] = None, + rotary_seqlens: Optional[torch.Tensor] = None, + q_descale: Optional[torch.Tensor] = None, + k_descale: Optional[torch.Tensor] = None, + v_descale: Optional[torch.Tensor] = None, + softmax_scale=None, + causal=False, + window_size=(-1, -1), # -1 means infinite context window + softcap=0.0, # 0.0 means deactivated + rotary_interleaved=True, + scheduler_metadata=None, + num_splits=0, # Can be tuned for speed + pack_gqa=None, # Can be tuned for speed + sm_margin=0, # Can be tuned if some SMs are used for communication + return_softmax_lse=False, + mtp_step=0 + ): + """ + If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from + k and v. This is useful for incremental decoding: you can pass in the cached keys/values from + the previous step, and update them with the new keys/values from the current step, and do + attention with the updated cache, all in 1 kernel. + + If you pass in k / v, you must make sure that the cache is large enough to hold the new values. + For example, the KV cache could be pre-allocated with the max sequence length, and you can use + cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. + + Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be + rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos + and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. + If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at + indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). + + See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. + + Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads + than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. + For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head + 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + + If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. + For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: + 1 1 1 1 0 + 1 1 1 1 1 + If seqlen_q = 5 and seqlen_k = 2, the causal mask is: + 0 0 + 0 0 + 0 0 + 1 0 + 1 1 + If the row of the mask is all zero, the output will be zero. + + If window_size != (-1, -1), implements sliding window local attention. Query at position i + will only attend to keys between + [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. + + Note: Does not support backward pass. + + Arguments: + q: (batch_size, seqlen, nheads, headdim) + k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) + page_block_size must be a multiple of 256. + v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, + or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) + k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate + k with k_cache, starting at the indices specified by cache_seqlens. + v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. + qv [optional]: (batch_size, seqlen, nheads, headdim_v) + rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding + to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. + rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. + cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the + KV cache. + cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. + If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. + If the indices are not distinct, and k and v are provided, the values updated in the cache + might come from any of the duplicate indices. + cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. + page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). + window_size: (left, right). If not (-1, -1), implements sliding window local attention. + softcap: float. Anything > 0 activates softcapping attention. + rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. + If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, + rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 + (i.e. GPT-NeoX style). + num_splits: int. If > 1, split the key/value into this many chunks along the sequence. + If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic + to automatically determine the number of splits. + Don't change this unless you know what you are doing. + return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. + + Return: + out: (batch_size, seqlen, nheads, headdim). + softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The + logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax + normalization factor). + """ + assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" + assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + if softmax_scale is None: + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( + -0.5 + ) + if cache_seqlens is not None and isinstance(cache_seqlens, int): + cache_seqlens = torch.full( + (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device + ) + cache_seqlens = maybe_contiguous(cache_seqlens) + + q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] + v_cache = ( + v_cache.contiguous() + if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 + else v_cache + ) + cu_seqlens_q, cu_seqlens_k_new = [ + maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new) + ] + page_table, cache_batch_idx, cache_leftpad = [ + maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) + ] + rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] + rotary_seqlens = maybe_contiguous(rotary_seqlens) + + # out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( + out, softmax_lse, *rest = flash_attn_3_mtp.fwd( + q, + k_cache, + v_cache, + k, + v, + qv, + None, # out + cu_seqlens_q, + None, # cu_seqlens_k + cu_seqlens_k_new, + None, # seqused_q + cache_seqlens, + max_seqlen_q, + None, # max_seqlen_k + page_table, + cache_batch_idx, + cache_leftpad, + rotary_cos, + rotary_sin, + rotary_seqlens, + q_descale, + k_descale, + v_descale, + softmax_scale, + causal, + window_size[0], + window_size[1], + 0, + softcap, + rotary_interleaved, + scheduler_metadata, + num_splits, + pack_gqa, + sm_margin, + mtp_step + ) + return (out, softmax_lse, *rest) if return_softmax_lse else out +except: + flash_attn_3_mtp = None + flash_attn_with_kvcache_mtp = None + logger.warning("flash_attn_3._C is not available, please install flash-attention-3 package.") diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index eccbe430d..4e3b78b5a 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -31,6 +31,7 @@ from lightllm.utils.dist_utils import get_global_world_size from lightllm.utils.log_utils import init_logger from lightllm.utils.sgl_utils import flash_attn_varlen_func, flash_attn_with_kvcache, merge_state_v2 +from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp logger = init_logger(__name__) @@ -72,6 +73,8 @@ def __init__(self, layer_num, network_config, mode=[]): super().__init__(layer_num, network_config, mode) self.num_heads = network_config["num_attention_heads"] self.num_kv_heads = network_config["num_key_value_heads"] + self.mtp_step = get_env_start_args().mtp_step + self.mtp_size = self.mtp_step + 1 return def _bind_func(self): @@ -98,9 +101,14 @@ def _bind_attention(self): else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) if get_env_start_args().enable_fa3: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self - ) + if get_env_start_args().mtp_mode is not None and flash_attn_with_kvcache_mtp is not None: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_mtp, self + ) + else: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self + ) elif get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self @@ -546,6 +554,37 @@ def _context_attention_kernel_origin_fp8( self.softmax_scale, ) return o_tensor + + def _token_gqa_decode_attention_mtp( + self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None + ): + q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] + q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + k_descale, v_descale = None, None + o_tensor = flash_attn_with_kvcache_mtp( + q=q_rope.view(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim), + k_cache=k_rope, + v_cache=kv_nope, + qv=q_nope.view(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_nope_head_dim), + page_table=infer_state.page_table[self.mtp_size - 1::self.mtp_size], + cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1::self.mtp_size], + cu_seqlens_q=infer_state.cu_seqlens_q[self.mtp_size - 1::self.mtp_size], + cu_seqlens_k_new=infer_state.cu_seqlens_k[self.mtp_size - 1::self.mtp_size], + max_seqlen_q=1, + softmax_scale=self.softmax_scale, + causal=True, + window_size=(-1, -1), + softcap=0.0, + k_descale=k_descale, + v_descale=v_descale, + return_softmax_lse=False, + mtp_step=self.mtp_step + ) + return o_tensor + def _token_gqa_decode_attention_flashattention( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None diff --git a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py index 176cf0536..62693b59b 100644 --- a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py +++ b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py @@ -22,7 +22,7 @@ import argparse import math from typing import Callable, Optional, List, Literal, Union -from lightllm.common.flash_attn import flash_attn_with_kvcache +from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp from lightllm.utils.bench_utils import do_bench def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @@ -102,7 +102,7 @@ def run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_s BLOCK_H = h_q * mtp_size def flash_mla_fa3(): - out = flash_attn_with_kvcache( + out = flash_attn_with_kvcache_mtp( q=q_pe.view(-1, BLOCK_H, dpe), k_cache=blocked_k_pe, v_cache=blocked_k_nope, From fddc5c65afa24d689bdc5e61c00399eb62df0270 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 Aug 2025 03:24:23 +0000 Subject: [PATCH 04/13] fix q_nope discontinuous issues --- .../triton_kernel/gen_decode_params.py | 13 +++++-- lightllm/common/flash_attn.py | 28 ++++++--------- .../layer_infer/transformer_layer_infer.py | 34 +++++++++---------- lightllm/server/api_cli.py | 5 +++ lightllm/server/api_start.py | 3 ++ lightllm/server/core/objs/start_args_type.py | 1 + 6 files changed, 47 insertions(+), 37 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index fe57ae8d2..1ffdf5cc5 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -2,13 +2,22 @@ import triton import triton.language as tl from .gen_prefill_params import gen_cumsum_pad0_tensor +from lightllm.utils.envs_utils import get_env_start_args @torch.no_grad() def gen_decode_params(b_seq_len: torch.Tensor): b_kv_seq_len = b_seq_len position_ids = b_seq_len - 1 - b_q_seq_len = torch.ones_like(b_seq_len) - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) + mtp_step = get_env_start_args().mtp_step + mtp_size = mtp_step + 1 + enable_fa3_mtp = get_env_start_args().enable_fa3_mtp + + if enable_fa3_mtp: + b_q_seq_len = torch.ones_like(b_seq_len[: len(b_seq_len) // mtp_size]) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size]) + else: + b_q_seq_len = torch.ones_like(b_seq_len) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids diff --git a/lightllm/common/flash_attn.py b/lightllm/common/flash_attn.py index 456349aa9..f3a505035 100644 --- a/lightllm/common/flash_attn.py +++ b/lightllm/common/flash_attn.py @@ -22,11 +22,14 @@ logger = init_logger(__name__) + def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x + try: - import flash_attn_3._C # Registers operators with PyTorch + import flash_attn_3._C # Registers operators with PyTorch + flash_attn_3_mtp = torch.ops.flash_attn_3 def flash_attn_with_kvcache_mtp( @@ -59,7 +62,7 @@ def flash_attn_with_kvcache_mtp( pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, - mtp_step=0 + mtp_step=0, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -149,24 +152,14 @@ def flash_attn_with_kvcache_mtp( assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" if softmax_scale is None: - softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( - -0.5 - ) + softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full( - (k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device - ) + cache_seqlens = torch.full((k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device) cache_seqlens = maybe_contiguous(cache_seqlens) q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] - v_cache = ( - v_cache.contiguous() - if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 - else v_cache - ) - cu_seqlens_q, cu_seqlens_k_new = [ - maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new) - ] + v_cache = v_cache.contiguous() if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 else v_cache + cu_seqlens_q, cu_seqlens_k_new = [maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)] page_table, cache_batch_idx, cache_leftpad = [ maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) ] @@ -209,9 +202,10 @@ def flash_attn_with_kvcache_mtp( num_splits, pack_gqa, sm_margin, - mtp_step + mtp_step, ) return (out, softmax_lse, *rest) if return_softmax_lse else out + except: flash_attn_3_mtp = None flash_attn_with_kvcache_mtp = None diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 4e3b78b5a..520bfd2b4 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -100,15 +100,14 @@ def _bind_attention(self): ) else: self._copy_kv_to_mem_cache = partial(Deepseek2TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) - if get_env_start_args().enable_fa3: - if get_env_start_args().mtp_mode is not None and flash_attn_with_kvcache_mtp is not None: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_mtp, self - ) - else: - self._token_attention_kernel = partial( - Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self - ) + if get_env_start_args().enable_fa3_mtp: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_mtp, self + ) + elif get_env_start_args().enable_fa3: + self._token_attention_kernel = partial( + Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashattention, self + ) elif get_env_start_args().enable_flashinfer_decode: self._token_attention_kernel = partial( Deepseek2TransformerLayerInfer._token_gqa_decode_attention_flashinfer, self @@ -554,7 +553,7 @@ def _context_attention_kernel_origin_fp8( self.softmax_scale, ) return o_tensor - + def _token_gqa_decode_attention_mtp( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): @@ -565,14 +564,14 @@ def _token_gqa_decode_attention_mtp( kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache_mtp( - q=q_rope.view(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim), + q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim), k_cache=k_rope, v_cache=kv_nope, - qv=q_nope.view(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_nope_head_dim), - page_table=infer_state.page_table[self.mtp_size - 1::self.mtp_size], - cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1::self.mtp_size], - cu_seqlens_q=infer_state.cu_seqlens_q[self.mtp_size - 1::self.mtp_size], - cu_seqlens_k_new=infer_state.cu_seqlens_k[self.mtp_size - 1::self.mtp_size], + qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_nope_head_dim), + page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size], + cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size], + cu_seqlens_q=infer_state.cu_seqlens_q, + cu_seqlens_k_new=infer_state.cu_seqlens_k, max_seqlen_q=1, softmax_scale=self.softmax_scale, causal=True, @@ -581,11 +580,10 @@ def _token_gqa_decode_attention_mtp( k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, - mtp_step=self.mtp_step + mtp_step=self.mtp_step, ) return o_tensor - def _token_gqa_decode_attention_flashattention( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None ): diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index dfbf0c84b..eabb06458 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -465,6 +465,11 @@ def make_argument_parser() -> argparse.ArgumentParser: but ensure that the model is compatible with the specified step count. currently, deepseekv3 model only support 1 step""", ) + parser.add_argument( + "--enable_fa3_mtp", + action="store_true", + help="""inference backend will use the fa3_mtp kernel for decode with MTP mode""", + ) parser.add_argument( "--kv_quant_calibration_config_path", type=str, diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index c2a87b4c3..d8c2900c6 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -138,6 +138,9 @@ def normal_or_p_d_start(args): assert args.mtp_draft_model_dir is None assert args.mtp_step == 0 + if args.enable_fa3_mtp: + assert args.mtp_mode is not None, "enable_fa3_mtp must set mtp_mode" + # 检查GPU数量是否足够 if args.visual_gpu_ids is None: args.visual_gpu_ids = list(range(args.visual_dp * args.visual_tp)) diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index d4a205a15..771365e5f 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -95,4 +95,5 @@ class StartArgs: mtp_mode: Optional[str] = field(default=None) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) + enable_fa3_mtp: bool = field(default=False) kv_quant_calibration_config_path: Optional[str] = field(default=None) From 9028d7434705f89932c0bc108b461c65a33b5697 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 Aug 2025 04:48:52 +0000 Subject: [PATCH 05/13] [fix]q_nope shape --- .../models/deepseek2/layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/deepseek2/model.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 520bfd2b4..8ec15f98a 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -567,7 +567,7 @@ def _token_gqa_decode_attention_mtp( q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim), k_cache=k_rope, v_cache=kv_nope, - qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_nope_head_dim), + qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank), page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size], cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size], cu_seqlens_q=infer_state.cu_seqlens_q, diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..4be30b556 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -69,7 +69,7 @@ def __init__(self, kvargs): return def _init_inferstate_cls(self): - if get_env_start_args().enable_fa3: + if get_env_start_args().enable_fa3 or get_env_start_args().enable_fa3_mtp: self.infer_state_class = Deepseek2FlashAttentionStateInfo elif self.enable_flashinfer: self.infer_state_class = Deepseek2FlashInferStateInfo From 60bbd8635e578afa073db6d45c2c098b33dd69df Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 Aug 2025 04:55:38 +0000 Subject: [PATCH 06/13] format --- .../kernel/benchmark_fa3_decode_mtp.py | 99 ++++++++++++------- 1 file changed, 61 insertions(+), 38 deletions(-) diff --git a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py index 62693b59b..b605f2fe0 100644 --- a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py +++ b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py @@ -25,6 +25,7 @@ from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp from lightllm.utils.bench_utils import do_bench + def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): query = query.float() key = key.float() @@ -36,8 +37,7 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): s_q = query.shape[-2] s_k = key.shape[-2] attn_bias = torch.zeros(s_q, s_k, dtype=query.dtype, device=query.device) - temp_mask = torch.ones( - s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) + temp_mask = torch.ones(s_q, s_k, dtype=torch.bool, device=query.device).tril(diagonal=s_k - s_q) attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) attn_bias.to(query.dtype) attn_weight += attn_bias @@ -47,8 +47,9 @@ def scaled_dot_product_attention(query, key, value, h_q, h_kv, is_causal=False): @torch.inference_mode() -def run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, - h_kv, d, dv, causal, dtype): +def run_torch_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype +): # q: [b, s_q, h_q, d] # block_table: [b, max_seqlen_pad // block_size] # blocked_k: [b * max_seqlen_pad // block_size, block_size, h_kv, d] @@ -77,27 +78,35 @@ def ref_mla(): return out_torch -def run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, - h_q, h_kv, d, dv, causal, dtype): +def run_fa3_mla_mtp( + mtp_size, + q, + block_table, + blocked_k, + max_seqlen_pad, + block_size, + b, + s_q, + cache_seqlens, + h_q, + h_kv, + d, + dv, + causal, + dtype, +): assert d > dv, "mla with rope dim should be larger than no rope dim" q_nope, q_pe = q[..., :dv].contiguous(), q[..., dv:].contiguous() - blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., - dv:].contiguous() + blocked_k_nope, blocked_k_pe = blocked_k[..., :dv].contiguous(), blocked_k[..., dv:].contiguous() dpe = d - dv - num_kv_splits = 1 - - out_partial = torch.empty(b, h_q, num_kv_splits, dv, dtype=dtype, device=q.device) - glse = torch.empty(b, h_q, num_kv_splits, dtype=dtype, device=q.device) batch_mtp = b // mtp_size - cu_seqlens_q = torch.arange( - 0, batch_mtp + 1, step=s_q, dtype=torch.int32, device=q.device - ) + cu_seqlens_q = torch.arange(0, batch_mtp + 1, step=s_q, dtype=torch.int32, device=q.device) cu_seqlens_k = torch.cumsum(cache_seqlens, dim=0) cu_seqlens_k = torch.cat([torch.tensor([0]).to(cu_seqlens_k), cu_seqlens_k]) - scale = (1.0 / (dv + dpe))**0.5 # log2(e) + scale = (1.0 / (dv + dpe)) ** 0.5 # log2(e) k_descale, v_descale = None, None BLOCK_H = h_q * mtp_size @@ -119,23 +128,24 @@ def flash_mla_fa3(): k_descale=k_descale, v_descale=v_descale, return_softmax_lse=False, - mtp_step=1 + mtp_step=1, ) return out.view([b, s_q, h_q, dv]) out_flash = flash_mla_fa3() t = do_bench(flash_mla_fa3) - out_ref = run_torch_mla(q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, - cache_seqlens, h_q, h_kv, d, dv, causal, dtype) - - # 计算相对绝对误差 + out_ref = run_torch_mla( + q, block_table, blocked_k, max_seqlen_pad, block_size, b, s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype + ) + + # 计算相对绝对误差 def print_error(a, b, name=""): max_absolute_error = torch.abs(a - b).max() relative_abs_error = torch.abs(a - b) / (torch.abs(a) + 1e-4) max_relative_abs_error = relative_abs_error.max() mean_relative_abs_error = relative_abs_error.mean() - + print(f"{name}: Maximum absolute difference: {max_absolute_error:.6e}") print(f"Maximum relative absolute error: {max_relative_abs_error:.6e}") print(f"Mean relative absolute error: {mean_relative_abs_error:.6e}") @@ -148,13 +158,13 @@ def print_error(a, b, name=""): if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--batch', type=int, default=128, help='batch size') - parser.add_argument('--h_q', type=int, default=16, help='q heads number') - parser.add_argument('--h_kv', type=int, default=1, help='kv heads number') - parser.add_argument('--cache_seqlen', type=int, default=8192, help='kv cache context length') - parser.add_argument('--d', type=int, default=576, help='query/key head dim, d = dv + dpe') - parser.add_argument('--dv', type=int, default=512, help='value head dim') - parser.add_argument('--mtp_size', type=int, default=2, help='Specifies the number of tokens per prediction.') + parser.add_argument("--batch", type=int, default=128, help="batch size") + parser.add_argument("--h_q", type=int, default=16, help="q heads number") + parser.add_argument("--h_kv", type=int, default=1, help="kv heads number") + parser.add_argument("--cache_seqlen", type=int, default=8192, help="kv cache context length") + parser.add_argument("--d", type=int, default=576, help="query/key head dim, d = dv + dpe") + parser.add_argument("--dv", type=int, default=512, help="value head dim") + parser.add_argument("--mtp_size", type=int, default=2, help="Specifies the number of tokens per prediction.") args = parser.parse_args() b, h_q, h_kv, cache_seqlen, d, dv = args.batch, args.h_q, args.h_kv, args.cache_seqlen, args.d, args.dv mtp_size = args.mtp_size @@ -165,9 +175,7 @@ def print_error(a, b, name=""): s_q = 1 # for decode, s_q = 1 block_size = 1 batch_mtp = b // mtp_size - cache_seqlens = torch.tensor([cache_seqlen + i for i in range(batch_mtp)], - dtype=torch.int32, - device=device) + cache_seqlens = torch.tensor([cache_seqlen + i for i in range(batch_mtp)], dtype=torch.int32, device=device) # print(cache_seqlens[-1]) dpe = d - dv causal = True @@ -175,18 +183,33 @@ def print_error(a, b, name=""): total_seqlens = cache_seqlens.sum().item() mean_seqlens = cache_seqlens.float().mean().int().item() max_seqlen = cache_seqlens.max().item() - max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256 + max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256 total_flops = s_q * (total_seqlens * 2 - batch_mtp) * h_q * (d + dv) * 2 q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) - block_table = torch.arange( - batch_mtp * max_seqlen_pad, dtype=torch.int32, - device=device).view(batch_mtp, max_seqlen_pad) + block_table = torch.arange(batch_mtp * max_seqlen_pad, dtype=torch.int32, device=device).view( + batch_mtp, max_seqlen_pad + ) blocked_k = torch.randn(block_table.numel(), block_size, h_kv, d, dtype=dtype, device=device) - out_flash, latency = run_fa3_mla_mtp(mtp_size, q, block_table, blocked_k, max_seqlen_pad, block_size, b, - s_q, cache_seqlens, h_q, h_kv, d, dv, causal, dtype) + out_flash, latency = run_fa3_mla_mtp( + mtp_size, + q, + block_table, + blocked_k, + max_seqlen_pad, + block_size, + b, + s_q, + cache_seqlens, + h_q, + h_kv, + d, + dv, + causal, + dtype, + ) print("Tile-lang: {:.3f} ms".format(latency)) print("Tile-lang: {:.3f} TFlops".format(total_flops / latency * 1e-9)) From fc22294ac32ca215f786f4701d5342978e1c1976 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 Aug 2025 05:01:03 +0000 Subject: [PATCH 07/13] format --- lightllm/utils/bench_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lightllm/utils/bench_utils.py b/lightllm/utils/bench_utils.py index bb063f6c8..e4b2100c6 100644 --- a/lightllm/utils/bench_utils.py +++ b/lightllm/utils/bench_utils.py @@ -34,13 +34,13 @@ def do_bench( return_mode: Literal["min", "max", "mean", "median"] = "mean", ) -> Union[float, List[float]]: """Benchmarks the runtime of a PyTorch function. - + This function handles: - L2 cache flushing between runs for consistent timing - Automatic warmup and repeat count calculation - Optional gradient clearing for backward passes - Multiple measurement modes (mean, median, min, max) - + Args: fn: Function to benchmark warmup: Target warmup time in milliseconds @@ -51,7 +51,7 @@ def do_bench( quantiles: Optional performance percentiles to compute fast_flush: Whether to use faster L2 cache flushing return_mode: How to aggregate timing results ("mean", "median", "min", "max") - + Returns: float: Aggregated runtime in milliseconds """ From d98579dc98b139345c957b2c9859ac9013f6efb0 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Mon, 11 Aug 2025 07:20:22 +0000 Subject: [PATCH 08/13] finish, pass test --- docs/CN/source/getting_started/benchmark.rst | 8 ++++---- docs/EN/source/getting_started/benchmark.rst | 8 ++++---- lightllm/common/basemodel/cuda_graph.py | 4 ++++ .../common/basemodel/triton_kernel/gen_decode_params.py | 6 ++++-- .../deepseek2/layer_infer/transformer_layer_infer.py | 4 ++-- lightllm/server/api_start.py | 4 ++++ 6 files changed, 22 insertions(+), 12 deletions(-) diff --git a/docs/CN/source/getting_started/benchmark.rst b/docs/CN/source/getting_started/benchmark.rst index c9fc778aa..cfcd2c0ed 100644 --- a/docs/CN/source/getting_started/benchmark.rst +++ b/docs/CN/source/getting_started/benchmark.rst @@ -89,15 +89,15 @@ ShareGPT 数据集测试 (benchmark_sharegpt.py) python test/benchmark/service/benchmark_sharegpt.py \ --dataset /path/to/sharegpt_dataset.json \ --tokenizer /path/to/tokenizer \ - --num_prompts 1000 \ - --request_rate 10.0 + --num-prompts 1000 \ + --request-rate 10.0 **主要参数:** - ``--dataset``: ShareGPT 格式数据集路径 - ``--tokenizer``: 分词器路径 -- ``--num_prompts``: 测试提示数量 -- ``--request_rate``: 请求速率 (requests/s) +- ``--num-prompts``: 测试提示数量 +- ``--request-rate``: 请求速率 (requests/s) Prompt Cache 测试 diff --git a/docs/EN/source/getting_started/benchmark.rst b/docs/EN/source/getting_started/benchmark.rst index 87caaa06a..5587b8a11 100755 --- a/docs/EN/source/getting_started/benchmark.rst +++ b/docs/EN/source/getting_started/benchmark.rst @@ -88,15 +88,15 @@ Performance testing using ShareGPT real conversation data. python test/benchmark/service/benchmark_sharegpt.py \ --dataset /path/to/sharegpt_dataset.json \ --tokenizer /path/to/tokenizer \ - --num_prompts 1000 \ - --request_rate 10.0 + --num-prompts 1000 \ + --request-rate 10.0 **Main Parameters:** - ``--dataset``: ShareGPT format dataset path - ``--tokenizer``: Tokenizer path -- ``--num_prompts``: Number of test prompts -- ``--request_rate``: Request rate (requests/s) +- ``--num-prompts``: Number of test prompts +- ``--request-rate``: Request rate (requests/s) Prompt Cache Testing ~~~~~~~~~~~~~~~~~~~ diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..43c9856dd 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -40,6 +40,10 @@ def __init__(self, max_batch_size=8, max_len_in_batch=8192): batch_sizes.append(max_batch_size) batch_sizes.sort() + if self.args.enable_fa3_mtp: + step_size = self.args.mtp_step + 1 + batch_sizes = [b for b in batch_sizes if b % step_size == 0] + self.cuda_graph_batch_sizes = batch_sizes assert batch_sizes[-1] == self.max_batch_size logger.info(f"cuda graph batch_sizes: {self.cuda_graph_batch_sizes}") diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index 1ffdf5cc5..a8237f1a7 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -12,10 +12,12 @@ def gen_decode_params(b_seq_len: torch.Tensor): mtp_step = get_env_start_args().mtp_step mtp_size = mtp_step + 1 enable_fa3_mtp = get_env_start_args().enable_fa3_mtp + b_q_seq_len = torch.ones_like(b_seq_len) if enable_fa3_mtp: - b_q_seq_len = torch.ones_like(b_seq_len[: len(b_seq_len) // mtp_size]) - b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len[mtp_size - 1 :: mtp_size]) + b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( + b_q_seq_len[: len(b_seq_len) // mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size] + ) else: b_q_seq_len = torch.ones_like(b_seq_len) b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 8ec15f98a..62a46b3f8 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -569,7 +569,7 @@ def _token_gqa_decode_attention_mtp( v_cache=kv_nope, qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank), page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size], - cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size], + cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(), cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, max_seqlen_q=1, @@ -582,7 +582,7 @@ def _token_gqa_decode_attention_mtp( return_softmax_lse=False, mtp_step=self.mtp_step, ) - return o_tensor + return o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank) def _token_gqa_decode_attention_flashattention( self, q, infer_state: Deepseek2FlashAttentionStateInfo, layer_weight: Deepseek2TransformerLayerWeight, out=None diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index d8c2900c6..5895ddcc2 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -15,6 +15,7 @@ from .router.manager import start_router_process from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip +from lightllm.common.flash_attn import flash_attn_with_kvcache_mtp logger = init_logger(__name__) @@ -140,6 +141,9 @@ def normal_or_p_d_start(args): if args.enable_fa3_mtp: assert args.mtp_mode is not None, "enable_fa3_mtp must set mtp_mode" + assert ( + flash_attn_with_kvcache_mtp is not None + ), "flash_attn_with_kvcache_mtp is None, please check if you have installed the fa3_mtp kernel" # 检查GPU数量是否足够 if args.visual_gpu_ids is None: From d3bf481521ae6347c9c20ed76e8e0c539072ff35 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 15 Aug 2025 03:17:15 +0000 Subject: [PATCH 09/13] Enriched the benchmark --- test/benchmark/service/benchmark_longbench.py | 305 ++++++++++++++++++ test/benchmark/service/benchmark_sharegpt.py | 248 ++++++++++---- 2 files changed, 489 insertions(+), 64 deletions(-) create mode 100644 test/benchmark/service/benchmark_longbench.py diff --git a/test/benchmark/service/benchmark_longbench.py b/test/benchmark/service/benchmark_longbench.py new file mode 100644 index 000000000..0a091a0d5 --- /dev/null +++ b/test/benchmark/service/benchmark_longbench.py @@ -0,0 +1,305 @@ +# Adapted from benchmarks/benchmark_serving.py +# of the vllm-project/vllm GitHub repository. +# +# Copyright 2023 ModelTC Team +# Copyright 2023 vLLM Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import asyncio +import json +import random +import time +from typing import AsyncGenerator, List, Tuple, Union + +import aiohttp +import numpy as np +from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase + +from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast + + +def get_tokenizer( + tokenizer_name: str, + tokenizer_mode: str = "auto", + *args, + **kwargs, +) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + """Gets a tokenizer for the given model name via Huggingface.""" + if tokenizer_mode == "slow": + if kwargs.get("use_fast", False): + raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.") + kwargs["use_fast"] = False + + if "llama" in tokenizer_name.lower() and kwargs.get("use_fast", True): + pass + try: + tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, *args, **kwargs) + except TypeError as e: + err_msg = "Failed to load the tokenizer. {e}" + raise RuntimeError(err_msg) from e + + if not isinstance(tokenizer, PreTrainedTokenizerFast): + pass + return tokenizer + + +# (prompt len, output len, latency) +REQUEST_LATENCY: List[Tuple[int, int, float]] = [] + + +def sample_requests( + dataset_path: str, + num_requests: int, + tokenizer: PreTrainedTokenizerBase, + max_total_tokens: int = 16384, +) -> List[Tuple[List[dict], str, int, int]]: + # Load the dataset (jsonl) + dataset = [] + with open(dataset_path) as f: + for line in f.readlines(): + if not line.strip(): + continue + dataset.append(json.loads(line)) + print("read data set finish") + + def render_with_template(messages: List[dict]) -> str: + try: + return tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + parts = [] + for m in messages: + parts.append(f"{m['role']}: {m['content']}") + parts.append("assistant:") + return "\n".join(parts) + + built_examples: List[Tuple[List[dict], str, int, int]] = [] + + for data in dataset: + context = data.get("context") or "" + question = data.get("input") or "Summarizing government work reports" + answers = data.get("answers") + if not isinstance(context, str) or not isinstance(question, str): + continue + + # Build messages: system + user with context and question + system_prompt = "You are a helpful assistant. Read the context and answer the question concisely." + user_content = f"Context:\n{context}\nInput:\n{question}" + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ] + + rendered_prompt = render_with_template(messages) + prompt_len = len(tokenizer(rendered_prompt).input_ids) + + # Estimate output length from reference answer if available + target_text = "" + if isinstance(answers, list) and len(answers) > 0: + first_ans = answers[0] + if isinstance(first_ans, str): + target_text = first_ans + else: + target_text = str(first_ans) + elif isinstance(answers, str): + target_text = answers + + estimated_out = len(tokenizer(target_text).input_ids) if target_text else 128 + + # Fit within max_total_tokens + available_out = max_total_tokens - 1 - prompt_len + if available_out < 4: + # Skip samples that are too long + continue + output_len = min(estimated_out, available_out) + + built_examples.append((messages, rendered_prompt, prompt_len, output_len)) + + # Take the first N valid samples + sampled_requests = built_examples[:num_requests] + sum_len = 0 + for _, _, prompt_len, output_len in sampled_requests: + sum_len += prompt_len + output_len + print("total tokens:", sum_len) + return sampled_requests + + +async def get_request( + input_requests: List[Tuple[List[dict], str, int, int]], + request_rate: float, +) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: + input_requests = iter(input_requests) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) + + +async def send_request( + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool +) -> None: + if use_openai_api: + # Use OpenAI API to send the request. + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/v1/chat/completions" + + data = { + "model": "DeepSeek-R1", + "messages": messages, + "top_k": 1, + "top_p": 1.0, + "temperature": 0, + "stream": True, + "ignore_eos": True, + "max_tokens": output_len, + } + timeout = aiohttp.ClientTimeout(total=3 * 3600) + receive_n = 1 + + async with aiohttp.ClientSession(timeout=timeout) as session: + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + text = "" + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + if delta_time < 0.005: + receive_n += 1 + chunks.append(delta_time) + start_time = now_time + + else: + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/generate_stream" + + data = { + "inputs": rendered_prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + }, + } + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + receive_n = 0 + text = "" + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + if delta_time < 0.005: + receive_n += 1 + chunks.append(chunk) + text += json.loads(chunk.decode("utf-8")[5:])["token"]["text"] + start_time = now_time + + request_end_time = time.time() + request_latency = request_end_time - request_start_time + REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) + + +async def benchmark( + input_requests: List[Tuple[List[dict], str, int, int]], + request_rate: float, + use_openai_api: bool = False, +) -> None: + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api)) + tasks.append(task) + await asyncio.gather(*tasks) + + +def main(args: argparse.Namespace): + print(args) + random.seed(args.seed) + np.random.seed(args.seed) + tokenizer = get_tokenizer(args.tokenizer, "slow") + input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_total_tokens) + + benchmark_start_time = time.time() + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api)) + benchmark_end_time = time.time() + benchmark_time = benchmark_end_time - benchmark_start_time + print(f"Total time: {benchmark_time:.2f} s") + print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") + + # Compute the latency statistics. + avg_latency = np.mean([latency for _, _, latency, _ in REQUEST_LATENCY]) + print(f"Average latency: {avg_latency:.2f} s") + avg_time_to_first_token = np.mean([ttft for _, _, _, ttft in REQUEST_LATENCY]) + print("Average time to first token: " f"{avg_time_to_first_token:.2f} s") + avg_per_token_latency = ( + np.mean([latency / (prompt_len + output_len) for prompt_len, output_len, latency, _ in REQUEST_LATENCY]) * 1000 + ) + print(f"Average latency per token: {avg_per_token_latency:.1f} ms") + # avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency, _ in REQUEST_LATENCY]) + # print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + avg_inter_token_latency = ( + np.mean( + [(latency - ttft) / (output_len - 1) for _, output_len, latency, ttft in REQUEST_LATENCY if output_len > 1] + ) + * 1000 + ) + print(f"Average inter-token latency: {avg_inter_token_latency:.1f} ms") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--use_openai_api", default=False, action="store_true", help="Use OpenAI API for requests.") + parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") + parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") + parser.add_argument( + "--request-rate", + type=float, + default=float("inf"), + help="Number of requests per second. If this is inf, " + "then all the requests are sent at time 0. " + "Otherwise, we use Poisson process to synthesize " + "the request arrival times.", + ) + parser.add_argument("--num-prompts", type=int, default=1, help="Number of prompts to process.") + parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") + parser.add_argument("--seed", type=int, default=0) + args = parser.parse_args() + main(args) diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index c9f92f098..6b0c2ba99 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -63,55 +63,111 @@ def sample_requests( dataset_path: str, num_requests: int, tokenizer: PreTrainedTokenizerBase, -) -> List[Tuple[str, int, int]]: + max_history_turns: int = 6, + max_total_tokens: int = 16384, +) -> List[Tuple[List[dict], str, int, int]]: # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) - # Filter out the conversations with less than 2 turns. - dataset = [data for data in dataset if len(data["conversations"]) >= 2] - # Only keep the first two turns of each conversation. - dataset = [(data["conversations"][0]["value"], data["conversations"][1]["value"]) for data in dataset] + # Filter out the conversations with at least 2 turns. + dataset = [data for data in dataset if len(data["conversations"]) >= max_history_turns] print("read data set finish") - # Tokenize the prompts and completions. - import random - - dataset = random.sample(dataset, num_requests * 3) - prompts = [prompt for prompt, _ in dataset] - completions = [completion for _, completion in dataset] + dataset = dataset[: num_requests * 3] + + def to_openai_role(role_value: str) -> str: + lower_value = role_value.lower() + if lower_value in ["human", "user", "system"]: + return "user" if lower_value != "system" else "system" + return "assistant" + + # Build messages and targets + built_examples: List[Tuple[List[dict], str]] = [] + for data in dataset: + convs = data.get("conversations", []) + if not convs: + continue + # Find the last assistant turn to be used as the completion target + last_assistant_idx = -1 + for idx in range(len(convs) - 1, -1, -1): + role_val = convs[idx].get("from") or convs[idx].get("role") or "assistant" + if to_openai_role(role_val) == "assistant": + last_assistant_idx = idx + break + if last_assistant_idx <= 0: + # Need at least one prompt message before the assistant response + continue + # Determine how many turns of history to keep before the target assistant turn + start_idx = max(0, last_assistant_idx - max_history_turns) + context_convs = convs[start_idx:last_assistant_idx] + completion_text = convs[last_assistant_idx].get("value") or convs[last_assistant_idx].get("content") or "" + if not completion_text: + continue + messages: List[dict] = [] + for turn in context_convs: + role_val = turn.get("from") or turn.get("role") or "user" + content_val = turn.get("value") or turn.get("content") or "" + if not content_val: + continue + messages.append({"role": to_openai_role(role_val), "content": content_val}) + if not messages: + continue + built_examples.append((messages, completion_text)) + + # Render prompts using chat template when possible + rendered_prompts: List[str] = [] + for messages, _ in built_examples: + rendered_text = None + try: + # Prefer using the tokenizer's chat template + rendered_text = tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + ) + except Exception: + # Fallback rendering if chat template is unavailable + parts = [] + for m in messages: + parts.append(f"{m['role']}: {m['content']}") + parts.append("assistant:") + rendered_text = "\n".join(parts) + rendered_prompts.append(rendered_text) - prompt_token_ids = tokenizer(prompts).input_ids - completion_token_ids = tokenizer(completions).input_ids - tokenized_dataset = [] - for i in range(len(dataset)): + # Tokenize the prompts and completions. + prompt_token_ids = tokenizer(rendered_prompts).input_ids if rendered_prompts else [] + completion_texts = [completion for _, completion in built_examples] + completion_token_ids = tokenizer(completion_texts).input_ids if completion_texts else [] + + tokenized_dataset: List[Tuple[List[dict], str, int, int]] = [] + for i in range(len(built_examples)): + messages, _ = built_examples[i] + prompt_len = len(prompt_token_ids[i]) output_len = len(completion_token_ids[i]) - tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) + tokenized_dataset.append((messages, rendered_prompts[i], prompt_len, output_len)) - # Filter out too long sequences. - filtered_dataset: List[Tuple[str, int, int]] = [] - for prompt, prompt_token_ids, output_len in tokenized_dataset: - prompt_len = len(prompt_token_ids) + # Filter out too long or too short sequences. + filtered_dataset: List[Tuple[List[dict], str, int, int]] = [] + for messages, rendered_prompt, prompt_len, output_len in tokenized_dataset: if prompt_len < 4 or output_len < 4: - # Prune too short sequences. continue - if prompt_len > 1024 or prompt_len + output_len > 2048: - # Prune too long sequences. + if (prompt_len + output_len) >= max_total_tokens: continue - filtered_dataset.append((prompt, prompt_len, output_len)) + filtered_dataset.append((messages, rendered_prompt, prompt_len, output_len)) # Sample the requests. - sampled_requests = random.sample(filtered_dataset, num_requests) + sampled_requests = filtered_dataset[:num_requests] sum_len = 0 - for e in sampled_requests: - sum_len += e[1] + e[2] + for _, _, prompt_len, output_len in sampled_requests: + sum_len += prompt_len + output_len print("total tokens:", sum_len) return sampled_requests async def get_request( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, -) -> AsyncGenerator[Tuple[str, int, int], None]: +) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: input_requests = iter(input_requests) for request in input_requests: yield request @@ -125,48 +181,96 @@ async def get_request( await asyncio.sleep(interval) -async def send_request(prompt: str, prompt_len: int, output_len: int) -> None: - request_start_time = time.time() - headers = {"Content-Type": "application/json"} - headers = {"User-Agent": "Benchmark Client"} - url = "http://localhost:8000/generate" - - data = { - "inputs": prompt, - "parameters": { - "do_sample": False, +async def send_request( + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool +) -> None: + if use_openai_api: + # Use OpenAI API to send the request. + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/v1/chat/completions" + + data = { + "model": "DeepSeek-R1", + "messages": messages, + "top_k": 1, + "top_p": 1.0, + "temperature": 0, + "stream": True, "ignore_eos": True, - "max_new_tokens": output_len, - # 'temperature': 0.1, - }, - } - - timeout = aiohttp.ClientTimeout(total=3 * 3600) - async with aiohttp.ClientSession(timeout=timeout) as session: - while True: + "max_tokens": output_len, + } + timeout = aiohttp.ClientTimeout(total=3 * 3600) + receive_n = 1 + + async with aiohttp.ClientSession(timeout=timeout) as session: async with session.post(url, headers=headers, json=data) as response: chunks = [] + text = "" + start_time = time.time() + is_first = True async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + if delta_time < 0.005: + receive_n += 1 + chunks.append(delta_time) + start_time = now_time + + else: + # Use local server to send the request. + request_start_time = time.time() + headers = {"Content-Type": "application/json", "User-Agent": "Benchmark Client"} + url = "http://localhost:8000/generate_stream" + + data = { + "inputs": rendered_prompt, + "parameters": { + "do_sample": False, + "ignore_eos": True, + "max_new_tokens": output_len, + }, + } + + timeout = aiohttp.ClientTimeout(total=3 * 3600) + async with aiohttp.ClientSession(timeout=timeout) as session: + receive_n = 0 + text = "" + async with session.post(url, headers=headers, json=data) as response: + chunks = [] + start_time = time.time() + is_first = True + async for chunk, _ in response.content.iter_chunks(): + now_time = time.time() + delta_time = now_time - start_time + if is_first: + is_first = False + ttft = delta_time + if delta_time < 0.005: + receive_n += 1 chunks.append(chunk) - output = b"".join(chunks).decode("utf-8") - output = json.loads(output) - - if "error" not in output: - break + text += json.loads(chunk.decode("utf-8")[5:])["token"]["text"] + start_time = now_time request_end_time = time.time() request_latency = request_end_time - request_start_time - REQUEST_LATENCY.append((prompt_len, output_len, request_latency)) + REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) async def benchmark( - input_requests: List[Tuple[str, int, int]], + input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, + use_openai_api: bool = False, ) -> None: tasks: List[asyncio.Task] = [] async for request in get_request(input_requests, request_rate): - prompt, prompt_len, output_len = request - task = asyncio.create_task(send_request(prompt, prompt_len, output_len)) + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api)) tasks.append(task) await asyncio.gather(*tasks) @@ -176,28 +280,40 @@ def main(args: argparse.Namespace): random.seed(args.seed) np.random.seed(args.seed) tokenizer = get_tokenizer(args.tokenizer, "slow") - input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer) + input_requests = sample_requests( + args.dataset, args.num_prompts, tokenizer, args.history_turns, args.max_total_tokens + ) benchmark_start_time = time.time() - asyncio.run(benchmark(input_requests, args.request_rate)) + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api)) benchmark_end_time = time.time() benchmark_time = benchmark_end_time - benchmark_start_time print(f"Total time: {benchmark_time:.2f} s") print(f"Throughput: {args.num_prompts / benchmark_time:.2f} requests/s") # Compute the latency statistics. - avg_latency = np.mean([latency for _, _, latency in REQUEST_LATENCY]) + avg_latency = np.mean([latency for _, _, latency, _ in REQUEST_LATENCY]) print(f"Average latency: {avg_latency:.2f} s") - avg_per_token_latency = np.mean( - [latency / (prompt_len + output_len) for prompt_len, output_len, latency in REQUEST_LATENCY] + avg_time_to_first_token = np.mean([ttft for _, _, _, ttft in REQUEST_LATENCY]) + print("Average time to first token: " f"{avg_time_to_first_token:.2f} s") + avg_per_token_latency = ( + np.mean([latency / (prompt_len + output_len) for prompt_len, output_len, latency, _ in REQUEST_LATENCY]) * 1000 + ) + print(f"Average latency per token: {avg_per_token_latency:.1f} ms") + # avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency, _ in REQUEST_LATENCY]) + # print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + avg_inter_token_latency = ( + np.mean( + [(latency - ttft) / (output_len - 1) for _, output_len, latency, ttft in REQUEST_LATENCY if output_len > 1] + ) + * 1000 ) - print(f"Average latency per token: {avg_per_token_latency:.2f} s") - avg_per_output_token_latency = np.mean([latency / output_len for _, output_len, latency in REQUEST_LATENCY]) - print("Average latency per output token: " f"{avg_per_output_token_latency:.2f} s") + print(f"Average inter-token latency: {avg_inter_token_latency:.1f} ms") if __name__ == "__main__": parser = argparse.ArgumentParser(description="Benchmark the online serving throughput.") + parser.add_argument("--use_openai_api", default=False, action="store_true", help="Use OpenAI API for requests.") parser.add_argument("--dataset", type=str, required=True, help="Path to the dataset.") parser.add_argument("--tokenizer", type=str, required=True, help="Name or path of the tokenizer.") parser.add_argument( @@ -210,6 +326,10 @@ def main(args: argparse.Namespace): "the request arrival times.", ) parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") + parser.add_argument( + "--history-turns", type=int, default=6, help="Max number of context turns before the target assistant reply." + ) + parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() main(args) From 594de7340813b668c6dd416e7c198e80e6c693e9 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Thu, 28 Aug 2025 08:41:35 +0000 Subject: [PATCH 10/13] Improved the testing script --- test/benchmark/service/benchmark_longbench.py | 90 +++++-- test/benchmark/service/benchmark_sharegpt.py | 90 +++++-- .../static_inference/model_infer_mtp.py | 253 +++++++++++++----- test/benchmark/static_inference/test_model.py | 5 + 4 files changed, 336 insertions(+), 102 deletions(-) diff --git a/test/benchmark/service/benchmark_longbench.py b/test/benchmark/service/benchmark_longbench.py index 0a091a0d5..53b9eb360 100644 --- a/test/benchmark/service/benchmark_longbench.py +++ b/test/benchmark/service/benchmark_longbench.py @@ -26,6 +26,7 @@ import aiohttp import numpy as np from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase +from tqdm.asyncio import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -142,22 +143,32 @@ def render_with_template(messages: List[dict]) -> str: async def get_request( input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, + concurrency: int = None, ) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: input_requests = iter(input_requests) - for request in input_requests: - yield request - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) + if concurrency is not None: + # Concurrency-based request generation + # This generator will be consumed by the benchmark function + # which will manage the concurrency + for request in input_requests: + yield request + else: + # Rate-based request generation (original logic) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) async def send_request( - messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool, pbar=None ) -> None: if use_openai_api: # Use OpenAI API to send the request. @@ -191,7 +202,7 @@ async def send_request( if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + # text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") if delta_time < 0.005: receive_n += 1 chunks.append(delta_time) @@ -236,18 +247,50 @@ async def send_request( request_latency = request_end_time - request_start_time REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) + # Update progress bar if provided + if pbar: + pbar.update(1) + async def benchmark( input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, use_openai_api: bool = False, + concurrency: int = None, ) -> None: - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - messages, rendered_prompt, prompt_len, output_len = request - task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api)) - tasks.append(task) - await asyncio.gather(*tasks) + total_requests = len(input_requests) + + # Create progress bar + pbar = tqdm(total=total_requests, desc="Processing requests", unit="req") + + if concurrency is not None: + # Concurrency-based processing + semaphore = asyncio.Semaphore(concurrency) + tasks: List[asyncio.Task] = [] + + async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len): + async with semaphore: + await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len)) + tasks.append(task) + + await asyncio.gather(*tasks) + else: + # Rate-based processing (original logic) + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task( + send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Close progress bar + pbar.close() def main(args: argparse.Namespace): @@ -258,7 +301,7 @@ def main(args: argparse.Namespace): input_requests = sample_requests(args.dataset, args.num_prompts, tokenizer, args.max_total_tokens) benchmark_start_time = time.time() - asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api)) + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency)) benchmark_end_time = time.time() benchmark_time = benchmark_end_time - benchmark_start_time print(f"Total time: {benchmark_time:.2f} s") @@ -298,8 +341,19 @@ def main(args: argparse.Namespace): "Otherwise, we use Poisson process to synthesize " "the request arrival times.", ) + parser.add_argument( + "--concurrency", + type=int, + default=None, + help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.", + ) parser.add_argument("--num-prompts", type=int, default=1, help="Number of prompts to process.") parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() + + # Validate that only one of request_rate or concurrency is set + if args.concurrency is not None and args.request_rate != float("inf"): + raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.") + main(args) diff --git a/test/benchmark/service/benchmark_sharegpt.py b/test/benchmark/service/benchmark_sharegpt.py index 6b0c2ba99..9a7ea556f 100644 --- a/test/benchmark/service/benchmark_sharegpt.py +++ b/test/benchmark/service/benchmark_sharegpt.py @@ -26,6 +26,7 @@ import aiohttp import numpy as np from transformers import AutoModelForCausalLM, PreTrainedTokenizerBase +from tqdm.asyncio import tqdm from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -167,22 +168,32 @@ def to_openai_role(role_value: str) -> str: async def get_request( input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, + concurrency: int = None, ) -> AsyncGenerator[Tuple[List[dict], str, int, int], None]: input_requests = iter(input_requests) - for request in input_requests: - yield request - if request_rate == float("inf"): - # If the request rate is infinity, then we don't need to wait. - continue - # Sample the request interval from the exponential distribution. - interval = np.random.exponential(1.0 / request_rate) - # The next request will be sent after the interval. - await asyncio.sleep(interval) + if concurrency is not None: + # Concurrency-based request generation + # This generator will be consumed by the benchmark function + # which will manage the concurrency + for request in input_requests: + yield request + else: + # Rate-based request generation (original logic) + for request in input_requests: + yield request + + if request_rate == float("inf"): + # If the request rate is infinity, then we don't need to wait. + continue + # Sample the request interval from the exponential distribution. + interval = np.random.exponential(1.0 / request_rate) + # The next request will be sent after the interval. + await asyncio.sleep(interval) async def send_request( - messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool + messages: List[dict], rendered_prompt: str, prompt_len: int, output_len: int, use_openai_api: bool, pbar=None ) -> None: if use_openai_api: # Use OpenAI API to send the request. @@ -216,7 +227,7 @@ async def send_request( if is_first: is_first = False ttft = delta_time - text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") + # text += json.loads(chunk.decode("utf-8")[6:])["choices"][0]["delta"].get("content", "") if delta_time < 0.005: receive_n += 1 chunks.append(delta_time) @@ -261,18 +272,50 @@ async def send_request( request_latency = request_end_time - request_start_time REQUEST_LATENCY.append((prompt_len, output_len, request_latency, ttft)) + # Update progress bar if provided + if pbar: + pbar.update(1) + async def benchmark( input_requests: List[Tuple[List[dict], str, int, int]], request_rate: float, use_openai_api: bool = False, + concurrency: int = None, ) -> None: - tasks: List[asyncio.Task] = [] - async for request in get_request(input_requests, request_rate): - messages, rendered_prompt, prompt_len, output_len = request - task = asyncio.create_task(send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api)) - tasks.append(task) - await asyncio.gather(*tasks) + total_requests = len(input_requests) + + # Create progress bar + pbar = tqdm(total=total_requests, desc="Processing requests", unit="req") + + if concurrency is not None: + # Concurrency-based processing + semaphore = asyncio.Semaphore(concurrency) + tasks: List[asyncio.Task] = [] + + async def send_with_semaphore(messages, rendered_prompt, prompt_len, output_len): + async with semaphore: + await send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task(send_with_semaphore(messages, rendered_prompt, prompt_len, output_len)) + tasks.append(task) + + await asyncio.gather(*tasks) + else: + # Rate-based processing (original logic) + tasks: List[asyncio.Task] = [] + async for request in get_request(input_requests, request_rate, concurrency): + messages, rendered_prompt, prompt_len, output_len = request + task = asyncio.create_task( + send_request(messages, rendered_prompt, prompt_len, output_len, use_openai_api, pbar) + ) + tasks.append(task) + await asyncio.gather(*tasks) + + # Close progress bar + pbar.close() def main(args: argparse.Namespace): @@ -285,7 +328,7 @@ def main(args: argparse.Namespace): ) benchmark_start_time = time.time() - asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api)) + asyncio.run(benchmark(input_requests, args.request_rate, args.use_openai_api, args.concurrency)) benchmark_end_time = time.time() benchmark_time = benchmark_end_time - benchmark_start_time print(f"Total time: {benchmark_time:.2f} s") @@ -325,6 +368,12 @@ def main(args: argparse.Namespace): "Otherwise, we use Poisson process to synthesize " "the request arrival times.", ) + parser.add_argument( + "--concurrency", + type=int, + default=None, + help="Number of concurrent requests to maintain. " "Cannot be used together with --request-rate.", + ) parser.add_argument("--num-prompts", type=int, default=1000, help="Number of prompts to process.") parser.add_argument( "--history-turns", type=int, default=6, help="Max number of context turns before the target assistant reply." @@ -332,4 +381,9 @@ def main(args: argparse.Namespace): parser.add_argument("--max-total-tokens", type=int, default=16384, help="Max total tokens (input + output).") parser.add_argument("--seed", type=int, default=0) args = parser.parse_args() + + # Validate that only one of request_rate or concurrency is set + if args.concurrency is not None and args.request_rate != float("inf"): + raise ValueError("Cannot set both --request-rate and --concurrency. Please use only one.") + main(args) diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index eb36bc873..fa7e92dbd 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -21,32 +21,40 @@ def init_mtp_model(args: StartArgs, kvargs, main_model): mtp_step = args.mtp_step draft_models = [] + logger.info(f"Initializing {mtp_step} MTP draft models") + os.environ["DISABLE_CHECK_MAX_LEN_INFER"] = "1" mtp_model_kvargs = kvargs mtp_model_kvargs.update( { "weight_dir": args.mtp_draft_model_dir, "max_total_token_num": main_model.mem_manager.size, - "use_dynamic_prompt_cache": False, "disable_chunked_prefill": True, "mtp_mode": args.mtp_mode, "main_model": main_model, } ) for i in range(mtp_step): - mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) - mtp_model_kvargs.update( - { - "weight_dir": args.spec_model_dir, - "max_total_token_num": main_model.mem_manager.size, - "use_dynamic_prompt_cache": False, - "disable_chunked_prefill": True, - "mtp_mode": args.mtp_mode, - "main_model": main_model, - "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], - } - ) - draft_models.append(Deepseek3MTPModel(mtp_model_kvargs)) + try: + mtp_model_cfg, _ = PretrainedConfig.get_config_dict(args.mtp_draft_model_dir) + mtp_model_kvargs.update( + { + "weight_dir": args.mtp_draft_model_dir, + "max_total_token_num": main_model.mem_manager.size, + "disable_chunked_prefill": True, + "mtp_mode": args.mtp_mode, + "main_model": main_model, + "mem_layer_start": main_model.config["num_hidden_layers"] + i * mtp_model_cfg["num_hidden_layers"], + } + ) + draft_model = Deepseek3MTPModel(mtp_model_kvargs) + draft_models.append(draft_model) + logger.info(f"Successfully initialized draft model {i+1}/{mtp_step}") + except Exception as e: + logger.error(f"Failed to initialize draft model {i+1}: {str(e)}") + raise + + logger.info(f"Successfully initialized all {len(draft_models)} draft models") return draft_models @@ -70,12 +78,11 @@ def test_model_inference_mtp(args): "max_total_token_num": args.max_total_token_num, "graph_max_len_in_batch": args.max_req_total_len, "graph_max_batch_size": args.graph_max_batch_size, - "mem_faction": args.mem_fraction, + "mem_fraction": args.mem_fraction, "max_req_num": 2000, "batch_max_tokens": 2048, "run_mode": "normal", "max_seq_length": args.max_req_total_len, - "spec_algo": args.spec_algo, "disable_cudagraph": args.disable_cudagraph, } proc = multiprocessing.Process( @@ -94,7 +101,7 @@ def test_model_inference_mtp(args): return -def torch_profile(fn, log_dir=None): +def torch_profile(fn, batch_size, log_dir=None): torch.cuda.synchronize() with profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], @@ -103,69 +110,124 @@ def torch_profile(fn, log_dir=None): on_trace_ready=torch.profiler.tensorboard_trace_handler(log_dir), ) as prof: fn() + torch.cuda.synchronize() if get_current_rank_in_dp() == 0: - print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=10)) - - -def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False): + logger.info(f"batch_size {batch_size}\n{prof.key_averages().table(sort_by='cuda_time_total', row_limit=20)}") + table = prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=20) + logger.info(table if table else " (no ops recorded)") + + +def run_forward_once( + args, + input_len, + output_len, + batch_size, + main_model, + draft_models, + warmup=False, + enable_torch_profile=False, + skip_prefill=False, +): import time + import torch.distributed as dist + + dist.barrier() torch.cuda.synchronize() prefill_start_time = time.time() - test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) - test_data = test_data.reshape(-1) - test_data = torch.from_numpy(test_data).cuda() - b_req_idx = torch.tensor( [main_model.req_manager.alloc() for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) + b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") b_ready_cache_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") for i in range(batch_size): b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() - # Main model Prefill - model_input = ModelInput( - batch_size=batch_size, - total_token_num=total_token_num, - max_len_in_batch=input_len, - input_ids=test_data, - mem_indexes=mem_indexes, - b_req_idx=b_req_idx, - b_seq_len=b_seq_len, - is_prefill=True, - b_ready_cache_len=b_ready_cache_len, - ) - model_output: ModelOutput = main_model.forward(model_input) - prob_out = torch.softmax(model_output.logits, dim=-1) - predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) - predict_ids = predict_ids.detach().cpu().numpy() - - draft_ids = [predict_ids] + if skip_prefill: + # Skip prefill computation but simulate the state after prefill + # Generate dummy output tokens as if prefill happened + draft_ids = [] + + # Generate dummy token IDs for main model and draft models + # Simulate one token output per model (main + draft models) + for model_idx in range(len(draft_models) + 1): + # Generate random token IDs as if they were predicted + dummy_predict_ids = np.random.randint(1000, 10000, (batch_size, 1)) + draft_ids.append(dummy_predict_ids) + + # Update sequence lengths to reflect that prefill "happened" + # No need to update b_seq_len as it already contains input_len + + if get_current_rank_in_dp() == 0 and not warmup: + logger.info(f"Skipped prefill phase, simulated {len(draft_ids)} draft outputs") + else: + # Generate test data for actual prefill + test_data = np.vstack([np.random.randint(0, 50256, input_len) for _ in range(batch_size)]) + test_data = test_data.reshape(-1) + test_data = torch.from_numpy(test_data).cuda() + + # Allocate memory for prefill tokens + mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + # Main model Prefill + model_input = ModelInput( + batch_size=batch_size, + total_token_num=total_token_num, + max_len_in_batch=input_len, + input_ids=test_data, + b_req_idx=b_req_idx, + b_mtp_index=b_mtp_index, + b_seq_len=b_seq_len, + mem_indexes=mem_indexes, + is_prefill=True, + b_ready_cache_len=b_ready_cache_len, + ) - # Draft model Prefill - # For simplicity, we'll just take the input of main_model to draft model. - model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens - for draft_model_id in range(len(draft_models)): - draft_model = draft_models[draft_model_id] - model_output = draft_model.forward(model_input) + model_output: ModelOutput = main_model.forward(model_input) prob_out = torch.softmax(model_output.logits, dim=-1) predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) predict_ids = predict_ids.detach().cpu().numpy() - draft_ids.append(predict_ids) + + draft_ids = [predict_ids] + + # Draft model Prefill + # For simplicity, we'll just take the input of main_model to draft model. model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens + for draft_model_id in range(len(draft_models)): + draft_model = draft_models[draft_model_id] + model_output = draft_model.forward(model_input) + prob_out = torch.softmax(model_output.logits, dim=-1) + predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) + predict_ids = predict_ids.detach().cpu().numpy() + draft_ids.append(predict_ids) + model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens torch.cuda.synchronize() prefill_end_time = time.time() - if get_current_rank_in_dp() == 0 and not warmup: - print("prefill time cost:", (prefill_end_time - prefill_start_time) * 1000) - print( - f"Prefill throughput: {batch_size * input_len * args.dp / (prefill_end_time - prefill_start_time)} tokens/s" - ) + + rank_id = get_current_rank_in_dp() + + if rank_id == 0 and not warmup and not skip_prefill: + prefill_time = (prefill_end_time - prefill_start_time) * 1000 + dp_size = getattr(args, "dp", 1) + throughput = dp_size * batch_size * input_len / (prefill_end_time - prefill_start_time) + logger.info(f"prefill time cost: {prefill_time:.2f} ms, prefill throughput: {throughput:.2f} tokens/s") + + # Add profiling support for prefill + if enable_torch_profile and not warmup and not skip_prefill: + logger.info("Profile Prefill") + try: + torch_profile( + lambda: main_model.forward(model_input), + batch_size, + log_dir=f"./logs/forward_prefill_mtp_bs{batch_size}_{rank_id}", + ) + except Exception as e: + logger.error(f"Profiling error: {str(e)}") + # Continue without profiling torch.cuda.synchronize() @@ -174,12 +236,14 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ # build main decode input: nopad_b_seq_idx = [] + nopad_b_mtp_index = [] nopad_b_seq_len = [] nopad_total_token_num = 0 nopad_max_len_in_batch = 0 for i in range(batch_size): nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_mtp_index.append(0) seq_len = b_seq_len[i].item() nopad_b_seq_len.append(seq_len + 1) nopad_total_token_num += seq_len + 1 @@ -187,11 +251,13 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ for step in range(len(draft_models)): nopad_b_seq_idx.append(b_req_idx[i]) + nopad_b_mtp_index.append(step + 1) nopad_b_seq_len.append(seq_len + step + 2) nopad_total_token_num += seq_len + step + 2 nopad_max_len_in_batch = max(nopad_max_len_in_batch, seq_len + step + 2) nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") + nopad_b_mtp_index = torch.tensor(nopad_b_mtp_index, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() @@ -200,9 +266,10 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ total_token_num=nopad_total_token_num, max_len_in_batch=nopad_max_len_in_batch, input_ids=decode_input_ids, - mem_indexes=mem_indexes, b_req_idx=nopad_b_seq_idx, + b_mtp_index=nopad_b_mtp_index, b_seq_len=nopad_b_seq_len, + mem_indexes=mem_indexes, is_prefill=False, ) @@ -234,15 +301,31 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ model_input.input_ids = predict_ids.reshape(-1) model_input.deepseekv3_mtp_draft_input_hiddens = model_output.deepseekv3_mtp_main_output_hiddens torch.cuda.synchronize() - if i % 100 == 0 or i == output_len - 1: + if i % 100 == 0 or i == output_len - (len(draft_models) + 1): step_end_time = time.time() - if get_current_rank_in_dp() == 0 and not warmup: + if rank_id == 0 and not warmup: step_time = step_end_time - step_start_time - print(i, " step cost time:", step_time * 1000) - print(f"Decode throughput: {batch_size * (len(draft_models) + 1) * args.dp / step_time} tokens/s") + dp_size = getattr(args, "dp", 1) + throughput = dp_size * batch_size * (len(draft_models) + 1) / step_time + logger.info(f"i: {i}, step cost time: {step_time * 1000:.2f} ms, throughput: {throughput:.2f} tokens/s") + + # Add profiling support for decode on last step + if enable_torch_profile and not warmup and i == output_len - (len(draft_models) + 1): + logger.info("Profile Decode") + try: + torch_profile( + lambda: main_model.forward(model_input), + batch_size, + log_dir=f"./logs/forward_decode_mtp_bs{batch_size}_{rank_id}", + ) + except Exception as e: + logger.error(f"Profiling error: {str(e)}") + # Continue without profiling main_model.mem_manager.free_all() main_model.req_manager.free_all() + torch.cuda.synchronize() + torch.cuda.empty_cache() def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, ans_queue): @@ -252,11 +335,22 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a from lightllm.distributed import dist_group_manager from lightllm.utils.dist_utils import set_current_device_id + # Handle batch_sizes as either int or list + if isinstance(batch_sizes, int): + batch_sizes = [batch_sizes] + else: + # Default batch sizes for comprehensive testing + batch_sizes = [16, 32, 64] + + logger.info(f"Testing batch sizes: {batch_sizes}") + import torch.distributed as dist enable_decode_overlap = args.enable_decode_microbatch_overlap group_size = 1 if enable_decode_overlap or args.enable_prefill_microbatch_overlap: + for bs in batch_sizes: + assert bs % 2 == 0, f"batch size {bs} must be even number for overlap mode" group_size = 2 init_distributed_env(model_kvargs) dist_group_manager.create_groups(group_size=group_size) @@ -267,14 +361,41 @@ def tppart_model_infer(args, model_kvargs, batch_sizes, input_len, output_len, a main_model, _ = get_model(model_cfg, model_kvargs) draft_models = init_mtp_model(args, model_kvargs, main_model) - if isinstance(batch_sizes, int): - batch_sizes = [batch_sizes] - + rank_id = model_kvargs["rank_id"] + skip_prefill = getattr(args, "skip_prefill", False) for batch_size in batch_sizes: + if rank_id == 0: + logger.info(f"Testing batch size {batch_size}") + # warm up - run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=True) + run_forward_once( + args, + input_len, + 10, + batch_size, + main_model, + draft_models, + warmup=True, + enable_torch_profile=False, + skip_prefill=skip_prefill, + ) torch.cuda.synchronize() - run_forward_once(args, input_len, output_len, batch_size, main_model, draft_models, warmup=False) + + # actual test + enable_profiling = getattr(args, "torch_profile", False) + run_forward_once( + args, + input_len, + output_len, + batch_size, + main_model, + draft_models, + warmup=False, + enable_torch_profile=enable_profiling, + skip_prefill=skip_prefill, + ) + if rank_id == 0: + logger.info("=" * 50) dist.barrier() ans_queue.put(True) diff --git a/test/benchmark/static_inference/test_model.py b/test/benchmark/static_inference/test_model.py index 5b3751bcc..8725ac267 100644 --- a/test/benchmark/static_inference/test_model.py +++ b/test/benchmark/static_inference/test_model.py @@ -40,6 +40,11 @@ def test_model_infer(self): action="store_true", help="Enable torch profiler to profile the model", ) + parser.add_argument( + "--skip_prefill", + action="store_true", + help="Whether or not to skip prefill phase, because it is easy to have OOM in large batches", + ) args = parser.parse_args() set_env_start_args(args) torch.multiprocessing.set_start_method("spawn") From d70d188b494b2e3efe4cfc8cd0736d26c7089ec8 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Fri, 29 Aug 2025 10:22:13 +0800 Subject: [PATCH 11/13] Remove redundant initialization of b_q_seq_len --- lightllm/common/basemodel/triton_kernel/gen_decode_params.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index a8237f1a7..1b40638e7 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -19,7 +19,6 @@ def gen_decode_params(b_seq_len: torch.Tensor): b_q_seq_len[: len(b_seq_len) // mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size] ) else: - b_q_seq_len = torch.ones_like(b_seq_len) b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) return b_q_seq_len, b1_cu_q_seq_len, b_kv_seq_len, b1_cu_kv_seq_len, position_ids From d45865cf64185e0f92ff486a4be8a125e00400f8 Mon Sep 17 00:00:00 2001 From: WANDY666 <1060304770@qq.com> Date: Fri, 29 Aug 2025 08:06:24 +0000 Subject: [PATCH 12/13] custom flash_attn --- .../triton_kernel/gen_decode_params.py | 2 +- lightllm/common/flash_attn.py | 181 ++++-------------- .../layer_infer/transformer_layer_infer.py | 11 +- .../kernel/benchmark_fa3_decode_mtp.py | 11 +- 4 files changed, 47 insertions(+), 158 deletions(-) diff --git a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py index 1b40638e7..c3f604c8c 100644 --- a/lightllm/common/basemodel/triton_kernel/gen_decode_params.py +++ b/lightllm/common/basemodel/triton_kernel/gen_decode_params.py @@ -16,7 +16,7 @@ def gen_decode_params(b_seq_len: torch.Tensor): if enable_fa3_mtp: b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor( - b_q_seq_len[: len(b_seq_len) // mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size] + b_q_seq_len[mtp_size - 1 :: mtp_size], b_kv_seq_len[mtp_size - 1 :: mtp_size] ) else: b1_cu_q_seq_len, b1_cu_kv_seq_len = gen_cumsum_pad0_tensor(b_q_seq_len, b_kv_seq_len) diff --git a/lightllm/common/flash_attn.py b/lightllm/common/flash_attn.py index f3a505035..66609e700 100644 --- a/lightllm/common/flash_attn.py +++ b/lightllm/common/flash_attn.py @@ -1,21 +1,3 @@ -# This file is adapted from sgl-project/sglang: -# https://github.com/sgl-project/sglang/blob/main/sgl-kernel/python/sgl_kernel/flash_attn.py -# The original code and this file are licensed under the Apache License, Version 2.0. -# -# Copyright (c) sgl-project and other contributors. -# Modifications Copyright (c) LightLLM contributors. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. import torch from typing import List, Optional, Tuple, Union from lightllm.utils.log_utils import init_logger @@ -23,7 +5,7 @@ logger = init_logger(__name__) -def maybe_contiguous(x): +def get_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x @@ -34,152 +16,61 @@ def maybe_contiguous(x): def flash_attn_with_kvcache_mtp( q, - k_cache, - v_cache, - k=None, - v=None, - qv=None, - rotary_cos=None, - rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, - cache_batch_idx: Optional[torch.Tensor] = None, - cache_leftpad: Optional[torch.Tensor] = None, - page_table: Optional[torch.Tensor] = None, + k, + v, + k_new: Optional[torch.Tensor] = None, + v_new: Optional[torch.Tensor] = None, + q_v: Optional[torch.Tensor] = None, cu_seqlens_q: Optional[torch.Tensor] = None, + cu_seqlens_k: Optional[torch.Tensor] = None, cu_seqlens_k_new: Optional[torch.Tensor] = None, + seqused_q: Optional[torch.Tensor] = None, + seqused_k: Optional[torch.Tensor] = None, max_seqlen_q: Optional[int] = None, + max_seqlen_k: Optional[int] = None, + page_table: Optional[torch.Tensor] = None, + cache_batch_idx: Optional[torch.Tensor] = None, + cache_leftpad: Optional[torch.Tensor] = None, + rotary_cos: Optional[torch.Tensor] = None, + rotary_sin: Optional[torch.Tensor] = None, rotary_seqlens: Optional[torch.Tensor] = None, q_descale: Optional[torch.Tensor] = None, k_descale: Optional[torch.Tensor] = None, v_descale: Optional[torch.Tensor] = None, softmax_scale=None, - causal=False, - window_size=(-1, -1), # -1 means infinite context window + is_causal=False, + window_size=(-1, -1), softcap=0.0, # 0.0 means deactivated - rotary_interleaved=True, + is_rotary_interleaved=True, scheduler_metadata=None, - num_splits=0, # Can be tuned for speed - pack_gqa=None, # Can be tuned for speed - sm_margin=0, # Can be tuned if some SMs are used for communication - return_softmax_lse=False, + num_splits=0, + pack_gqa=None, + sm_margin=0, mtp_step=0, ): - """ - If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from - k and v. This is useful for incremental decoding: you can pass in the cached keys/values from - the previous step, and update them with the new keys/values from the current step, and do - attention with the updated cache, all in 1 kernel. - - If you pass in k / v, you must make sure that the cache is large enough to hold the new values. - For example, the KV cache could be pre-allocated with the max sequence length, and you can use - cache_seqlens to keep track of the current sequence lengths of each sequence in the batch. - - Also apply rotary embedding if rotary_cos and rotary_sin are passed in. The key @k will be - rotated by rotary_cos and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If causal or local (i.e., window_size != (-1, -1)), the query @q will be rotated by rotary_cos - and rotary_sin at indices cache_seqlens, cache_seqlens + 1, etc. - If not causal and not local, the query @q will be rotated by rotary_cos and rotary_sin at - indices cache_seqlens only (i.e. we consider all tokens in @q to be at position cache_seqlens). - - See tests/test_flash_attn.py::test_flash_attn_kvcache for examples of how to use this function. - - Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads - than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. - For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head - 0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - - If causal=True, the causal mask is aligned to the bottom right corner of the attention matrix. - For example, if seqlen_q = 2 and seqlen_k = 5, the causal mask (1 = keep, 0 = masked out) is: - 1 1 1 1 0 - 1 1 1 1 1 - If seqlen_q = 5 and seqlen_k = 2, the causal mask is: - 0 0 - 0 0 - 0 0 - 1 0 - 1 1 - If the row of the mask is all zero, the output will be zero. - - If window_size != (-1, -1), implements sliding window local attention. Query at position i - will only attend to keys between - [i + seqlen_k - seqlen_q - window_size[0], i + seqlen_k - seqlen_q + window_size[1]] inclusive. - - Note: Does not support backward pass. - - Arguments: - q: (batch_size, seqlen, nheads, headdim) - k_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim) if there's no page_table, - or (num_blocks, page_block_size, nheads_k, headdim) if there's a page_table (i.e. paged KV cache) - page_block_size must be a multiple of 256. - v_cache: (batch_size_cache, seqlen_cache, nheads_k, headdim_v) if there's no page_table, - or (num_blocks, page_block_size, nheads_k, headdim_v) if there's a page_table (i.e. paged KV cache) - k [optional]: (batch_size, seqlen_new, nheads_k, headdim). If not None, we concatenate - k with k_cache, starting at the indices specified by cache_seqlens. - v [optional]: (batch_size, seqlen_new, nheads_k, headdim_v). Similar to k. - qv [optional]: (batch_size, seqlen, nheads, headdim_v) - rotary_cos [optional]: (seqlen_ro, rotary_dim / 2). If not None, we apply rotary embedding - to k and q. Only applicable if k and v are passed in. rotary_dim must be divisible by 16. - rotary_sin [optional]: (seqlen_ro, rotary_dim / 2). Similar to rotary_cos. - cache_seqlens: int, or (batch_size,), dtype torch.int32. The sequence lengths of the - KV cache. - cache_batch_idx: (batch_size,), dtype torch.int32. The indices used to index into the KV cache. - If None, we assume that the batch indices are [0, 1, 2, ..., batch_size - 1]. - If the indices are not distinct, and k and v are provided, the values updated in the cache - might come from any of the duplicate indices. - cache_leftpad: (batch_size,), dtype torch.int32. The index that the KV cache starts. If None, assume 0. - page_table [optional]: (batch_size, max_num_blocks_per_seq), dtype torch.int32. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). - window_size: (left, right). If not (-1, -1), implements sliding window local attention. - softcap: float. Anything > 0 activates softcapping attention. - rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. - If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, - rotary embedding will combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1 - (i.e. GPT-NeoX style). - num_splits: int. If > 1, split the key/value into this many chunks along the sequence. - If num_splits == 1, we don't split the key/value. If num_splits == 0, we use a heuristic - to automatically determine the number of splits. - Don't change this unless you know what you are doing. - return_softmax_lse: bool. Whether to return the logsumexp of the attention scores. - - Return: - out: (batch_size, seqlen, nheads, headdim). - softmax_lse [optional, if return_softmax_lse=True]: (batch_size, nheads, seqlen). The - logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax - normalization factor). - """ - assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" - assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" + assert k.stride(-1) == 1, "k must have contiguous last dimension" + assert v.stride(-1) == 1, "v must have contiguous last dimension" if softmax_scale is None: - softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** (-0.5) - if cache_seqlens is not None and isinstance(cache_seqlens, int): - cache_seqlens = torch.full((k_cache.shape[0],), cache_seqlens, dtype=torch.int32, device=k_cache.device) - cache_seqlens = maybe_contiguous(cache_seqlens) - - q, k_cache, k, v = [maybe_contiguous(x) for x in (q, k_cache, k, v)] - v_cache = v_cache.contiguous() if v_cache.stride(-1) != 1 and v_cache.stride(-3) != 1 else v_cache - cu_seqlens_q, cu_seqlens_k_new = [maybe_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)] - page_table, cache_batch_idx, cache_leftpad = [ - maybe_contiguous(x) for x in (page_table, cache_batch_idx, cache_leftpad) - ] - rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] - rotary_seqlens = maybe_contiguous(rotary_seqlens) + softmax_scale = (q.shape[-1] + (q_v.shape[-1] if q_v is not None else 0)) ** (-0.5) + seqused_k = get_contiguous(seqused_k) - # out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( + q, k, k_new, v_new = [get_contiguous(x) for x in (q, k, k_new, v_new)] + v = v.contiguous() if v.stride(-1) != 1 and v.stride(-3) != 1 else v + cu_seqlens_q, cu_seqlens_k_new = [get_contiguous(x) for x in (cu_seqlens_q, cu_seqlens_k_new)] + page_table = get_contiguous(page_table) out, softmax_lse, *rest = flash_attn_3_mtp.fwd( q, - k_cache, - v_cache, k, v, - qv, + k_new, + v_new, + q_v, None, # out cu_seqlens_q, None, # cu_seqlens_k cu_seqlens_k_new, None, # seqused_q - cache_seqlens, + seqused_k, max_seqlen_q, None, # max_seqlen_k page_table, @@ -192,19 +83,19 @@ def flash_attn_with_kvcache_mtp( k_descale, v_descale, softmax_scale, - causal, + is_causal, window_size[0], window_size[1], 0, softcap, - rotary_interleaved, + is_rotary_interleaved, scheduler_metadata, num_splits, pack_gqa, sm_margin, mtp_step, ) - return (out, softmax_lse, *rest) if return_softmax_lse else out + return out except: flash_attn_3_mtp = None diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index 99246cb81..923140cc6 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -563,21 +563,20 @@ def _token_gqa_decode_attention_mtp( k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache_mtp( q=q_rope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.qk_rope_head_dim), - k_cache=k_rope, - v_cache=kv_nope, - qv=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank), + k=k_rope, + v=kv_nope, + q_v=q_nope.reshape(-1, self.tp_q_head_num_ * self.mtp_size, self.kv_lora_rank), page_table=infer_state.page_table[self.mtp_size - 1 :: self.mtp_size], - cache_seqlens=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(), + seqused_k=infer_state.b_seq_len[self.mtp_size - 1 :: self.mtp_size].contiguous(), cu_seqlens_q=infer_state.cu_seqlens_q, cu_seqlens_k_new=infer_state.cu_seqlens_k, max_seqlen_q=1, softmax_scale=self.softmax_scale, - causal=True, + is_causal=True, window_size=(-1, -1), softcap=0.0, k_descale=k_descale, v_descale=v_descale, - return_softmax_lse=False, mtp_step=self.mtp_step, ) return o_tensor.view(-1, self.tp_q_head_num_, self.kv_lora_rank) diff --git a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py index b605f2fe0..05054a43c 100644 --- a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py +++ b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py @@ -113,21 +113,20 @@ def run_fa3_mla_mtp( def flash_mla_fa3(): out = flash_attn_with_kvcache_mtp( q=q_pe.view(-1, BLOCK_H, dpe), - k_cache=blocked_k_pe, - v_cache=blocked_k_nope, - qv=q_nope.view(-1, BLOCK_H, dv), + k=blocked_k_pe, + v=blocked_k_nope, + q_v=q_nope.view(-1, BLOCK_H, dv), page_table=block_table, - cache_seqlens=cache_seqlens, + seqused_k=cache_seqlens, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k_new=cu_seqlens_k, max_seqlen_q=1, softmax_scale=scale, - causal=True, + is_causal=True, window_size=(-1, -1), softcap=0.0, k_descale=k_descale, v_descale=v_descale, - return_softmax_lse=False, mtp_step=1, ) return out.view([b, s_q, h_q, dv]) From 67a1c87fb49654c4f1593b02a77f7415c6f61fa4 Mon Sep 17 00:00:00 2001 From: wanzihao <1060304770@qq.com> Date: Tue, 23 Sep 2025 02:56:37 +0000 Subject: [PATCH 13/13] fix flops --- test/benchmark/kernel/benchmark_fa3_decode_mtp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py index 05054a43c..08564c669 100644 --- a/test/benchmark/kernel/benchmark_fa3_decode_mtp.py +++ b/test/benchmark/kernel/benchmark_fa3_decode_mtp.py @@ -184,7 +184,7 @@ def print_error(a, b, name=""): max_seqlen = cache_seqlens.max().item() max_seqlen_pad = math.ceil(max_seqlen / 256) * 256 # ?为什么对齐256 - total_flops = s_q * (total_seqlens * 2 - batch_mtp) * h_q * (d + dv) * 2 + total_flops = s_q * total_seqlens * h_q * (d + dv) * 2 * mtp_size q = torch.randn(b, s_q, h_q, d, dtype=dtype, device=device) block_table = torch.arange(batch_mtp * max_seqlen_pad, dtype=torch.int32, device=device).view(