diff --git a/lightllm/common/basemodel/basemodel.py b/lightllm/common/basemodel/basemodel.py index ca20a7551..088df2291 100755 --- a/lightllm/common/basemodel/basemodel.py +++ b/lightllm/common/basemodel/basemodel.py @@ -684,10 +684,10 @@ def _check_max_len_infer(self): logger.info("begin check max_len infer") dummy_input_ids = torch.ones(self.batch_max_tokens, dtype=torch.int32, device="cuda") b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = self.batch_max_tokens b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.req_manager.alloc_mem_indices(len(dummy_input_ids), b_seq_len, b_ready_cache_len).cuda() total_token_num = self.batch_max_tokens b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") model_input = ModelInput( @@ -757,12 +757,14 @@ def _autotune_warmup(self): 0, 10000, (input_len,), dtype=torch.int32, device="cuda", generator=rand_gen ) b_req_idx = torch.tensor([self.req_manager.alloc()], dtype=torch.int32, device="cuda") - mem_indexes = self.mem_manager.alloc(len(dummy_input_ids)).cuda() b_seq_len = torch.ones(1, dtype=torch.int32, device="cuda") b_seq_len[:] = input_len b_ready_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") total_token_num = input_len b_mtp_index = torch.zeros(1, dtype=torch.int32, device="cuda") + mem_indexes = self.req_manager.alloc_mem_indices( + len(dummy_input_ids), b_seq_len, b_ready_cache_len + ).cuda() model_input = ModelInput( batch_size=1, total_token_num=total_token_num, diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index 07792865e..3027f4ca3 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -196,13 +196,16 @@ def warmup(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) + b_last_mem_index = torch.zeros_like(b_seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.req_manager.alloc_mem_indices( + len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index + ).cuda() model_input = ModelInput( batch_size=batch_size, @@ -252,13 +255,16 @@ def warmup_overlap(self, model): total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch input_ids = torch.tensor([1 for _ in range(batch_size)], dtype=torch.int32, device="cuda") - mem_indexes = model.mem_manager.alloc(len(input_ids)).cuda() b_req_idx = torch.tensor( [model.req_manager.HOLD_REQUEST_ID for _ in range(batch_size)], dtype=torch.int32, device="cuda" ) b_seq_len = torch.empty(batch_size, dtype=torch.int32, device="cuda") b_seq_len.fill_(seq_len) + b_last_mem_index = torch.zeros_like(b_seq_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + mem_indexes = model.req_manager.alloc_mem_indices( + len(input_ids), b_seq_len, b_last_mem_index=b_last_mem_index + ).cuda() micro_batch = ModelInput( is_prefill=False, diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py index 6ddec24e2..75a4c3039 100644 --- a/lightllm/common/deepseek2_mem_manager.py +++ b/lightllm/common/deepseek2_mem_manager.py @@ -8,6 +8,7 @@ from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.common.kv_trans_kernel.kv_trans_v2 import kv_trans_v2_for_d_node, kv_trans_v2_for_p_node from lightllm.distributed.pynccl import PyNcclCommunicator +from lightllm.utils.envs_utils import get_page_size logger = init_logger(__name__) @@ -20,7 +21,12 @@ def get_cell_size(self): return self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): - self.kv_buffer = torch.empty((layer_num, size + 1, head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + self.kv_buffer = torch.empty( + (layer_num, (size // page_size + 1) * page_size, head_num, head_dim), + dtype=dtype, + device="cuda", + ) # todo, etp or edp use the same work buffer here # also it can be used for any kernels for work buffer witout save info only diff --git a/lightllm/common/mem_manager.py b/lightllm/common/mem_manager.py index 4142ce4aa..26ec7c970 100755 --- a/lightllm/common/mem_manager.py +++ b/lightllm/common/mem_manager.py @@ -2,6 +2,7 @@ import os import torch import torch.distributed as dist +import triton from typing import List, Union from lightllm.server.pd_io_struct import KVMoveTask from lightllm.utils.log_utils import init_logger @@ -9,7 +10,7 @@ from lightllm.utils.profile_max_tokens import get_available_gpu_memory, get_total_gpu_memory from lightllm.common.kv_trans_kernel.kv_trans import kv_trans from lightllm.utils.dist_utils import get_current_rank_in_node -from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args +from lightllm.utils.envs_utils import get_unique_server_name, get_env_start_args, get_page_size from lightllm.distributed.pynccl import PyNcclCommunicator from lightllm.utils.dist_utils import get_current_device_id @@ -81,7 +82,12 @@ def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): # 分配,内部实际也没有管理,这个token是预留来对一些特殊的运行模式,如多dp下,overlap microbatch # 等模式下 padding 一些请求,使推理过程可以正常运行采用的,其索引值为size,存储在HOLD_TOKEN_MEMINDEX # 成员变量中,其与 req_manager 中的HOLD_REQUEST_ID具有类似的作用和意义。 - self.kv_buffer = torch.empty((layer_num, size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda") + page_size = get_page_size() + self.kv_buffer = torch.empty( + (layer_num, (size // page_size + 1) * page_size, 2 * head_num, head_dim), + dtype=dtype, + device="cuda", + ) def alloc_kv_move_buffer(self, max_req_total_len): """ @@ -244,6 +250,7 @@ def _free_buffers(self): self.kv_buffer = None def alloc(self, need_size) -> torch.Tensor: + assert need_size % get_page_size() == 0, "Need size must be a multiple of page size" if need_size > self.mark_end - self.mark_start: logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") assert False, "error alloc state" @@ -265,18 +272,25 @@ def free(self, free_index: Union[torch.Tensor, List[int]]): """ end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + page_size = get_page_size() + free_len = page_size * triton.cdiv(len(free_index), page_size) + start = self.mark_start - free_len + assert start >= 0, f"error free state start: {self.mark_start} free len {free_len}" if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index + free_index = torch.tensor(free_index) + + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + if page_size > 1: + base_free_index = free_index[free_index % page_size == 0] + token_idxs = base_free_index[:, None] + torch.arange(page_size) + self.mem_state[start:end] = token_idxs.flatten() else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 self.mem_state[start:end] = free_index - self.mark_start -= len(free_index) + self.mark_start -= free_len - self.can_use_mem_size += len(free_index) + self.can_use_mem_size += free_len self.shared_can_use_token_num.set_value(self.can_use_mem_size) if self.can_use_mem_size == len(self.mem_state): diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index dcd1b3072..3f85d7f57 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -1,11 +1,12 @@ import torch import collections +import triton from lightllm.utils.log_utils import init_logger from .mem_manager import MemoryManager from typing import List, Optional from lightllm.common.basemodel.triton_kernel.gen_sampling_params import token_id_counter from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter -from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args +from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args, get_page_size from lightllm.utils.config_utils import get_vocab_size logger = init_logger(__name__) @@ -67,6 +68,27 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: MemoryMana self.max_request_num = max_request_num self.HOLD_REQUEST_ID = max_request_num + def calc_real_need_token_num(self, need_token_num, b_seq_len, b_ready_cache_len=None): + return max(need_token_num, self._get_need_paged_token_num(b_seq_len, b_ready_cache_len)) + + def calc_last_mem_index_in_prefill(self, mem_indices, b_seq_len, b_ready_cache_len=None): + b_token_len = b_seq_len + if b_ready_cache_len is not None: + b_token_len = b_seq_len - b_ready_cache_len + b_token_len_cumsum = torch.cumsum(b_token_len, dim=0) + b_last_mem_index = mem_indices[b_token_len_cumsum - 1] + return b_last_mem_index + + # b_ready_cache_len为None时才需要b_last_mem_index + def alloc_mem_indices( + self, need_size, b_seq_len=None, b_ready_cache_len=None, b_last_mem_index=None + ) -> torch.Tensor: + page_size = get_page_size() + if page_size > 1 and b_seq_len is not None: + return self._alloc_paged_mem_indices(page_size, b_seq_len, b_ready_cache_len, b_last_mem_index) + else: + return self.mem_manager.alloc(need_size) + def alloc(self): return self.req_list.alloc() @@ -92,6 +114,58 @@ def free_all(self): self.req_list = _ReqLinkedList(self.max_request_num) return + def _expand_by_page_size(self, b_token_len, page_size): + # 将seq_len按page整数倍展开,例如seq_len = [9,9,9] -> p_token_len = [4,4,1,4,4,1,4,4,1], page_size = 4 + b_page_len = triton.cdiv(b_token_len, page_size) + need_pages_num = b_page_len.sum() + p_token_len = torch.full((need_pages_num,), page_size, dtype=b_token_len.dtype, device=b_token_len.device) + cumsum_pages = torch.cumsum(b_page_len, dim=0) + last_page_positions = cumsum_pages - 1 + remainders = b_token_len - (b_page_len - 1) * page_size + p_token_len[last_page_positions] = remainders + return need_pages_num, p_token_len + + def _alloc_paged_mem_indices(self, page_size, b_seq_len, b_ready_cache_len, b_last_mem_index): + b_seq_len = b_seq_len.cpu() + if b_ready_cache_len is not None: + # prefill + b_ready_cache_len = b_ready_cache_len.cpu() + b_token_len = b_seq_len - b_ready_cache_len + total_pages_needed, p_token_len = self._expand_by_page_size(b_token_len, page_size) + paged_token_idxs = self.mem_manager.alloc(total_pages_needed * page_size) + pages = paged_token_idxs.view(-1, page_size) + mask = torch.arange(page_size, device=p_token_len.device) < p_token_len.unsqueeze(1) + return pages[mask] + else: + # decode + assert b_last_mem_index is not None + b_last_mem_index = b_last_mem_index.cpu() + need_new_page_mask = (b_seq_len - 1) % page_size == 0 + new_pages_num = need_new_page_mask.sum() + token_idxs = torch.zeros_like(b_seq_len, device=b_seq_len.device) + if new_pages_num > 0: + new_pages_tokens = self.mem_manager.alloc(new_pages_num * page_size) + token_idxs[need_new_page_mask] = new_pages_tokens[::page_size] + mask = ~need_new_page_mask + if mask.any(): + token_idxs[mask] = b_last_mem_index[mask] + 1 + return token_idxs + + def _get_need_paged_token_num(self, b_seq_len, b_ready_cache_len=None): + page_size = get_page_size() + if page_size == 1: + return 0 + + need_new_pages = 0 + if b_ready_cache_len is not None: + need_tokens_array = b_seq_len - b_ready_cache_len + need_pages_array = (need_tokens_array + page_size - 1) // page_size + need_new_pages = need_pages_array.sum() + else: + mask = (b_seq_len - 1) % page_size == 0 + need_new_pages = mask.sum() + return need_new_pages * page_size + class ReqSamplingParamsManager: """ diff --git a/lightllm/models/deepseek2/flashattention_infer_struct.py b/lightllm/models/deepseek2/flashattention_infer_struct.py index d2ae055ce..be979e288 100644 --- a/lightllm/models/deepseek2/flashattention_infer_struct.py +++ b/lightllm/models/deepseek2/flashattention_infer_struct.py @@ -2,8 +2,10 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo from lightllm.utils.dist_utils import get_current_device_id +from lightllm.utils.envs_utils import get_page_size class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): @@ -11,6 +13,7 @@ class Deepseek2FlashAttentionStateInfo(Deepseek2InferStateInfo): def __init__(self): super().__init__() + self.page_size = get_page_size() @classmethod def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): @@ -39,19 +42,19 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.cu_seqlens_k = self.b1_cu_kv_seq_len max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + length = triton.cdiv(model.graph_max_len_in_batch, self.page_size) + page_buffer = Deepseek2FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty((self.batch_size, self.max_len_in_batch), dtype=torch.int32).to( - input_ids.device - ) + length = triton.cdiv(self.max_len_in_batch, self.page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32).to(input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k] - ) - self.page_table[:, max_seq_len_k:].fill_(0) + length = triton.cdiv(max_seq_len_k, self.page_size) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table[:, :length].copy_(token_indexs // self.page_size) + self.page_table[:, length:].fill_(0) return diff --git a/lightllm/models/deepseek2/flashinfer_struct.py b/lightllm/models/deepseek2/flashinfer_struct.py index a00c45601..213b89ad7 100644 --- a/lightllm/models/deepseek2/flashinfer_struct.py +++ b/lightllm/models/deepseek2/flashinfer_struct.py @@ -2,8 +2,9 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index @@ -13,6 +14,7 @@ def __init__(self): self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -23,22 +25,26 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: self.q_indptr = torch.arange(self.batch_size + 1, dtype=torch.int32).to(input_ids.device) + length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) + b_page_len = triton.cdiv(self.b_seq_len, self.page_size) + self.kv_starts[1:] = b_page_len.cumsum(0) repack_kv_index( self.req_manager.req_to_token_indexs, self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.max_len_in_batch, self.page_size), + self.page_size, self.kv_indices, ) if self.decode_wrapper is None: @@ -58,7 +64,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.kv_lora_rank, self.flashinfer_extra_state.qk_rope_head_dim, - 1, + self.page_size, False, # causal self.flashinfer_extra_state.softmax_scale, self.flashinfer_extra_state.q_data_type, @@ -97,7 +103,7 @@ def copy_for_cuda_graph(self, new_infer_state): new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.kv_lora_rank, new_infer_state.flashinfer_extra_state.qk_rope_head_dim, - 1, + self.page_size, False, # causal new_infer_state.flashinfer_extra_state.softmax_scale, new_infer_state.flashinfer_extra_state.q_data_type, diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py index ace54bba4..5ca0d3215 100644 --- a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py @@ -551,8 +551,8 @@ def _token_gqa_decode_attention_flashattention( q_nope, q_rope = q[:, :, : -self.qk_rope_head_dim], q[:, :, -self.qk_rope_head_dim :] q_nope = layer_weight.k_b_proj_.bmm(q_nope.transpose(0, 1)).transpose(0, 1) kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, 1, 1, self.qk_rope_head_dim) - kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, 1, 1, self.kv_lora_rank) + k_rope = kv[:, :, -self.qk_rope_head_dim :].reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim) + kv_nope = kv[:, :, : -self.qk_rope_head_dim].reshape(-1, infer_state.page_size, 1, self.kv_lora_rank) k_descale, v_descale = None, None o_tensor = flash_attn_with_kvcache( q=q_rope, @@ -583,11 +583,15 @@ def _token_gqa_decode_attention_flashinfer( kv = infer_state.mem_manager.kv_buffer[self.layer_num_] o_tensor = self.alloc_tensor(q_nope.shape, dtype=q_nope.dtype) + k_nope = kv[:, :, : -self.qk_rope_head_dim] + k_rope = kv[:, :, -self.qk_rope_head_dim :] infer_state.decode_wrapper.run( q_nope, q_rope, - kv[:, :, : -self.qk_rope_head_dim], - kv[:, :, -self.qk_rope_head_dim :], + k_nope if infer_state.page_size == 1 else k_nope.reshape(-1, infer_state.page_size, 1, self.kv_lora_rank), + k_rope + if infer_state.page_size == 1 + else k_rope.reshape(-1, infer_state.page_size, 1, self.qk_rope_head_dim), out=o_tensor, return_lse=False, ) diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py index 9101cb963..2131539fa 100644 --- a/lightllm/models/deepseek2/model.py +++ b/lightllm/models/deepseek2/model.py @@ -21,35 +21,6 @@ logger = init_logger(__name__) -class DeepSeek2FlashInferStateExtraInfo: - def __init__(self, model): - num_heads = model.config["num_attention_heads"] - self.tp_q_head_num = num_heads // get_dp_world_size() - self.qk_nope_head_dim = model.qk_nope_head_dim - self.qk_rope_head_dim = model.qk_rope_head_dim - self.kv_lora_rank = model.kv_lora_rank - self.q_data_type = model.data_type - self.kv_data_type = model.data_type - self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) - self.max_seq_length = model.max_seq_length - self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) - self.kv_indices_buffer = [ - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - torch.empty( - model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() - ), - ] - if model.config["rope_scaling"] is not None: - rope_scaling = model.config["rope_scaling"] - mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) - scaling_factor = rope_scaling["factor"] - if mscale_all_dim: - mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) - self.softmax_scale = self.softmax_scale * mscale * mscale - - @ModelRegistry(["deepseek_v2", "deepseek_v3"]) class Deepseek2TpPartModel(LlamaTpPartModel): # weight class @@ -197,3 +168,32 @@ def _init_to_get_yarn_rotary(self): def _context_forward(self, input_ids, infer_state): predict_logics = super()._context_forward(input_ids, infer_state) return predict_logics + + +class DeepSeek2FlashInferStateExtraInfo: + def __init__(self, model): + num_heads = model.config["num_attention_heads"] + self.tp_q_head_num = num_heads // get_dp_world_size() + self.qk_nope_head_dim = model.qk_nope_head_dim + self.qk_rope_head_dim = model.qk_rope_head_dim + self.kv_lora_rank = model.kv_lora_rank + self.q_data_type = model.data_type + self.kv_data_type = model.data_type + self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8, device=get_current_device_id()) + self.max_seq_length = model.max_seq_length + self.softmax_scale = (self.qk_nope_head_dim + self.qk_rope_head_dim) ** (-0.5) + self.kv_indices_buffer = [ + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + torch.empty( + model.graph_max_batch_size * self.max_seq_length, dtype=torch.int32, device=get_current_device_id() + ), + ] + if model.config["rope_scaling"] is not None: + rope_scaling = model.config["rope_scaling"] + mscale_all_dim = rope_scaling.get("mscale_all_dim", 0) + scaling_factor = rope_scaling["factor"] + if mscale_all_dim: + mscale = get_deepseek_mscale(scaling_factor, mscale_all_dim) + self.softmax_scale = self.softmax_scale * mscale * mscale diff --git a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py index e86d2e819..65a165656 100644 --- a/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py +++ b/lightllm/models/deepseek2/triton_kernel/repack_kv_index.py @@ -33,8 +33,44 @@ def _fwd_kernel_repack_kv_index( return +@triton.jit +def _fwd_kernel_repack_page_kv_index_from_tokens( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + token_stride_h, + SEQ_BLOCK: tl.constexpr, +): + cur_batch = tl.program_id(0) + start_seq_n = tl.program_id(1) + + cur_batch_seq_len = tl.load(seq_len + cur_batch) + cur_batch_req_idx = tl.load(req_index + cur_batch) + cur_batch_start_loc = tl.load(start_loc + cur_batch) + + offs_seq = (start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK)) * page_size + block_end_loc = (tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len)) * page_size + token_data = tl.load( + req_to_token_indexs + token_stride_h * cur_batch_req_idx + offs_seq, + mask=offs_seq < block_end_loc, + other=0, + ) + valid_mask = (token_data % page_size) == 0 + valid_mask = valid_mask & (token_data > 0) # 确保是有效的 token 索引 + page_data = tl.where(valid_mask, token_data // page_size, 0) + + offs_seq = start_seq_n * SEQ_BLOCK + tl.arange(0, SEQ_BLOCK) + block_end_loc = tl.minimum((start_seq_n + 1) * SEQ_BLOCK, cur_batch_seq_len) + out_kv_index_ptr = out_kv_index + cur_batch_start_loc + offs_seq + tl.store(out_kv_index_ptr, page_data, mask=offs_seq < block_end_loc) + return + + @torch.no_grad() -def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv_index): +def repack_kv_index(req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index): batch_size = req_index.shape[0] # flashinfer requires out_kv_index to be zeroed before use out_kv_index.zero_() @@ -44,29 +80,56 @@ def repack_kv_index(kv_index, req_index, seq_len, start_loc, max_seq_len, out_kv triton.cdiv(max_seq_len, BLOCK), ) - _fwd_kernel_repack_kv_index[grid]( - kv_index, - req_index, - out_kv_index, - seq_len, - start_loc, - kv_index.stride(0), - SEQ_BLOCK=BLOCK, - num_warps=8, - num_stages=1, - ) + if page_size > 1: + _fwd_kernel_repack_page_kv_index_from_tokens[grid]( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + page_size, + req_to_token_indexs.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) + else: + _fwd_kernel_repack_kv_index[grid]( + req_to_token_indexs, + req_index, + out_kv_index, + seq_len, + start_loc, + req_to_token_indexs.stride(0), + SEQ_BLOCK=BLOCK, + num_warps=8, + num_stages=1, + ) return -def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output): - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - output[start : start + sl] = req_to_token_indexs[b][:sl] +def repack_kv_ref(req_to_token_indexs, req_index, seq_len, start_loc, max_seq_len, page_size, out_kv_index): + page_indexs = torch.zeros_like(req_to_token_indexs) + valid_mask = req_to_token_indexs % page_size == 0 + batch_size, seq_len_dim = req_to_token_indexs.shape + valid_positions = torch.cumsum(valid_mask.int(), dim=1) - 1 + batch_indices = torch.arange(batch_size, device=req_to_token_indexs.device).unsqueeze(1).expand(-1, seq_len_dim) + page_indexs.view(-1).scatter_add_( + 0, + (batch_indices * seq_len_dim + torch.where(valid_mask, valid_positions, 0)).flatten(), + (torch.where(valid_mask, req_to_token_indexs // page_size, 0) * valid_mask.int()).flatten(), + ) + + for b, sl, start in zip(req_index, seq_len, start_loc): + out_kv_index[start : start + sl] = page_indexs[b][:sl] + return if __name__ == "__main__": import torch.nn.functional as F BATCH, MAX_SEQ_LEN = 10, 1024 + PAGE_SIZE = 64 rand_idx = torch.randperm(2 * MAX_SEQ_LEN * BATCH).cuda().int() b_req_idx = torch.randperm(BATCH).cuda().int() b_seq_len = torch.randint(1, MAX_SEQ_LEN, (BATCH,)).cuda().int() @@ -77,14 +140,31 @@ def repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, output .int() ) + # 为每个batch生成基于page的连续索引 + for b in range(2 * BATCH): + start_page_id = b * 100 # 确保不同batch有不同的page ID范围 + for token_idx in range(2 * MAX_SEQ_LEN): + page_offset = token_idx // PAGE_SIZE + token_in_page = token_idx % PAGE_SIZE + page_id = start_page_id + page_offset + token_index = page_id * PAGE_SIZE + token_in_page + req_to_token_indexs[b, token_idx] = token_index + output = torch.zeros((b_seq_len.sum(),)).cuda().int() ref = torch.zeros((b_seq_len.sum(),)).cuda().int() - for b, sl, start in zip(b_req_idx, b_seq_len, b_start_loc): - req_to_token_indexs[b][:sl] = rand_idx[start : start + sl] - fn1 = lambda: repack_kv_ref(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, ref) - fn2 = lambda: repack_kv_index(req_to_token_indexs, b_req_idx, b_seq_len, b_start_loc, MAX_SEQ_LEN, output) + b_page_len = triton.cdiv(b_seq_len, PAGE_SIZE) + page_output = torch.zeros((b_page_len.sum(),)).cuda().int() + page_ref = torch.zeros((b_page_len.sum(),)).cuda().int() + b_start_loc[1:] = b_page_len.cumsum(0)[:-1] + max_seq_len = triton.cdiv(MAX_SEQ_LEN, PAGE_SIZE) + fn1 = lambda: repack_kv_ref( + req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_seq_len, PAGE_SIZE, page_ref + ) + fn2 = lambda: repack_kv_index( + req_to_token_indexs, b_req_idx, b_page_len, b_start_loc, max_seq_len, PAGE_SIZE, page_output + ) ms1 = triton.testing.do_bench(fn1) ms2 = triton.testing.do_bench_cudagraph(fn2) - print(ms1, ms2) - assert torch.allclose(output.float(), ref.float()) + print(f"repack_kv_index: ref={ms1:.3f}ms, triton={ms2:.3f}ms") + assert torch.allclose(page_output.float(), page_ref.float()) diff --git a/lightllm/models/llama/flashattention_infer_struct.py b/lightllm/models/llama/flashattention_infer_struct.py index 98f628f07..1c62052e3 100644 --- a/lightllm/models/llama/flashattention_infer_struct.py +++ b/lightllm/models/llama/flashattention_infer_struct.py @@ -2,8 +2,9 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.utils.dist_utils import get_current_device_id from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index from lightllm.common.basemodel.batch_objs import ModelInput @@ -14,6 +15,7 @@ class FlashAttentionStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() + self.page_size = get_page_size() @classmethod def get_page_table_buffer(cls, graph_max_batch_size: int, max_seq_len: int): @@ -28,32 +30,33 @@ def _init_flash_attention_state(self, model, input_ids: torch.Tensor): if self.is_prefill: self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() - self.page_table = torch.empty( - (self.batch_size, self.max_seq_len), dtype=torch.int32, device=input_ids.device - ) - self.page_table.copy_(model.req_manager.req_to_token_indexs[self.b_req_idx, : self.max_seq_len]) + length = triton.cdiv(self.max_seq_len, self.page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table.copy_(token_indexs // self.page_size) else: # Meta information of flashattention for decoding self.cu_seqlens_q = self.b1_cu_q_seq_len.int() self.cu_seqlens_k = self.b1_cu_kv_seq_len.int() max_seq_len_k = self.max_kv_seq_len if self.batch_size <= model.graph_max_batch_size and self.max_len_in_batch <= model.graph_max_len_in_batch: - page_buffer = FlashAttentionStateInfo.get_page_table_buffer( - model.graph_max_batch_size, model.graph_max_len_in_batch + length = triton.cdiv(model.graph_max_len_in_batch, self.page_size) + page_buffer = FlashAttentionStateInfo.get_page_table_buffer(model.graph_max_batch_size, length) + self.page_table = page_buffer[self.microbatch_index][: self.batch_size * length].reshape( + self.batch_size, length ) - self.page_table = page_buffer[self.microbatch_index][ - : self.batch_size * model.graph_max_len_in_batch - ].reshape(self.batch_size, model.graph_max_len_in_batch) else: - self.page_table = torch.empty( - (self.batch_size, self.max_len_in_batch), dtype=torch.int32, device=input_ids.device - ) + length = triton.cdiv(self.max_len_in_batch, self.page_size) + self.page_table = torch.empty((self.batch_size, length), dtype=torch.int32, device=input_ids.device) - self.page_table[:, :max_seq_len_k].copy_( - model.req_manager.req_to_token_indexs[self.b_req_idx, :max_seq_len_k], - non_blocking=True, - ) - self.page_table[:, max_seq_len_k:].fill_(0) + length = triton.cdiv(max_seq_len_k, self.page_size) + token_indexs = model.req_manager.req_to_token_indexs[ + self.b_req_idx, : length * self.page_size : self.page_size + ] + self.page_table[:, :length].copy_(token_indexs // self.page_size) + self.page_table[:, length:].fill_(0) if "offline_calibration_fp8kv" in model.mode: if self.is_prefill: diff --git a/lightllm/models/llama/flashinfer_struct.py b/lightllm/models/llama/flashinfer_struct.py index a0c40b57a..1d655d9ff 100644 --- a/lightllm/models/llama/flashinfer_struct.py +++ b/lightllm/models/llama/flashinfer_struct.py @@ -2,8 +2,9 @@ import torch import numpy as np import torch.distributed as dist +import triton from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.models.deepseek2.triton_kernel.repack_kv_index import repack_kv_index @@ -13,6 +14,7 @@ def __init__(self): self.prefill_wrapper = None self.decode_wrapper = None self.flashinfer_extra_state = None + self.page_size = get_page_size() def init_some_extra_state(self, model, input_ids: torch.Tensor): super().init_some_extra_state(model, input_ids) @@ -22,29 +24,33 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): if not self.is_prefill: if get_env_start_args().enable_flashinfer_decode: - self.kv_last_page_len_buffer = torch.full( - (self.batch_size,), 1, dtype=torch.int32, device=input_ids.device - ) + self.kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) if self.batch_size <= model.graph_max_batch_size: self.kv_indices = self.flashinfer_extra_state.kv_indices_buffer[self.microbatch_index][ - : self.batch_size * self.flashinfer_extra_state.max_seq_length + : self.batch_size * length ] else: self.kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) + self.kv_starts = self.b1_cu_kv_seq_len.int() + b_page_len = triton.cdiv(self.b_seq_len, self.page_size) + if self.page_size > 1: + self.kv_starts[1:] = b_page_len.cumsum(0) + self.kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size repack_kv_index( self.req_manager.req_to_token_indexs, self.b_req_idx, - self.b_seq_len, - self.b_start_loc, - self.max_len_in_batch, + b_page_len, + self.kv_starts[:-1], + triton.cdiv(self.max_kv_seq_len, self.page_size), + self.page_size, self.kv_indices, ) - self.kv_starts = self.b1_cu_kv_seq_len.int() if self.decode_wrapper is None: self.decode_wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( self.flashinfer_extra_state.workspace_buffer, @@ -53,16 +59,16 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): use_tensor_cores=True, paged_kv_indptr_buffer=self.kv_starts, paged_kv_indices_buffer=self.kv_indices, - paged_kv_last_page_len_buffer=self.kv_last_page_len_buffer, + paged_kv_last_page_len_buffer=self.kv_last_page_len, ) self.decode_wrapper.plan( self.kv_starts, self.kv_indices, - self.kv_last_page_len_buffer, + self.kv_last_page_len, self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=self.flashinfer_extra_state.q_data_type, kv_data_type=self.flashinfer_extra_state.kv_data_type, non_blocking=True, @@ -72,17 +78,23 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): q_starts = self.b1_cu_q_seq_len.int() kv_starts = self.b1_cu_kv_seq_len.int() kv_last_page_len = torch.full((self.batch_size,), 1, dtype=torch.int32, device=input_ids.device) + length = triton.cdiv(self.flashinfer_extra_state.max_seq_length, self.page_size) kv_indices = torch.empty( - self.batch_size * self.flashinfer_extra_state.max_seq_length, + self.batch_size * length, dtype=torch.int32, device=input_ids.device, ) + b_page_len = triton.cdiv(self.b_seq_len, self.page_size) + if self.page_size > 1: + kv_starts[1:] = b_page_len.cumsum(0) + kv_last_page_len = self.b_seq_len - (b_page_len - 1) * self.page_size repack_kv_index( self.req_manager.req_to_token_indexs, self.b_req_idx, - self.b_seq_len, + b_page_len, kv_starts[:-1], - self.max_kv_seq_len, + triton.cdiv(self.max_kv_seq_len, self.page_size), + self.page_size, kv_indices, ) self.prefill_wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( @@ -100,7 +112,7 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.flashinfer_extra_state.tp_q_head_num, self.flashinfer_extra_state.tp_kv_head_num, self.flashinfer_extra_state.head_dim, - 1, + self.page_size, causal=True, pos_encoding_mode="NONE", logits_soft_cap=0.0, @@ -115,11 +127,11 @@ def copy_for_cuda_graph(self, new_infer_state): self.decode_wrapper.plan( new_infer_state.kv_starts, new_infer_state.kv_indices, - new_infer_state.kv_last_page_len_buffer, + new_infer_state.kv_last_page_len, new_infer_state.flashinfer_extra_state.tp_q_head_num, new_infer_state.flashinfer_extra_state.tp_kv_head_num, new_infer_state.flashinfer_extra_state.head_dim, - 1, + self.page_size, q_data_type=new_infer_state.flashinfer_extra_state.q_data_type, kv_data_type=new_infer_state.flashinfer_extra_state.kv_data_type, non_blocking=True, diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index b00215cff..6191ba745 100755 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -257,8 +257,9 @@ def _context_attention_flashinfer_kernel( self, q, kv, infer_state: LlamaFlashInferStateInfo, layer_weight, out=None ) -> torch.Tensor: o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_] - kv = kv.unsqueeze(1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) infer_state.prefill_wrapper.run( q.view(q.shape[0], -1, self.head_dim_), (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), @@ -319,11 +320,11 @@ def _context_attention_kernel_ppl_int8kv( def _context_attention_flashattention(self, q, kv, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ ) cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] @@ -538,7 +539,9 @@ def _token_decode_attention_flashinfer(self, q, infer_state: LlamaFlashInferStat calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) o_tensor = self.alloc_tensor(q.shape, q.dtype) if out is None else out - kv = infer_state.mem_manager.kv_buffer[self.layer_num_].unsqueeze(1) + kv = infer_state.mem_manager.kv_buffer[self.layer_num_].view( + -1, infer_state.page_size, 2 * self.tp_k_head_num_, self.head_dim_ + ) infer_state.decode_wrapper.run( q.view(calcu_shape1), (kv[:, :, : self.tp_k_head_num_, :], kv[:, :, self.tp_k_head_num_ :, :]), @@ -826,11 +829,11 @@ def _token_decode_attention_gqa_flashdecoding_vsm( def _token_decode_attention_flashattention(self, q, infer_state: FlashAttentionStateInfo, layer_weight, out=None): cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :].reshape( - -1, 1, self.tp_k_head_num_, self.head_dim_ + -1, infer_state.page_size, self.tp_k_head_num_, self.head_dim_ ) cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : - ].reshape(-1, 1, self.tp_v_head_num_, self.head_dim_) + ].reshape(-1, infer_state.page_size, self.tp_v_head_num_, self.head_dim_) q = q.reshape(-1, self.tp_q_head_num_, self.head_dim_) k_descale, v_descale = None, None # disable quantization Lq = q.shape[-1] diff --git a/lightllm/server/router/dynamic_prompt/paged_radix_cache.py b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py new file mode 100644 index 000000000..71f41d38a --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/paged_radix_cache.py @@ -0,0 +1,407 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/managers/router/radix_cache.py +import torch +import numpy as np +from typing import Tuple, Dict, Set, List +from sortedcontainers import SortedSet +from .shared_arr import SharedArray +from lightllm.utils.envs_utils import get_page_size + + +class UniqueTimeIdGenerator: + def __init__(self): + self.counter = 0 + + def generate_time_id(self): + self.counter += 1 + return self.counter + + +time_gen = UniqueTimeIdGenerator() + + +class TreeNode: + def __init__(self): + self.children: Dict[int, TreeNode] = {} # page_hash -> TreeNode + self.parent: TreeNode = None + self.token_id_key: torch.Tensor = None + self.token_mem_index_value: torch.Tensor = None # 用于记录存储的 token_index 为每个元素在 token mem 中的index位置 + self.ref_counter = 0 + self.time_id = time_gen.generate_time_id() # 用于标识时间周期 + + self.node_value_len = 0 + self.node_prefix_total_len = 0 + self.total_children_count = 0 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + def get_compare_key(self): + return (0 if self.ref_counter == 0 else 1, self.total_children_count, self.time_id) + + def _compute_key(self, tokens: torch.Tensor) -> int: + page_tokens = tokens[: self.page_size] + return page_tokens.item() if self.page_size == 1 else hash(page_tokens.cpu().numpy().tobytes()) + + def find_matched_child(self, token_id_key: torch.Tensor) -> Tuple["TreeNode", int]: + target_key = self._compute_key(token_id_key) + if target_key in self.children: + child = self.children[target_key] + prefix_len = match(token_id_key, child.token_id_key) + # 只匹配page_size的整数倍长度 + if self.page_size > 1: + if prefix_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + prefix_len = prefix_len & ~self._page_size_mask + else: + prefix_len = (prefix_len // self.page_size) * self.page_size + if prefix_len == 0: + return None, 0 + return child, prefix_len + + return None, 0 + + def split_node(self, prefix_len): + split_parent_node = TreeNode() + split_parent_node.parent = self.parent + self.parent.children[self._compute_key(self.token_id_key)] = split_parent_node + + split_parent_node.token_id_key = self.token_id_key[0:prefix_len] + split_parent_node.token_mem_index_value = self.token_mem_index_value[0:prefix_len] + split_parent_node.children = {} + + remaining_tokens = self.token_id_key[prefix_len:] + split_parent_node.children[self._compute_key(remaining_tokens)] = self + split_parent_node.ref_counter = self.ref_counter + split_parent_node.total_children_count = 1 + + new_len = len(split_parent_node.token_mem_index_value) + split_parent_node.node_value_len = new_len + split_parent_node.node_prefix_total_len = split_parent_node.parent.node_prefix_total_len + new_len + + self.token_id_key = remaining_tokens + self.token_mem_index_value = self.token_mem_index_value[prefix_len:] + self.parent = split_parent_node + new_len = len(self.token_mem_index_value) + self.node_value_len = new_len + self.node_prefix_total_len = self.parent.node_prefix_total_len + new_len + return split_parent_node + + def add_and_return_new_child(self, token_id_key, token_mem_index_value): + child = TreeNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + + self.children[self._compute_key(token_id_key)] = child + child.parent = self + self.total_children_count += 1 + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = child.parent.node_prefix_total_len + new_len + return child + + def remove_child(self, child_node: "TreeNode"): + del self.children[self._compute_key(child_node.token_id_key)] + child_node.parent = None + self.total_children_count -= 1 + return + + def update_time(self): + self.time_id = time_gen.generate_time_id() + + def is_leaf(self): + return self.total_children_count == 0 + + +def match(t1: torch.Tensor, t2: torch.Tensor) -> int: + # Ensure same shape for comparison: flatten and get min length + t1_flat = t1.flatten() + t2_flat = t2.flatten() + min_len = min(t1_flat.size(0), t2_flat.size(0)) + + # Compare elements and find first mismatch + diff = t1_flat[:min_len] != t2_flat[:min_len] + mismatch_indices = torch.nonzero(diff) + + if mismatch_indices.numel() == 0: + return min_len # All matched up to min_len + else: + return mismatch_indices[0].item() + + +class PagedRadixCache: + """ + unique_name 主要用于解决单机,多实列部署时的shm冲突 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + self.mem_manager = mem_manager + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + # 预计算page_size相关的常量 + self.page_size = get_page_size() + self._page_size_is_power_of_2 = (self.page_size & (self.page_size - 1)) == 0 + self._page_size_mask = self.page_size - 1 if self._page_size_is_power_of_2 else None + + self.root_node = TreeNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 + + self.evict_tree_set: Set[TreeNode] = SortedSet(key=lambda x: x.get_compare_key()) # 自定义比较器 + self.evict_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + + def _get_page_aligned_key(self, key, value=None, free_truncated=False): + aligned_len = len(key) + if aligned_len == 0: + return None, None + # page_size > 1时, 需要确保输入的key长度是page_size的整数倍 + if self.page_size > 1: + if aligned_len % self.page_size != 0: + if self._page_size_is_power_of_2: + # 位运算加速 + aligned_len = aligned_len & ~self._page_size_mask + else: + aligned_len = (aligned_len // self.page_size) * self.page_size + + # 释放被截断的部分 + if free_truncated and aligned_len < len(key) and self.mem_manager is not None: + truncated_value = value[aligned_len:] if value is not None else key[aligned_len:] + if len(truncated_value) > 0: + self.mem_manager.free(truncated_value) + + return ( + key[:aligned_len] if aligned_len > 0 else None, + value[:aligned_len] if value is not None and aligned_len > 0 else None, + ) + return key, value + + def insert(self, key, value=None): + if value is None: + value = key + + assert len(key) == len(value) # and len(key) >= 1 + key, value = self._get_page_aligned_key(key, value, free_truncated=True) + if key is None: + return 0 + return self._insert_helper(self.root_node, key, value) + + def _insert_helper(self, node: TreeNode, key, value): + if node.is_leaf(): + self.evict_tree_set.discard(node) + + try: + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + child.update_time() + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + remaining_key = key[prefix_len:] + remaining_value = value[prefix_len:] + split_parent_node = child.split_node(prefix_len) + new_node = split_parent_node.add_and_return_new_child(remaining_key, remaining_value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + if child.is_leaf(): + self.evict_tree_set.add(child) + return prefix_len + elif prefix_len < len(key) and prefix_len == len(child.token_id_key): + return prefix_len + self._insert_helper(child, key[prefix_len:], value[prefix_len:]) + else: + assert False, "can not run to here" + + new_node = node.add_and_return_new_child(key, value) + # update total token num + self.tree_total_tokens_num.arr[0] += len(new_node.token_mem_index_value) + if new_node.is_leaf(): + self.evict_tree_set.add(new_node) + return 0 + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + key, _ = self._get_page_aligned_key(key) + if key is None: + return None, 0, None + + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + if tree_node != self.root_node: + if len(ans_value_list) != 0: + value = torch.concat(ans_value_list) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + return tree_node, len(value), value + else: + self.dec_node_ref_counter(self.root_node) + return None, 0, None + + def _match_prefix_helper(self, node: TreeNode, key, ans_value_list: list, update_refs=False) -> TreeNode: + if node.is_leaf(): + self.evict_tree_set.discard(node) + + if update_refs: + node.ref_counter += 1 + # from 0 to 1 need update refs token num + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + + try: + if len(key) == 0: + return node + + child, prefix_len = node.find_matched_child(key) + if child is not None: + if prefix_len == len(child.token_id_key): + ans_value_list.append(child.token_mem_index_value) + return self._match_prefix_helper(child, key[prefix_len:], ans_value_list, update_refs=update_refs) + elif prefix_len < len(child.token_id_key): + if child.is_leaf(): + self.evict_tree_set.discard(child) + + split_parent_node = child.split_node(prefix_len) + ans_value_list.append(split_parent_node.token_mem_index_value) + + if update_refs: + split_parent_node.ref_counter += 1 + # from 0 to 1 need update refs token num + if split_parent_node.ref_counter == 1: + self.refed_tokens_num.arr[0] += len(split_parent_node.token_mem_index_value) + + if child.is_leaf(): + self.evict_tree_set.add(child) + if split_parent_node.is_leaf(): + self.evict_tree_set.add(split_parent_node) + + return split_parent_node + else: + assert False, "error state" + + return node + finally: + node.update_time() + if node.is_leaf(): + self.evict_tree_set.add(node) + + def evict(self, need_remove_tokens, evict_callback): + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: TreeNode = self.evict_tree_set.pop(0) + assert node.ref_counter == 0 and node.is_leaf() and node != self.root_node, "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_callback(node.token_mem_index_value) + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + + return + + def assert_leafs_is_right(self): + for node in self.evict_tree_set: + if node.is_leaf() and node.ref_counter == 0: + a = node.token_mem_index_value.cuda() + assert (self.mem_manager.mem_state[a] == 1).sum().item() == len(a) + + def clear_tree_nodes(self): + """ + 该函数只在测试时调用 + """ + while True: + node: TreeNode = self.evict_tree_set.pop(0) + if node != self.root_node: + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + else: + break + + self.tree_total_tokens_num.arr[0] = 0 + self.refed_tokens_num.arr[0] = 0 + return + + def dec_node_ref_counter(self, node: TreeNode): + if node is None: + return + # 如果减引用的是叶节点,需要先从 evict_tree_set 中移除 + old_node = node + if old_node.is_leaf(): + self.evict_tree_set.discard(old_node) + + while node is not None: + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + node = node.parent + + # 加回。 + if old_node.is_leaf(): + self.evict_tree_set.add(old_node) + return + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def print_self(self, indent=0): + self._print_helper(self.root_node, indent) + + def _print_helper(self, node: TreeNode, indent): + print( + " " * indent, + f"k: {node.token_id_key[0:10]} v: {node.token_mem_index_value[0:10]} refs: {node.ref_counter} \ + time_id: {node.time_id} prefix_total_len: {node.node_prefix_total_len} \ + node_value_len: {node.node_value_len}", + ) + for _, child in node.children.items(): + self._print_helper(child, indent=indent + 2) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num): + assert self.mem_manager is not None + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + self.evict(need_evict_token_num, release_mem) + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 01ae6c9c5..a5eb4c8b9 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -288,6 +288,7 @@ def __init__( self.shm_index = shm_index self.multimodal_params = multimodal_params self.vocab_size = vocab_size + self.last_kv_mem_index = -1 # 请求需要被暂停 self.wait_pause = False diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 7c2311d56..80ebb09d7 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,6 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.paged_radix_cache import PagedRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -26,7 +27,7 @@ from lightllm.utils.dist_utils import get_dp_world_size, get_global_dp_rank, get_current_rank_in_dp from lightllm.utils.dist_utils import get_current_device_id, get_current_rank_in_node, get_node_world_size from lightllm.utils.dist_utils import get_dp_rank_in_node, create_new_group_for_current_node -from lightllm.utils.envs_utils import get_env_start_args +from lightllm.utils.envs_utils import get_env_start_args, get_page_size from lightllm.distributed import dist_group_manager from lightllm.server.router.shm_reqs_io_buffer import ShmReqsIOBuffer from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventManager, OverlapEventPack @@ -139,8 +140,9 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + radix_cache_class = PagedRadixCache if get_page_size() > 1 else RadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py index 10090a576..38079cb3f 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_padded_pre_process.py @@ -77,8 +77,19 @@ def padded_prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0] - padded_req_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + input_ids.shape[0] - padded_req_num, b_seq_len, b_ready_cache_len + ) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() + g_infer_state_lock.release() if padded_req_num > 0: @@ -117,6 +128,7 @@ def padded_prepare_decode_inputs( b_req_idx = [] b_mtp_index = [] b_seq_len = [] + b_last_mem_index = [] for req in req_objs: run_reqs.append(req) b_req_idx.append(req.req_idx) @@ -126,6 +138,7 @@ def padded_prepare_decode_inputs( total_token_num += seq_len max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -158,12 +171,18 @@ def padded_prepare_decode_inputs( b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0] - padded_req_num) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0] - padded_req_num) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0] - padded_req_num, b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0] - padded_req_num, b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i] g_infer_state_lock.release() if padded_req_num > 0: diff --git a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py index d5bba1ae5..531f54869 100644 --- a/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py +++ b/lightllm/server/router/model_infer/mode_backend/generic_pre_process.py @@ -55,8 +55,16 @@ def prepare_prefill_inputs( # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(input_ids.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(input_ids.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num( + input_ids.shape[0], b_seq_len, b_ready_cache_len + ) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices(input_ids.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = g_infer_context.req_manager.calc_last_mem_index_in_prefill( + mem_indexes, b_seq_len, b_ready_cache_len + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = b_last_mem_index[i].item() g_infer_state_lock.release() model_input = ModelInput( @@ -85,6 +93,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = [] b_mtp_index = [] b_seq_len = [] + b_last_mem_index = [] for req in req_objs: run_reqs.append(req) b_req_idx.append(req.req_idx) @@ -94,6 +103,7 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In total_token_num += seq_len max_len_in_batch = max(max_len_in_batch, seq_len) b_mtp_index.append(0) + b_last_mem_index.append(req.last_kv_mem_index) # process the draft tokens. for step in range(req.mtp_step): run_reqs.append(req) @@ -107,12 +117,18 @@ def prepare_decode_inputs(req_objs: List[InferReq]) -> Tuple[ModelInput, List[In b_req_idx = torch.tensor(b_req_idx, dtype=torch.int32, device="cpu") b_seq_len = torch.tensor(b_seq_len, dtype=torch.int32, device="cpu") b_mtp_index = torch.tensor(b_mtp_index, dtype=torch.int32, device="cpu") + b_last_mem_index = torch.tensor(b_last_mem_index, dtype=torch.int32, device="cpu") # dynamic prompt cache 准备 token g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(b_seq_len.shape[0]) - mem_indexes = g_infer_context.req_manager.mem_manager.alloc(b_seq_len.shape[0]) + token_num = g_infer_context.req_manager.calc_real_need_token_num(b_seq_len.shape[0], b_seq_len) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(token_num) + mem_indexes = g_infer_context.req_manager.alloc_mem_indices( + b_seq_len.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + for i, req in enumerate(req_objs): + req.last_kv_mem_index = mem_indexes[i] g_infer_state_lock.release() model_input = ModelInput( diff --git a/lightllm/server/router/req_queue/chunked_prefill/impl.py b/lightllm/server/router/req_queue/chunked_prefill/impl.py index f1dae4cac..8fa6248b3 100644 --- a/lightllm/server/router/req_queue/chunked_prefill/impl.py +++ b/lightllm/server/router/req_queue/chunked_prefill/impl.py @@ -3,6 +3,7 @@ from ...batch import Batch, Req from lightllm.server.router.req_queue.base_queue import BaseQueue from lightllm.common.basemodel.infer_lock import g_router_lock +from lightllm.utils.envs_utils import get_page_size class ChunkedPrefillQueue(BaseQueue): @@ -32,9 +33,11 @@ def _can_add_new_req(self, req: Req, is_busy, new_batch_first_router_need_tokens need_max_token_num = (left_out_len_array * size_array + cum_run_len_array).max() with g_router_lock.obj: + page_size = get_page_size() + page_remaining = (len(self.cache_len_list) - 1) * page_size if page_size > 1 else 0 ok_token_num = ( need_max_token_num + self.router.shared_token_load.get_frozened_token_count(self.dp_index) - < self.max_total_tokens + < self.max_total_tokens - page_remaining ) ok_req_num = len(self.cache_len_list) <= self.running_max_req_size diff --git a/lightllm/utils/envs_utils.py b/lightllm/utils/envs_utils.py index 6c9a070e2..3effd1a40 100644 --- a/lightllm/utils/envs_utils.py +++ b/lightllm/utils/envs_utils.py @@ -158,6 +158,11 @@ def set_triton_autotune_level(level: int): return +@lru_cache(maxsize=None) +def get_page_size(): + return int(os.getenv("PAGE_SIZE", 1)) + + g_model_init_done = False diff --git a/test/benchmark/static_inference/model_infer.py b/test/benchmark/static_inference/model_infer.py index b7c07d17a..e0f262f93 100644 --- a/test/benchmark/static_inference/model_infer.py +++ b/test/benchmark/static_inference/model_infer.py @@ -258,7 +258,8 @@ def run_forward_once( b_seq_len[i] = input_len total_token_num = batch_size * input_len - mem_indexes = model_part.req_manager.mem_manager.alloc(test_data.shape[0]) + mem_indexes = model_part.req_manager.alloc_mem_indices(test_data.shape[0], b_seq_len, b_ready_cache_len) + b_last_mem_index = model_part.req_manager.calc_last_mem_index_in_prefill(mem_indexes, b_seq_len, b_ready_cache_len) b_mtp_index = torch.zeros(batch_size, dtype=torch.int32, device="cpu") rank_id = model_kvargs["rank_id"] @@ -321,7 +322,10 @@ def run_forward_once( step_start = time.time() total_token_num += batch_size b_seq_len += 1 - mem_indexes = model_part.req_manager.mem_manager.alloc(predict_ids.shape[0]) + mem_indexes = model_part.req_manager.alloc_mem_indices( + predict_ids.shape[0], b_seq_len, b_last_mem_index=b_last_mem_index + ) + b_last_mem_index = mem_indexes max_len_in_batch = input_len + i + 1 logits = decode_fn( model_part, diff --git a/test/benchmark/static_inference/model_infer_mtp.py b/test/benchmark/static_inference/model_infer_mtp.py index ba90e709b..cdb1f592e 100644 --- a/test/benchmark/static_inference/model_infer_mtp.py +++ b/test/benchmark/static_inference/model_infer_mtp.py @@ -124,7 +124,8 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ b_seq_len[i] = input_len total_token_num = input_len * batch_size - mem_indexes = main_model.req_manager.mem_manager.alloc(test_data.shape[0]).cuda() + mem_indexes = main_model.req_manager.alloc_mem_indices(test_data.shape[0], b_seq_len, b_ready_cache_len).cuda() + b_last_mem_index = main_model.req_manager.calc_last_mem_index_in_prefill(mem_indexes, b_seq_len, b_ready_cache_len) # Main model Prefill model_input = ModelInput( batch_size=batch_size, @@ -191,7 +192,9 @@ def run_forward_once(args, input_len, output_len, batch_size, main_model, draft_ nopad_b_seq_idx = torch.tensor(nopad_b_seq_idx, dtype=torch.int32, device="cuda") nopad_b_seq_len = torch.tensor(nopad_b_seq_len, dtype=torch.int32, device="cuda") - mem_indexes = main_model.req_manager.mem_manager.alloc(batch_size * (len(draft_models) + 1)).cuda() + mem_indexes = main_model.req_manager.alloc_mem_indices( + batch_size * (len(draft_models) + 1), nopad_b_seq_len, b_last_mem_index=b_last_mem_index + ).cuda() model_input = ModelInput( batch_size=batch_size * (len(draft_models) + 1), diff --git a/unit_tests/models/llama/test_context_flashattention_nopad.py b/unit_tests/models/llama/test_context_flashattention_nopad.py index f24ab619b..94e61cfda 100644 --- a/unit_tests/models/llama/test_context_flashattention_nopad.py +++ b/unit_tests/models/llama/test_context_flashattention_nopad.py @@ -10,7 +10,6 @@ context_attention_fwd_no_prompt_cache, ) from lightllm.models.llama.infer_struct import LlamaInferStateInfo -from lightllm.common.req_manager import ReqManager logger = init_logger(__name__) @@ -56,8 +55,6 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.batch_size = Z infer_state.max_len_in_batch = N_CTX infer_state.total_token_num = Z * N_CTX - infer_state.req_manager = ReqManager(Z, N_CTX, None) - infer_state.req_manager.req_to_token_indexs = req_to_token_indexs infer_state.b_req_idx = b_req_idx infer_state.b_seq_len = b_seq_len infer_state.b_ready_cache_len = b_ready_cache_len @@ -73,7 +70,7 @@ def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): infer_state.b_seq_len, infer_state.b_ready_cache_len, infer_state.max_len_in_batch, - infer_state.req_manager.req_to_token_indexs, + req_to_token_indexs, ) batch_size = Z diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py new file mode 100644 index 000000000..e7702f084 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_fa3_paged.py @@ -0,0 +1,163 @@ +import torch +import time +import pytest +import triton as tl +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.utils.sgl_utils import flash_attn_with_kvcache +from lightllm.common.basemodel.triton_kernel.q_per_head_fp8_quant import q_per_head_fp8_quant + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)).to(torch.float32) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + if N_CTX > 1: + b_ready_cache_len = torch.randint_like(b_seq_len, high=(N_CTX - 1) // 2, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32, device="cuda") + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :] + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :] + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.reshape(-1, 1, kv_heads, head_dim), + v_cache=v_cache.reshape(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o, o1, atol=1e-2, rtol=1e-2) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(f"cos_sim1: {cos_sim1}") + assert cos_sim1.item() == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=N_CTX, + causal=True, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(f"cos_sim2: {cos_sim2}") + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_context_attention_fwd_fa3_fp8(32, 16384, 32, 4, 128) diff --git a/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..763a80015 --- /dev/null +++ b/unit_tests/models/llama/test_context_flashattention_nopad_flashinfer_paged.py @@ -0,0 +1,214 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, + context_attention_fwd_no_prompt_cache, +) +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_ready_cache_len = torch.zeros_like(b_seq_len, dtype=torch.int32, device="cuda") + b_ready_cache_len = torch.randint_like(b_seq_len, high=N_CTX - 1, dtype=torch.int32, device="cuda") + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + q_lens = b_seq_len - b_ready_cache_len + q_start_loc = q_lens.cumsum(0) - q_lens + + q = torch.randn((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((q_lens.sum(), Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_ready_cache_len = b_ready_cache_len + infer_state.b_start_loc = q_start_loc + + context_attention_fwd( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len - b_ready_cache_len, dim=0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + + total_pages = num_pages_per_seq.sum().item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + + # 设置kv_indices + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.prefill.BatchPrefillWithPagedKVCacheWrapper( + workspace_buffer, + qo_indptr_buf=q_indptr, + paged_kv_indptr_buf=kv_indptr, + paged_kv_indices_buf=kv_indices, + paged_kv_last_page_len_buf=kv_last_page_len_buffer, + ) + + # 设置kv_last_page_len + kv_last_page_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + wrapper.plan( + q_indptr, + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + causal=True, + pos_encoding_mode="NONE", + logits_soft_cap=0.0, + q_data_type=q.dtype, + kv_data_type=kv.dtype, + ) + k_cache = kv[:, :, :KV_HEADS, :] + v_cache = kv[:, :, KV_HEADS:, :] + wrapper.run(q, (k_cache, v_cache), out=o1, return_lse=False) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_context_attention_fwd_no_prompt_cache(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + q = torch.randn((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + k = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + v = torch.randn((Z * N_CTX, KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + + o = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z * N_CTX, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + context_attention_fwd_no_prompt_cache( + q, + k, + v, + o, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + q_starts = torch.zeros((Z + 1,)).int().cuda() + q_starts[1:] = torch.cumsum(b_seq_len, dim=0) + kv_starts = torch.zeros_like(q_starts) + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_indptr = q_starts.int() + kv_indptr = kv_starts.int() + wrapper = flashinfer.prefill.BatchPrefillWithRaggedKVCacheWrapper( + workspace_buffer, + ) + wrapper.plan( + qo_indptr=q_indptr, + kv_indptr=kv_indptr, + num_qo_heads=q_heads, + num_kv_heads=kv_heads, + head_dim_qk=head_dim, + head_dim_vo=head_dim, + q_data_type=dtype, + causal=True, + ) + wrapper.run(q, k, v, out=o1, return_lse=False) + + # assert torch.allclose(o, o1, atol=1e-2, rtol=0) + cos_sim1 = F.cosine_similarity(o, o1).mean() + assert cos_sim1 == 1.0 + + +if __name__ == "__main__": + test_context_attention_fwd(32, 16384, 32, 4, 128) # 16384 is divisible by 4 diff --git a/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py new file mode 100644 index 000000000..1de2fbc34 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_fa3_paged.py @@ -0,0 +1,186 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd +from lightllm.utils.sgl_utils import flash_attn_with_kvcache + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def kv_quantize_per_head_fp8(kv_buffer: torch.Tensor, seq_lens): + device = kv_buffer.device + B = seq_lens.size(0) + min_fp8 = torch.finfo(torch.float8_e4m3fn).min + max_fp8 = torch.finfo(torch.float8_e4m3fn).max + _, S_max, H, D = kv_buffer.shape + seq_range = torch.arange(S_max, device=device)[None, :] + valid_mask = (seq_range < seq_lens[:, None]).view(B, S_max, 1, 1) + masked = kv_buffer * valid_mask + max_per_bh = masked.float().abs().amax(dim=(1, 3)) # [B, H] + scales = torch.where(max_per_bh > 0, max_per_bh / max_fp8, torch.ones_like(max_per_bh)) + scales_exp = scales.view(B, 1, H, 1) + q = (kv_buffer / scales_exp).clamp(min_fp8, max_fp8).to(torch.float8_e4m3fn) + return q, scales + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad_fa3_fp8(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * (N_CTX // 2) + rand_num = torch.randint_like(b_seq_len, high=(N_CTX // 2), dtype=torch.int32, device="cuda") + b_seq_len += rand_num + b_start_loc = b_seq_len.cumsum(0) - b_seq_len + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + kv_starts = torch.zeros((Z + 1,)).int().cuda() + kv_starts[1:] = torch.cumsum(b_seq_len, dim=0) + q_starts = torch.arange(0, Z + 1).int().cuda() + page_table = torch.empty((batch_size, N_CTX), dtype=torch.int32).to(0) + page_table.copy_(req_to_token_indexs[b_req_idx, :N_CTX]) + + k_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :].contiguous() + v_cache = kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :].contiguous() + o1 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache.view(-1, 1, kv_heads, head_dim), + v_cache=v_cache.view(-1, 1, kv_heads, head_dim), + page_table=page_table, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + # assert torch.allclose(o, o1, atol=1e-1, rtol=1e-1) + cos_sim1 = F.cosine_similarity(o, o1).mean() + print(cos_sim1) + assert cos_sim1 == 1 + + k_cache_paged = k_cache.reshape(-1, page_size, kv_heads, head_dim) + v_cache_paged = v_cache.reshape(-1, page_size, kv_heads, head_dim) + + page_table_paged = torch.empty((batch_size, N_CTX // page_size), dtype=torch.int32, device="cuda") + page_table_paged.copy_(req_to_page_indexs[b_req_idx, : N_CTX // page_size]) + + o2 = flash_attn_with_kvcache( + q=q, + k_cache=k_cache_paged, + v_cache=v_cache_paged, + page_table=page_table_paged, + cache_seqlens=infer_state.b_seq_len, + cu_seqlens_q=q_starts, + cu_seqlens_k_new=kv_starts, + max_seqlen_q=1, + causal=False, + window_size=(-1, -1), + softcap=0.0, + return_softmax_lse=False, + ) + + assert torch.allclose(o1, o2, atol=1e-2, rtol=1e-2) + cos_sim2 = F.cosine_similarity(o1, o2).mean() + print(cos_sim2) + assert cos_sim2.item() == 1 + + +if __name__ == "__main__": + test_token_attention_nopad_fa3_fp8(16, 16384, 28, 4, 128) diff --git a/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py new file mode 100644 index 000000000..9bb97be99 --- /dev/null +++ b/unit_tests/models/llama/test_token_attention_nopad_flashinfer_paged.py @@ -0,0 +1,169 @@ +import torch +import time +import pytest +import numpy as np +import torch.nn.functional as F +import flashinfer +from lightllm.utils.log_utils import init_logger +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.triton_kernel.gqa_decode_flashattention_nopad import gqa_decode_attention_fwd + +logger = init_logger(__name__) + +seed = 42 +torch.manual_seed(seed) + +if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + +def ref_token_attention_nopad(q, k, v, o, q_h, h_dim, infer_state, req_to_token_indexs): + from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd + + total_token_num = infer_state.total_token_num + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, q_h, h_dim) + + att_m_tensor = torch.empty((q_h, total_token_num), dtype=torch.float32).cuda() + + token_att_fwd( + q.view(calcu_shape1), + k, + att_m_tensor, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + from lightllm.models.llama.triton_kernel.token_attention_softmax_and_reducev import ( + token_softmax_reducev_fwd, + ) + + token_softmax_reducev_fwd( + att_m_tensor, + v, + o, + req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + ) + return o + + +@pytest.mark.parametrize( + "batch, seqlen, q_heads, kv_heads, head_dim", + [ + (a, b, c, d, e) + for a in [1, 16, 32, 128, 512] + for b in [16, 32, 512, 1024] + for c in [28] + for d in [4] + for e in [128] + ], +) +def test_token_attention_nopad(batch, seqlen, q_heads, kv_heads, head_dim): + Z, N_CTX, Q_HEADS, KV_HEADS, HEAD_DIM = batch, seqlen, q_heads, kv_heads, head_dim + dtype = torch.bfloat16 + page_size = 4 + q = torch.randn((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + kv = torch.randn((Z * N_CTX // page_size, page_size, 2 * KV_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + max_input_len = Z * N_CTX + req_to_page_indexs = ( + torch.randperm(max_input_len // page_size, dtype=torch.int32).cuda().view(Z, N_CTX // page_size) + ) + req_to_token_indexs = ( + req_to_page_indexs.unsqueeze(-1) * page_size + torch.arange(page_size, dtype=torch.int32, device="cuda") + ).reshape(Z, N_CTX) + + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") * N_CTX + b_start_loc = torch.arange(Z).cuda().int() * N_CTX + b_req_idx = torch.randperm(Z, dtype=torch.int32).cuda() + + o = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + o1 = torch.zeros((Z, Q_HEADS, HEAD_DIM), dtype=dtype, device="cuda") + + infer_state = LlamaInferStateInfo() + infer_state.batch_size = Z + infer_state.max_len_in_batch = N_CTX + infer_state.total_token_num = Z * N_CTX + infer_state.b_req_idx = b_req_idx + infer_state.b_seq_len = b_seq_len + infer_state.b_start_loc = b_start_loc + + ref_token_attention_nopad( + q, + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, :KV_HEADS, :], + kv.view(-1, 2 * KV_HEADS, HEAD_DIM)[:, KV_HEADS:, :], + o, + Q_HEADS, + HEAD_DIM, + infer_state, + req_to_token_indexs, + ) + # gqa_decode_attention_fwd( + # q, + # kv[:,:KV_HEADS,:], + # kv[:,KV_HEADS:,:], + # o, + # req_to_token_indexs, + # infer_state.b_req_idx, + # infer_state.b_seq_len, + # ) + + batch_size = Z + head_dim = HEAD_DIM + q_heads = Q_HEADS + kv_heads = KV_HEADS + workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0) + + num_pages_per_seq = torch.ceil(b_seq_len.float() / page_size).int() + kv_indptr = torch.zeros(Z + 1, dtype=torch.int32, device="cuda") + kv_indptr[1:] = torch.cumsum(num_pages_per_seq, dim=0) + + # Fill the paged KV data indices + total_pages = kv_indptr[-1].item() + kv_indices = torch.zeros(total_pages, dtype=torch.int32, device="cuda") + b_start_loc = num_pages_per_seq.cumsum(0) - num_pages_per_seq + for req, sl, start in zip(b_req_idx, num_pages_per_seq, b_start_loc): + kv_indices[start : start + sl] = req_to_page_indexs[req][:sl] + + # Calculate last page lengths + kv_last_page_len = torch.zeros(Z, dtype=torch.int32, device="cuda") + for i in range(Z): + seq_len = b_seq_len[i].item() + remainder = seq_len % page_size + kv_last_page_len[i] = remainder if remainder > 0 else page_size + + kv_last_page_len_buffer = torch.empty(batch_size, device="cuda:0", dtype=torch.int32) + wrapper = flashinfer.decode.BatchDecodeWithPagedKVCacheWrapper( + workspace_buffer, + "NHD", + use_cuda_graph=True, + use_tensor_cores=True, + paged_kv_indptr_buffer=kv_indptr, + paged_kv_indices_buffer=kv_indices, + paged_kv_last_page_len_buffer=kv_last_page_len_buffer, + ) + wrapper.plan( + kv_indptr, + kv_indices, + kv_last_page_len, + q_heads, + kv_heads, + head_dim, + page_size, + q_data_type=dtype, + non_blocking=True, + ) + wrapper.run(q, (kv[:, :, :KV_HEADS, :], kv[:, :, KV_HEADS:, :]), out=o1, return_lse=False) + cos_sim = F.cosine_similarity(o, o1).mean() + assert cos_sim == 1.0 + + +if __name__ == "__main__": + test_token_attention_nopad(32, 16384, 32, 4, 128)