From 11b012fd72d8e3b61992098b1cfa6c1acb6d13ea Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Dec 2025 11:25:28 +0000 Subject: [PATCH 01/19] qwen3next --- lightllm/common/basemodel/cuda_graph.py | 17 +- .../basemodel/layer_weights/hf_load_utils.py | 2 +- .../layer_weights/meta_weights/__init__.py | 1 + .../meta_weights/parameter_weight.py | 44 + .../kv_cache_mem_manager/mem_manager.py | 170 +-- ...8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 38 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 38 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 7 + ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 83 ++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 83 ++ .../{topk_num=10}_NVIDIA_H200.json | 38 + ...rch.bfloat16,topk_num=10}_NVIDIA_H200.json | 56 + ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 110 ++ lightllm/models/__init__.py | 1 + lightllm/models/qwen2/model.py | 2 +- .../qwen3next/layer_infer/post_layer_infer.py | 16 + .../layer_infer/transformer_layer_infer.py | 345 ++++++ .../layer_weights/transformer_layer_weight.py | 177 +++ lightllm/models/qwen3next/mem_manager.py | 152 +++ lightllm/models/qwen3next/model.py | 97 ++ lightllm/models/qwen3next/req_manager.py | 42 + .../qwen3next/triton_kernel/causal_conv1d.py | 1057 +++++++++++++++++ .../qwen3next/triton_kernel/fla/__init__.py | 16 + .../triton_kernel/fla/ops/__init__.py | 15 + .../qwen3next/triton_kernel/fla/ops/chunk.py | 225 ++++ .../triton_kernel/fla/ops/chunk_delta_h.py | 257 ++++ .../triton_kernel/fla/ops/chunk_o.py | 167 +++ .../fla/ops/chunk_scaled_dot_kkt.py | 136 +++ .../qwen3next/triton_kernel/fla/ops/cumsum.py | 200 ++++ .../triton_kernel/fla/ops/fused_recurrent.py | 367 ++++++ .../qwen3next/triton_kernel/fla/ops/index.py | 30 + .../qwen3next/triton_kernel/fla/ops/l2norm.py | 137 +++ .../qwen3next/triton_kernel/fla/ops/op.py | 36 + .../triton_kernel/fla/ops/solve_tril.py | 271 +++++ .../qwen3next/triton_kernel/fla/ops/utils.py | 173 +++ .../triton_kernel/fla/ops/wy_fast.py | 122 ++ .../triton_kernel/fused_gdn_gating.py | 83 ++ .../qwen3next/triton_kernel/gated_rmsnorm.py | 174 +++ .../qwen3next/triton_kernel/gemma_rmsnorm.py | 144 +++ lightllm/server/api_cli.py | 1 + lightllm/server/api_http.py | 18 +- lightllm/server/api_models.py | 4 +- lightllm/server/api_server.py | 26 +- lightllm/server/api_start.py | 33 +- lightllm/server/core/objs/start_args_type.py | 59 +- .../dynamic_prompt/hybrid_radix_cache.py | 420 +++++++ .../router/dynamic_prompt/radix_cache.py | 4 +- .../server/router/model_infer/infer_batch.py | 11 +- .../model_infer/mode_backend/__init__.py | 1 + .../model_infer/mode_backend/base_backend.py | 7 +- .../impl_for_hybrid_radix_cache.py | 112 ++ .../server/router/model_infer/model_rpc.py | 10 +- lightllm/utils/device_utils.py | 2 +- test/benchmark/service/benchmark_gsm8k.py | 231 ++++ test/test_api/test_chat.py | 183 +++ 56 files changed, 6142 insertions(+), 116 deletions(-) create mode 100644 lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json create mode 100644 lightllm/models/qwen3next/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/qwen3next/mem_manager.py create mode 100644 lightllm/models/qwen3next/model.py create mode 100644 lightllm/models/qwen3next/req_manager.py create mode 100644 lightllm/models/qwen3next/triton_kernel/causal_conv1d.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/index.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/op.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py create mode 100644 lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py create mode 100644 lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py create mode 100644 test/benchmark/service/benchmark_gsm8k.py create mode 100644 test/test_api/test_chat.py diff --git a/lightllm/common/basemodel/cuda_graph.py b/lightllm/common/basemodel/cuda_graph.py index c754fabce..f09417dad 100644 --- a/lightllm/common/basemodel/cuda_graph.py +++ b/lightllm/common/basemodel/cuda_graph.py @@ -3,6 +3,7 @@ import copy import bisect from typing import Optional +from tqdm import tqdm from lightllm.utils.log_utils import init_logger from lightllm.utils.envs_utils import get_env_start_args from lightllm.distributed import dist_group_manager, lightllm_capture_graph, CustomProcessGroup @@ -191,7 +192,12 @@ def warmup(self, model): model: TpPartBaseModel = model # decode cuda graph init - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing CUDA graphs") + for batch_size in progress_bar: + # Get available memory info + avail_mem, total_mem = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description(f"Capturing CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB") seq_len = 2 total_token_num = batch_size * seq_len max_len_in_batch = self.graph_max_len_in_batch @@ -246,7 +252,14 @@ def warmup_overlap(self, model): model: TpPartBaseModel = model - for batch_size in self.cuda_graph_batch_sizes[::-1]: + progress_bar = tqdm(self.cuda_graph_batch_sizes[::-1], desc="Capturing overlap CUDA graphs") + for batch_size in progress_bar: + # Get available memory info + avail_mem, total_mem = torch.cuda.mem_get_info() + avail_mem_gb = avail_mem / (1024 ** 3) + progress_bar.set_description( + f"Capturing overlap CUDA graphs - Batch: {batch_size}, AvailMem: {avail_mem_gb:.2f}GB" + ) decode_batches = [] for micro_batch_index in [0, 1]: # dummy decoding, capture the cudagraph diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index 8cf66a5ad..bb2d9aec4 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -60,7 +60,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 1)) + worker = int(os.environ.get("LOADWORKER", 16)) with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py index b3dab0614..cb815cd86 100644 --- a/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py +++ b/lightllm/common/basemodel/layer_weights/meta_weights/__init__.py @@ -9,3 +9,4 @@ from .norm_weight import NormWeight, GEMMANormWeight, TpNormWeight from .fused_moe_weight_tp import create_tp_moe_wegiht_obj from .fused_moe_weight_ep import FusedMoeWeightEP +from .parameter_weight import ParameterWeight, TpParameterWeight diff --git a/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py new file mode 100644 index 000000000..65adcd469 --- /dev/null +++ b/lightllm/common/basemodel/layer_weights/meta_weights/parameter_weight.py @@ -0,0 +1,44 @@ +import torch +from typing import Dict +from .base_weight import BaseWeightTpl +from lightllm.utils.dist_utils import get_current_device_id + + +class ParameterWeight(BaseWeightTpl): + def __init__(self, weight_name: str, data_type: torch.dtype, bias_name: str = None): + super().__init__() + self.weight_name = weight_name + self.bias_name = bias_name + self.data_type_ = data_type + self.weight = None + self.bias = None + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + if self.weight_name in weights: + self.weight = weights[self.weight_name].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights: + self.bias = weights[self.bias_name].to(self.data_type_).cuda(get_current_device_id()) + + def verify_load(self): + load_ok = True + # Verify weight. The weight must be not None. + load_ok = load_ok and self.weight is not None + # Verify bias. If bias_name is set, it must be not None. + if self.bias_name is not None: + load_ok = load_ok and self.bias is not None + return load_ok + + +class TpParameterWeight(ParameterWeight): + def __init__(self, weight_name: str, data_type: torch.dtype, split_n_embed: int, bias_name: str = None): + super().__init__(weight_name, data_type, bias_name) + self.split_n_embed = split_n_embed + + def load_hf_weights(self, weights: Dict[str, torch.Tensor]) -> None: + start = self.split_n_embed * self.tp_rank_ + end = self.split_n_embed * (self.tp_rank_ + 1) + + if self.weight_name in weights: + self.weight = weights[self.weight_name][start:end].to(self.data_type_).cuda(get_current_device_id()) + if self.bias_name in weights: + self.bias = weights[self.bias_name][start:end].to(self.data_type_).cuda(get_current_device_id()) diff --git a/lightllm/common/kv_cache_mem_manager/mem_manager.py b/lightllm/common/kv_cache_mem_manager/mem_manager.py index d8fd93009..23d5a5dfa 100755 --- a/lightllm/common/kv_cache_mem_manager/mem_manager.py +++ b/lightllm/common/kv_cache_mem_manager/mem_manager.py @@ -24,16 +24,9 @@ logger = init_logger(__name__) -class MemoryManager: - def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9): +class BaseAllocator: + def __init__(self, size, mem_manager_name=None): self.size = size - self.head_num = head_num - self.head_dim = head_dim - self.layer_num = layer_num - self.always_copy = always_copy - self.dtype = dtype - # profile the max total token num if the size is None - self.profile_size(mem_fraction) self.mem_state = torch.arange( 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True @@ -48,14 +41,95 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False self.can_use_mem_size = self.size # 用共享内存进行共享,router 模块读取进行精确的调度估计, nccl port 作为一个单机中单实列的标记。防止冲突。 - from lightllm.utils.envs_utils import get_unique_server_name - + if mem_manager_name is None: + mem_manager_name = get_unique_server_name() rank_in_node = get_current_rank_in_node() - self.shared_can_use_token_num = SharedInt( - f"{get_unique_server_name()}_mem_manger_can_use_token_num_{rank_in_node}" - ) + self.shared_can_use_token_num = SharedInt(f"{mem_manager_name}_mem_manger_can_use_token_num_{rank_in_node}") self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.HOLD_TOKEN_MEMINDEX = self.size + + def alloc(self, need_size) -> torch.Tensor: + if need_size > self.mark_end - self.mark_start: + logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") + assert False, "error alloc state" + + start = self.mark_start + end = self.mark_start + need_size + self.mark_start += need_size + + self.can_use_mem_size -= need_size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + # 利用缓冲区返回,避免异步情况下的内存竞争 + if self._return_start + need_size > self._mem_state_return.shape[0]: + self._return_start = 0 + ans = self._mem_state_return[self._return_start : self._return_start + need_size] + ans.copy_(self.mem_state[start:end]) + self._return_start += need_size + return ans + + def free(self, free_index: Union[torch.Tensor, List[int]]): + """_summary_ + + Args: + free_index (torch.Tensor): _description_ + """ + end = self.mark_start + start = self.mark_start - len(free_index) + assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" + + if isinstance(free_index, list): + self.mem_state.numpy()[start:end] = free_index + else: + # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 + self.mem_state[start:end] = free_index + + self.mark_start -= len(free_index) + + self.can_use_mem_size += len(free_index) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + + if self.can_use_mem_size == len(self.mem_state): + logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") + return + + def free_all(self): + self.can_use_mem_size = len(self.mem_state) + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) + self.mark_start = 0 + self.mark_end = len(self.mem_state) + + def resize_mem(self, new_size): + """ + just for test code + """ + self.size = new_size + self.mem_state = torch.arange( + 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True + ) + self.mark_start = 0 + self.mark_end = self.size + self.can_use_mem_size = self.size + self.shared_can_use_token_num.set_value(self.can_use_mem_size) + return + + +class MemoryManager(BaseAllocator): + def __init__( + self, size, dtype, head_num, head_dim, layer_num, always_copy=False, mem_fraction=0.9, mem_manager_name=None + ): + self.size = size + self.head_num = head_num + self.head_dim = head_dim + self.layer_num = layer_num + self.always_copy = always_copy + self.dtype = dtype + # profile the max total token num if the size is None + self.profile_size(mem_fraction) + super().__init__(self.size, mem_manager_name) + self._init_buffers( self.size, dtype, @@ -63,7 +137,6 @@ def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=False head_dim, layer_num, ) - self.HOLD_TOKEN_MEMINDEX = self.size def get_cell_size(self): return 2 * self.head_num * self.head_dim * self.layer_num * torch._utils._element_size(self.dtype) @@ -326,59 +399,13 @@ def _write_kv_move_data_p2p(self, token_indexes: torch.Tensor, buffer_tensor: to def _free_buffers(self): self.kv_buffer = None - def alloc(self, need_size) -> torch.Tensor: - if need_size > self.mark_end - self.mark_start: - logger.error(f"warn no enough cache need_size {need_size} left_size {self.can_use_mem_size}") - assert False, "error alloc state" - - start = self.mark_start - end = self.mark_start + need_size - self.mark_start += need_size - - self.can_use_mem_size -= need_size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - # 利用缓冲区返回,避免异步情况下的内存竞争 - if self._return_start + need_size > self._mem_state_return.shape[0]: - self._return_start = 0 - ans = self._mem_state_return[self._return_start : self._return_start + need_size] - ans.copy_(self.mem_state[start:end]) - self._return_start += need_size - return ans - - def free(self, free_index: Union[torch.Tensor, List[int]]): - """_summary_ - - Args: - free_index (torch.Tensor): _description_ - """ - - end = self.mark_start - start = self.mark_start - len(free_index) - assert start >= 0, f"error free state start: {self.mark_start} free len {len(free_index)}" - - if isinstance(free_index, list): - self.mem_state.numpy()[start:end] = free_index - else: - # 从 gpu 到 cpu 的拷贝操作是流内阻塞操作 - self.mem_state[start:end] = free_index - - self.mark_start -= len(free_index) - - self.can_use_mem_size += len(free_index) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - - if self.can_use_mem_size == len(self.mem_state): - logger.debug(f"freed all gpu mem size {self.can_use_mem_size}") - return + def get_index_kv_buffer(self, index): + return {"kv_buffer": self.kv_buffer[:, index]} - def free_all(self): - self.can_use_mem_size = len(self.mem_state) - self.shared_can_use_token_num.set_value(self.can_use_mem_size) - self.mem_state.numpy()[:] = list(range(0, len(self.mem_state))) - self.mark_start = 0 - self.mark_end = len(self.mem_state) + def load_index_kv_buffer(self, index, load_tensor_dict): + self.kv_buffer[:, index].copy_(load_tensor_dict["kv_buffer"]) + # 重写resize_mem方法,添加_free_buffers和_init_buffers调用 def resize_mem(self, new_size): """ just for test code @@ -389,14 +416,9 @@ def resize_mem(self, new_size): head_dim = self.head_dim layer_num = self.layer_num - self.size = new_size - self.mem_state = torch.arange( - 0, self.size, dtype=torch.int32, device="cpu", requires_grad=False, pin_memory=True - ) - self.mark_start = 0 - self.mark_end = self.size - self.can_use_mem_size = self.size - self.shared_can_use_token_num.set_value(self.can_use_mem_size) + # 调用父类的resize_mem + super().resize_mem(new_size) + self._free_buffers() self._init_buffers(size, dtype, head_num, head_dim, layer_num) return diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..b139a72ba --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "100": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "1024": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "128": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "16": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 16, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "64": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "8": { + "BLK_HEADS": 64, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..0b388b1a8 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1024": { + "BLOCK_N": 256, + "num_warps": 2 + }, + "128": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "2048": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "256": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "512": { + "BLOCK_N": 512, + "num_warps": 4 + }, + "64": { + "BLOCK_N": 256, + "num_warps": 1 + }, + "8": { + "BLOCK_N": 512, + "num_warps": 8 + }, + "800": { + "BLOCK_N": 64, + "num_warps": 1 + }, + "8192": { + "BLOCK_N": 128, + "num_warps": 2 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..10685c2e2 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "2048": { + "BLOCK_SIZE": 2048, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..3863d48e8 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=256,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "256": { + "BLOCK_SIZE": 256, + "num_stages": 1, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..dac851f69 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,83 @@ +{ + "10": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "1000": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "10240": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1280": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "160": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "2560": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "320": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "640": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 64, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json new file mode 100644 index 000000000..b38406dc3 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -0,0 +1,83 @@ +{ + "1": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 32, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 4, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "1024": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "256": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE_K": 128, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 2, + "num_warps": 4 + }, + "8": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 16, + "BLOCK_SIZE_N": 64, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 5, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json new file mode 100644 index 000000000..8e0ff1cf8 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "100": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "1024": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "128": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "16": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, + "256": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "32": { + "BLOCK_SIZE": 128, + "num_warps": 4 + }, + "64": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, + "8": { + "BLOCK_SIZE": 256, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json new file mode 100644 index 000000000..e459d2f32 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -0,0 +1,56 @@ +{ + "1": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "100": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 16 + }, + "1024": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 4, + "num_warps": 2 + }, + "128": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 2, + "num_warps": 16 + }, + "256": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "32": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, + "64": { + "BLOCK_DIM": 512, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, + "8": { + "BLOCK_DIM": 256, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json new file mode 100644 index 000000000..1a1eb3f74 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -0,0 +1,110 @@ +{ + "1": { + "BLOCK_M": 32, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "10": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "100": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "1000": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "1024": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, + "10240": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "128": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 4 + }, + "1280": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "16": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 8 + }, + "160": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "256": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 8 + }, + "2560": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "32": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "320": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "64": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 4, + "num_warps": 8 + }, + "640": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "8": { + "BLOCK_M": 1, + "BLOCK_N": 32, + "NUM_STAGES": 2, + "num_warps": 4 + }, + "80": { + "BLOCK_M": 1, + "BLOCK_N": 64, + "NUM_STAGES": 2, + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/models/__init__.py b/lightllm/models/__init__.py index 5237f8fd2..7798c10bf 100644 --- a/lightllm/models/__init__.py +++ b/lightllm/models/__init__.py @@ -8,6 +8,7 @@ from lightllm.models.qwen2.model import Qwen2TpPartModel from lightllm.models.qwen3.model import Qwen3TpPartModel from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.model import Qwen3NextTpPartModel from lightllm.models.chatglm2.model import ChatGlm2TpPartModel from lightllm.models.internlm.model import InternlmTpPartModel from lightllm.models.stablelm.model import StablelmTpPartModel diff --git a/lightllm/models/qwen2/model.py b/lightllm/models/qwen2/model.py index 5b756aadf..e64ea495c 100644 --- a/lightllm/models/qwen2/model.py +++ b/lightllm/models/qwen2/model.py @@ -17,7 +17,7 @@ def __init__(self, kvargs): def _init_config(self): super()._init_config() - if self.config["sliding_window"] is None: + if self.config.get("sliding_window") is None: self.config["sliding_window"] = self.max_total_token_num # rename key [SYM: to be confirmed] return diff --git a/lightllm/models/qwen3next/layer_infer/post_layer_infer.py b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py new file mode 100644 index 000000000..73ea97457 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/post_layer_infer.py @@ -0,0 +1,16 @@ +import os +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np + +from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + + +class Qwen3NextPostLayerInfer(LlamaPostLayerInfer): + def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.final_norm_weight_, self.eps_, out=out) + return out diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py new file mode 100644 index 000000000..0f49cba25 --- /dev/null +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -0,0 +1,345 @@ +import torch +import torch.nn.functional as F +import torch.distributed as dist +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import Qwen3NextTransformerLayerWeight +from lightllm.models.qwen3_moe.layer_infer.transformer_layer_infer import Qwen3MOETransformerLayerInfer +from functools import partial +from lightllm.utils.log_utils import init_logger +from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager +from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager +from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from typing import Tuple +from typing_extensions import override +from einops import rearrange +from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward +from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update +from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating +from lightllm.models.qwen3next.triton_kernel.fla.ops.chunk import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import fused_recurrent_gated_delta_rule +from lightllm.distributed import all_reduce +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward + +logger = init_logger(__name__) + + +class Qwen3NextTransformerLayerInfer(Qwen3MOETransformerLayerInfer): + def __init__(self, layer_num, network_config, mode=[]): + super().__init__(layer_num, network_config, mode) + self.is_linear = (layer_num + 1) % network_config["full_attention_interval"] != 0 + self.partial_rotary_factor = network_config.get("partial_rotary_factor", 1.0) + + if self.is_linear: + self.linear_attn_infer = Qwen3NextGatedDeltaNetInfer(network_config, layer_num, self.tp_world_size_) + return + + @override + def _bind_norm(self): + pass + + def _ffn_with_shared_expert( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + input = input.view(-1, self.embed_dim_) + up_gate_out = layer_weight.shared_expert_gate_up_proj.mm(input) + ffn1_out = self.alloc_tensor((input.size(0), up_gate_out.size(1) // 2), input.dtype) + silu_and_mul_fwd(up_gate_out, ffn1_out) + ffn2_out = layer_weight.shared_expert_down_proj.mm(ffn1_out) + shared_expert_out = F.sigmoid(layer_weight.shared_expert_gate.mm(input)) * ffn2_out + moe_out = self._ffn(input, infer_state, layer_weight) + return shared_expert_out + moe_out + + @override + def _att_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.att_norm_weight_.weight, self.eps_, out=out) + return out + + @override + def _ffn_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + out = self.alloc_tensor(input.shape, input.dtype) + gemma_rmsnorm_forward(input, layer_weight.ffn_norm_weight_.weight, self.eps_, out=out) + return out + + @override + def _get_qkv( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + ) -> Tuple[torch.Tensor, torch.Tensor]: + input = input.view(-1, self.embed_dim_) + q = layer_weight.q_proj.mm(input) + cache_kv = layer_weight.kv_proj.mm( + input.view(-1, self.embed_dim_), + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + + gemma_rmsnorm_forward( + q.view(-1, self.head_dim_), + layer_weight.q_norm_weight_.weight, + eps=self.eps_, + out=q.view(-1, self.head_dim_), + ) + + cache_kv[:, : self.tp_k_head_num_, :] = gemma_rmsnorm_forward( + cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), + layer_weight.k_norm_weight_.weight, + eps=self.eps_, + ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) + + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + partial_rotary_factor=self.partial_rotary_factor, + ) + return q, cache_kv + + @override + def _get_o( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ) -> torch.Tensor: + input = input * layer_weight._gate + layer_weight._gate = None + o_tensor = layer_weight.o_proj.mm(input) + return o_tensor + + def _context_full_attn( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + q, cache_kv = self._get_qkv(input, infer_state, layer_weight) + input = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def context_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + if self.is_linear: + o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=True, infer_cls=self) + else: + layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) + o = self._context_full_attn(input1, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + def _token_full_attn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight): + q, cache_kv = self._get_qkv(input, infer_state, layer_weight) + input = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.tp_world_size_ > 1: + all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + return o + + def token_forward( + self, input_embdings, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + if self.is_linear: + o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=False, infer_cls=self) + else: + layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) + o = self._token_full_attn(input1, infer_state, layer_weight) + input_embdings.add_(o.view(-1, self.embed_dim_)) + o = None + + input1 = self._ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn_with_shared_expert(input1, infer_state, layer_weight) + input1 = None + if self.tp_world_size_ > 1: + all_reduce(ffn_out, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return input_embdings + + +class Qwen3NextGatedDeltaNetInfer: + def __init__(self, network_config, layer_idx, tp_world_size_): + self.network_config_ = network_config + self.layer_idx_ = layer_idx + self.tp_world_size_ = tp_world_size_ + self.num_v_heads = self.network_config_["linear_num_value_heads"] + self.num_k_heads = self.network_config_["linear_num_key_heads"] + self.head_k_dim = self.network_config_["linear_key_head_dim"] + self.head_v_dim = self.network_config_["linear_value_head_dim"] + self.key_dim = self.head_k_dim * self.num_k_heads + self.value_dim = self.head_v_dim * self.num_v_heads + self.conv_kernel_dim = self.network_config_["linear_conv_kernel_dim"] + self.activation = self.network_config_["hidden_act"] + self.tp_qkvz_dim = (self.key_dim * 2 + self.value_dim * 2) // self.tp_world_size_ + self.tp_ba_dim = (self.num_v_heads * 2) // self.tp_world_size_ + self.tp_num_k_heads = self.num_k_heads // self.tp_world_size_ + self.tp_num_v_heads = self.num_v_heads // self.tp_world_size_ + self.tp_key_dim = self.key_dim // self.tp_world_size_ + self.tp_value_dim = self.value_dim // self.tp_world_size_ + assert self.num_v_heads % self.num_k_heads == 0, "num_v_heads must be divisible by num_k_heads" + self.num_v_heads_per_k_head = self.num_v_heads // self.num_k_heads + + def _fix_query_key_value_ba_ordering(self, mixed_qkvzba): + """ + Derives `query`, `key` and `value` tensors from `mixed_qkvzba`. + """ + mixed_qkvz, mixed_ba = torch.split(mixed_qkvzba, [self.tp_qkvz_dim, self.tp_ba_dim], dim=-1) + + mixed_qkvz = mixed_qkvz.view( + -1, + self.tp_num_k_heads, + self.head_k_dim + self.head_k_dim + (self.head_v_dim + self.head_v_dim) * self.num_v_heads_per_k_head, + ) + mixed_ba = mixed_ba.view(-1, self.tp_num_k_heads, 2 * self.num_v_heads_per_k_head) + + qkvz_split_list = [ + self.head_k_dim, + self.head_k_dim, + (self.num_v_heads_per_k_head * self.head_v_dim), + (self.num_v_heads_per_k_head * self.head_v_dim), + ] + (query, key, value, z) = torch.split(mixed_qkvz, qkvz_split_list, dim=2) + (b, a) = torch.split(mixed_ba, [self.num_v_heads_per_k_head, self.num_v_heads_per_k_head], dim=2) + + query = query.reshape(-1, self.tp_num_k_heads * self.head_k_dim) + key = key.reshape(-1, self.tp_num_k_heads * self.head_k_dim) + value = value.reshape(-1, self.tp_num_v_heads * self.head_v_dim) + z = z.reshape(-1, self.tp_num_v_heads, self.head_v_dim) + b = b.reshape(-1, self.tp_num_v_heads) + a = a.reshape(-1, self.tp_num_v_heads) + + return query, key, value, z, b, a + + def _rearrange_mixed_qkv(self, mixed_qkv): + if mixed_qkv is None: + return None, None, None + query, key, value = torch.split( + mixed_qkv, + [self.tp_key_dim, self.tp_key_dim, self.tp_value_dim], + dim=-1, + ) + query, key = map(lambda x: rearrange(x, "l (h d) -> 1 l h d", d=self.head_k_dim), (query, key)) + value = rearrange(value, "l (h d) -> 1 l h d", d=self.head_v_dim) + return query, key, value + + def _linear_attn( + self, + input: torch.Tensor, + infer_state: LlamaInferStateInfo, + layer_weight: Qwen3NextTransformerLayerWeight, + is_prefill: bool, + infer_cls: Qwen3NextTransformerLayerInfer, + ): + assert layer_weight.is_linear, "layer_weight must be linear" + assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager) + assert isinstance(infer_state.req_manager, Qwen3NextReqManager) + input = input.view(-1, infer_cls.embed_dim_) + buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx] + conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_) + + mixed_qkvzba = layer_weight.linear_in_proj.mm(input) + q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba) + mixed_qkv = torch.cat([q, k, v], dim=-1) + + if is_prefill: + mixed_qkv = mixed_qkv.transpose(0, 1) + out_tensor = infer_cls.alloc_tensor(mixed_qkv.shape, mixed_qkv.dtype, device=mixed_qkv.device) + causal_conv1d_fn( + mixed_qkv, + layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), + layer_weight.linear_conv1d.mm_param.bias, + conv_states.transpose(1, 2), + infer_state.b1_cu_q_seq_len, + out=out_tensor, + cache_indices=buffer_idx, + activation=self.activation, + ) + mixed_qkv = out_tensor.transpose(0, 1) + else: + mixed_qkv = causal_conv1d_update( + mixed_qkv, + conv_states.transpose(1, 2), + layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), + layer_weight.linear_conv1d.mm_param.bias, + self.activation, + conv_state_indices=buffer_idx, + validate_data=True, + ) + + # Rearrange mixed_qkv to query, key, value + query, key, value = self._rearrange_mixed_qkv(mixed_qkv) + + # Compute beta and g + beta = b.sigmoid() + g = fused_gdn_gating(layer_weight.linear_A_log.weight, a, layer_weight.linear_dt_bias.weight) + g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) + + if is_prefill: + initial_state = ssm_states[buffer_idx].contiguous() + (core_attn_out, last_recurrent_state,) = chunk_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=initial_state, + output_final_state=True, + cu_seqlens=infer_state.b1_cu_q_seq_len, + head_first=False, + use_qk_l2norm_in_kernel=True, + ) + # Update SSM state with final state + ssm_states[buffer_idx, ...] = last_recurrent_state.to(ssm_states.dtype) + else: + batch_size = input.shape[0] + cu_seqlens = torch.arange(0, batch_size + 1, dtype=torch.int32, device=input.device) + (core_attn_out, last_recurrent_state,) = fused_recurrent_gated_delta_rule( + q=query, + k=key, + v=value, + g=g, + beta=beta, + initial_state=ssm_states, + inplace_final_state=True, + cu_seqlens=cu_seqlens, + ssm_state_indices=buffer_idx, + use_qk_l2norm_in_kernel=True, + ) + + z_shape_og = z.shape + core_attn_out = core_attn_out.reshape(-1, core_attn_out.shape[-1]) + z = z.reshape(-1, z.shape[-1]) + norm_out = infer_cls.alloc_tensor(core_attn_out.shape, core_attn_out.dtype, device=core_attn_out.device) + gated_rmsnorm_forward( + core_attn_out, + layer_weight.linear_norm.weight, + layer_weight.linear_norm.bias, + infer_cls.eps_, + z, + out=norm_out, + ) + core_attn_out = norm_out.reshape(z_shape_og) + core_attn_out = rearrange(core_attn_out, "... h d -> ... (h d)") + + output = layer_weight.linear_out_proj.mm(core_attn_out) + if infer_cls.tp_world_size_ > 1: + all_reduce(output, group=infer_state.dist_group, op=dist.ReduceOp.SUM, async_op=False) + return output diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py new file mode 100644 index 000000000..da6168593 --- /dev/null +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -0,0 +1,177 @@ +import os +import torch +import math +import numpy as np +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.models.qwen3_moe.layer_weights.transformer_layer_weight import Qwen3MOETransformerLayerWeight +from lightllm.utils.envs_utils import enable_env_vars +from lightllm.common.basemodel.layer_weights.meta_weights import ( + ROWMMWeight, + COLMMWeight, + NormWeight, +) +from functools import partial +from typing_extensions import override +from lightllm.common.basemodel.layer_weights.meta_weights import TpParameterWeight + + +class Qwen3NextTransformerLayerWeight(Qwen3MOETransformerLayerWeight): + def __init__(self, layer_num, data_type, network_config, mode=[], quant_cfg=None): + super().__init__(layer_num, data_type, network_config, mode, quant_cfg) + return + + @override + def _parse_config(self): + super()._parse_config() + self.full_attention_interval = self.network_config_["full_attention_interval"] + self.is_linear = (self.layer_num_ + 1) % self.full_attention_interval != 0 + if self.is_linear: + self._parse_linear_config() + return + + @override + def _init_weight(self): + self._init_moe() + self._init_shared_expert_weight() + + self.att_norm_weight_ = NormWeight( + self._att_norm_weight_name, self.data_type_, bias_name=self._att_norm_bias_name + ) + self.ffn_norm_weight_ = NormWeight( + self._ffn_norm_weight_name, self.data_type_, bias_name=self._ffn_norm_bias_name + ) + + if self.is_linear: + self._init_linear_weight() + else: + self._init_qkv() + self._init_o() + self.q_norm_weight_ = NormWeight(weight_name=self._q_norm_name, data_type=self.data_type_) + self.k_norm_weight_ = NormWeight(weight_name=self._k_norm_name, data_type=self.data_type_) + self._o_gate_weight_name = f"model.layers.{self.layer_num_}.self_attn.o_gate_proj.weight" + self.o_gate_proj = ROWMMWeight( + weight_names=self._o_gate_weight_name, + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="o_gate_proj", + ) + self._gate = None + return + + @override + def load_hf_weights(self, weights): + if self.is_linear: + linear_conv1d_weight_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.weight" + linear_conv1d_bias_name = f"model.layers.{self.layer_num_}.linear_attn.conv1d.bias" + if linear_conv1d_weight_name in weights: + weights[linear_conv1d_weight_name] = self._parse_linear_conv1d( + weights[linear_conv1d_weight_name].squeeze(1) + ) + if linear_conv1d_bias_name in weights: + weights[linear_conv1d_bias_name] = self._parse_linear_conv1d(weights[linear_conv1d_bias_name]) + else: + self._split_q_with_gate(weights) + super().load_hf_weights(weights) + + def _init_shared_expert_weight(self): + prefix = f"model.layers.{self.layer_num_}.mlp.shared_expert" + self.shared_expert_gate_up_proj = ROWMMWeight( + weight_names=[f"{prefix}.gate_proj.weight", f"{prefix}.up_proj.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate_up_proj", + ) + self.shared_expert_down_proj = COLMMWeight( + weight_names=f"{prefix}.down_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_down_proj", + ) + self.shared_expert_gate = ROWMMWeight( + weight_names=f"model.layers.{self.layer_num_}.mlp.shared_expert_gate.weight", + data_type=self.data_type_, + bias_names=None, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="shared_expert_gate", + tp_rank=0, + tp_world_size=1, + ) + + def _split_q_with_gate(self, weights): + if self._q_weight_name in weights: + weight = weights[self._q_weight_name] + num_heads = self.tp_q_head_num_ * self.tp_world_size_ + weight = weight.view(num_heads * 2, self.head_dim, -1) + _q_proj = weight[0::2].reshape(-1, weight.shape[-1]) + _gate_proj = weight[1::2].reshape(-1, weight.shape[-1]) + weights[self._q_weight_name] = _q_proj + weights[self._o_gate_weight_name] = _gate_proj + + def _parse_linear_conv1d(self, weight): + qk_dim = self.linear_num_k_heads * self.linear_k_head_dim + v_dim = self.linear_num_v_heads * self.linear_v_head_dim + + q_bias, k_bias, v_bias = torch.split(weight, [qk_dim, qk_dim, v_dim], dim=0) + q_splits = q_bias.chunk(self.tp_world_size_, dim=0) + k_splits = k_bias.chunk(self.tp_world_size_, dim=0) + v_splits = v_bias.chunk(self.tp_world_size_, dim=0) + + new_weight = torch.cat( + [torch.cat([q_splits[i], k_splits[i], v_splits[i]], dim=0) for i in range(self.tp_world_size_)], dim=0 + ) + + return new_weight + + def _parse_linear_config(self): + self.linear_num_v_heads = self.network_config_["linear_num_value_heads"] + self.linear_num_k_heads = self.network_config_["linear_num_key_heads"] + self.linear_k_head_dim = self.network_config_["linear_key_head_dim"] + self.linear_v_head_dim = self.network_config_["linear_value_head_dim"] + + def _init_linear_weight(self): + prefix = f"model.layers.{self.layer_num_}.linear_attn" + self.linear_conv1d = ROWMMWeight( + weight_names=f"{prefix}.conv1d.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="conv1d_weight", + ) + + self.linear_in_proj = ROWMMWeight( + weight_names=[f"{prefix}.in_proj_qkvz.weight", f"{prefix}.in_proj_ba.weight"], + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="in_proj_weight", + ) + + self.linear_out_proj = COLMMWeight( + weight_names=f"{prefix}.out_proj.weight", + data_type=self.data_type_, + quant_cfg=self.quant_cfg, + layer_num=self.layer_num_, + name="out_proj_weight", + ) + + self.linear_dt_bias = TpParameterWeight( + weight_name=f"{prefix}.dt_bias", + data_type=torch.float32, + split_n_embed=self.linear_num_v_heads // self.tp_world_size_, + ) + + self.linear_A_log = TpParameterWeight( + weight_name=f"{prefix}.A_log", + data_type=torch.float32, + split_n_embed=self.linear_num_v_heads // self.tp_world_size_, + ) + + self.linear_norm = NormWeight( + weight_name=f"{prefix}.norm.weight", + data_type=self.data_type_, + ) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py new file mode 100644 index 000000000..ab8fdd6a6 --- /dev/null +++ b/lightllm/models/qwen3next/mem_manager.py @@ -0,0 +1,152 @@ +import torch +import numpy as np +from typing import Dict, List, Protocol, Set, Union, Tuple, Optional +from typing_extensions import override +from lightllm.utils.log_utils import init_logger +from lightllm.common.kv_cache_mem_manager.mem_manager import BaseAllocator, MemoryManager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.model_infer.infer_batch import InferReq +from lightllm.utils.envs_utils import get_unique_server_name +from lightllm.utils.dist_utils import get_current_rank_in_node +from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt + +logger = init_logger(__name__) + + +class LayerCacheMemoryManager(BaseAllocator): + def __init__(self, size: int, dtype: torch.dtype, shape: Tuple[int, ...], layer_num: int, mem_manager_nmae: str): + super().__init__(size, mem_manager_nmae) + + self.dtype = dtype + self.shape = shape + self.layer_num = layer_num + + self._init_buffers( + self.size, + dtype, + shape, + ) + + def _init_buffers(self, size, dtype, shape): + self.buffer = torch.zeros((self.layer_num, size + 1, *shape), dtype=dtype, device="cuda") + + def get_cell_size(self): + return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) + + +class HaveStateBuffer(Protocol): + def alloc_state_cache_buffer(self, need_size): + ... + + def free_state_cache_buffer(self, free_buffer_indexes): + ... + + def get_state_cache_buffer(self, layer_index): + ... + + def get_state_cache_can_use_size(self): + ... + + +class Qwen3NextMemoryManager(MemoryManager, HaveStateBuffer): + def __init__( + self, + full_attn_cache_size, + linear_attn_cache_size, + dtype, + num_kv_heads, + head_dim, + layer_num, + mtp_layer_num, + full_attention_interval: int, + conv_state_dtype: torch.dtype, + conv_state_shape: Tuple[int, ...], + ssm_state_dtype: torch.dtype, + ssm_state_shape: Tuple[int, ...], + always_copy=False, + mem_fraction=0.9, + ): + self.full_attention_interval = full_attention_interval + + assert layer_num % full_attention_interval == 0 + self.layer_num = layer_num + self.mtp_layer_num = mtp_layer_num + self.full_attn_layer_num = layer_num // full_attention_interval + self.linear_attn_layer_num = layer_num - self.full_attn_layer_num + + self.conv_state_dtype = conv_state_dtype + self.conv_state_shape = conv_state_shape + self.ssm_state_dtype = ssm_state_dtype + self.ssm_state_shape = ssm_state_shape + + assert linear_attn_cache_size is not None + self.conv_state_mem_manager = LayerCacheMemoryManager( + linear_attn_cache_size, conv_state_dtype, conv_state_shape, self.linear_attn_layer_num, "conv_state" + ) + self.ssm_state_mem_manager = LayerCacheMemoryManager( + linear_attn_cache_size, ssm_state_dtype, ssm_state_shape, self.linear_attn_layer_num, "ssm_state" + ) + logger.info( + f"Linear attention state cache size: {linear_attn_cache_size}\n" + f"Conv state use : " + f"{self.conv_state_mem_manager.get_cell_size() * linear_attn_cache_size / 1024 ** 3} GB Memory.\n" + f"Ssm state use : " + f"{self.ssm_state_mem_manager.get_cell_size() * linear_attn_cache_size / 1024 ** 3} GB Memory.\n" + ) + self.EMPTY_BUFFER_INDEX = -1 + self.HOLD_BUFFER_INDEX = self.conv_state_mem_manager.HOLD_TOKEN_MEMINDEX + super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) + + @override + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + # kv_buffer = [None, None, None, kv_cache, None, None, None, kv_cache, ..., + # None, kv_cache, mtp_kv_cache, mtp_kv_cache] + self.kv_buffer = [None for _ in range(self.layer_num)] + for layer_id in range(self.full_attn_layer_num): + self.kv_buffer[(layer_id + 1) * self.full_attention_interval - 1] = torch.empty( + (size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda" + ) + + for _ in range(self.mtp_layer_num): + self.kv_buffer.append(torch.empty((size + 1, 2 * head_num, head_dim), dtype=dtype, device="cuda")) + + @override + def free_all(self): + super().free_all() + self.conv_state_mem_manager.free_all() + self.ssm_state_mem_manager.free_all() + return + + @override + def get_state_cache_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]: + assert layer_index < self.layer_num, "layer_index is out of range" + assert (layer_index + 1) % self.full_attention_interval != 0, "layer_index is not linear attention layer" + real_layer_index = layer_index - layer_index // self.full_attention_interval + return self.conv_state_mem_manager.buffer[real_layer_index], self.ssm_state_mem_manager.buffer[real_layer_index] + + @override + def free_state_cache_buffer(self, free_buffer_indexes: List[int], reset=True): + # conv_state 和 ssm_state 共享buffer_idx + self.conv_state_mem_manager.free(free_buffer_indexes) + if reset: + self.conv_state_mem_manager.buffer[:, free_buffer_indexes] = 0 + self.ssm_state_mem_manager.buffer[:, free_buffer_indexes] = 0 + + @override + def alloc_state_cache_buffer(self, need_size): + # conv_state 和 ssm_state 共享buffer_idx + buffer_indexes = self.conv_state_mem_manager.alloc(need_size) + return buffer_indexes + + @override + def get_state_cache_can_use_size(self): + return self.conv_state_mem_manager.can_use_mem_size + + @override + def copy_state_cache_buffer(self, src_idx, tgt_idx): + assert src_idx is not None and tgt_idx is not None + assert src_idx != tgt_idx + # Use slice operation and in-place copy for better performance + self.conv_state_mem_manager.buffer[:, tgt_idx].copy_(self.conv_state_mem_manager.buffer[:, src_idx]) + self.ssm_state_mem_manager.buffer[:, tgt_idx].copy_(self.ssm_state_mem_manager.buffer[:, src_idx]) + return diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py new file mode 100644 index 000000000..27289fc19 --- /dev/null +++ b/lightllm/models/qwen3next/model.py @@ -0,0 +1,97 @@ +import torch +from typing_extensions import override +from lightllm.models.registry import ModelRegistry +from lightllm.models.qwen3_moe.model import Qwen3MOEModel +from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import Qwen3NextTransformerLayerWeight +from lightllm.models.qwen3next.layer_infer.transformer_layer_infer import Qwen3NextTransformerLayerInfer +from lightllm.models.qwen3next.layer_infer.post_layer_infer import Qwen3NextPostLayerInfer +from lightllm.utils.log_utils import init_logger +from lightllm.distributed.communication_op import dist_group_manager +from lightllm.utils.envs_utils import get_env_start_args +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager +from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager +from lightllm.server.core.objs.start_args_type import StartArgs + +logger = init_logger(__name__) + + +@ModelRegistry("qwen3_next") +class Qwen3NextTpPartModel(Qwen3MOEModel): + # weight class + transformer_weight_class = Qwen3NextTransformerLayerWeight + + # infer class + transformer_layer_infer_class = Qwen3NextTransformerLayerInfer + post_layer_infer_class = Qwen3NextPostLayerInfer + + def __init__(self, kvargs) -> None: + super().__init__(kvargs) + + @override + def autotune_layers(self): + return self.config["full_attention_interval"] + + @override + def _init_config(self): + super()._init_config() + self.num_kv_heads = max(self.config["num_key_value_heads"] // self.tp_world_size_, 1) + + @override + def _init_custom(self): + super()._init_custom() + dist_group_manager.new_deepep_group(self.config["num_experts"], self.config["hidden_size"]) + + @override + def _init_mem_manager(self): + assert self.config["num_attention_heads"] % self.tp_world_size_ == 0 + + start_args: StartArgs = get_env_start_args() + + mtp_step = start_args.mtp_step + linear_attn_cache_size = start_args.linear_attn_cache_size + if linear_attn_cache_size is not None: + assert ( + linear_attn_cache_size >= start_args.running_max_req_size + ), "linear_attn_cache_size must be greater than running_max_req_size" + + self.num_linear_k_heads = self.config["linear_num_key_heads"] + self.num_linear_v_heads = self.config["linear_num_value_heads"] + self.head_linear_k_dim = self.config["linear_key_head_dim"] + self.head_linear_v_dim = self.config["linear_value_head_dim"] + + conv_kernel_size = self.config["linear_conv_kernel_dim"] + conv_dim = ( + self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads + ) + + self.mem_manager = Qwen3NextMemoryManager( + full_attn_cache_size=self.max_total_token_num, + linear_attn_cache_size=linear_attn_cache_size, + dtype=self.data_type, + num_kv_heads=self.num_kv_heads, + head_dim=self.config["head_dim"], + layer_num=self.config["n_layer"], + mtp_layer_num=start_args.mtp_step, + full_attention_interval=self.config["full_attention_interval"], + conv_state_dtype=self.data_type, + conv_state_shape=(conv_kernel_size - 1 + mtp_step, conv_dim // self.tp_world_size_), + ssm_state_dtype=self.data_type, + ssm_state_shape=( + # mtp_step + 1, + self.num_linear_v_heads // self.tp_world_size_, + self.head_linear_k_dim, + self.head_linear_v_dim, + ), + mem_fraction=self.mem_fraction, + ) + + @override + def _init_req_manager(self): + create_max_seq_len = 0 + + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self.req_manager = Qwen3NextReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) diff --git a/lightllm/models/qwen3next/req_manager.py b/lightllm/models/qwen3next/req_manager.py new file mode 100644 index 000000000..31df6883f --- /dev/null +++ b/lightllm/models/qwen3next/req_manager.py @@ -0,0 +1,42 @@ +from typing import List, Dict +from typing_extensions import override +import torch + +from lightllm.common.req_manager import ReqManager +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager +from lightllm.utils.envs_utils import get_env_start_args + + +class Qwen3NextReqManager(ReqManager): + def __init__(self, max_request_num, max_sequence_length, mem_manager: Qwen3NextMemoryManager): + super().__init__(max_request_num, max_sequence_length, mem_manager) + self.mem_manager: Qwen3NextMemoryManager = self.mem_manager + self.enable_dynamic_prompt_cache = not get_env_start_args().disable_dynamic_prompt_cache + + self.req_to_buffer_indexes = torch.zeros((max_request_num + 1), dtype=torch.int32, device="cuda") + self.req_to_buffer_indexes[:] = self.mem_manager.EMPTY_BUFFER_INDEX + self.req_to_buffer_indexes[self.HOLD_REQUEST_ID] = self.mem_manager.HOLD_BUFFER_INDEX + + @override + def alloc(self): + from lightllm.server.router.model_infer.infer_batch import g_infer_state_lock, g_infer_context + + req_idx = super().alloc() + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(0, 1) + self.req_to_buffer_indexes[req_idx] = self.mem_manager.alloc_state_cache_buffer(1) + g_infer_state_lock.release() + return req_idx + + @override + def free(self, free_req_indexes: List[int], free_token_index): + super().free(free_req_indexes, free_token_index) + self.req_to_buffer_indexes[free_req_indexes] = self.mem_manager.EMPTY_BUFFER_INDEX + + @override + def free_all(self): + super().free_all() + self.req_to_buffer_indexes[:] = self.mem_manager.EMPTY_BUFFER_INDEX + self.req_to_buffer_indexes[self.HOLD_REQUEST_ID] = self.mem_manager.HOLD_BUFFER_INDEX + return diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py new file mode 100644 index 000000000..202ce7460 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -0,0 +1,1057 @@ +# SPDX-License-Identifier: Apache-2.0 +# Adapted from +# https://github.com/vllm-project/vllm/blob/v0.11.0rc1/vllm/model_executor/layers/mamba/ops/causal_conv1d.py +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +# Copyright (c) 2024, Tri Dao. +# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py + +from typing import Optional, Union + +import numpy as np +import torch + +import triton +import triton.language as tl + +PAD_SLOT_ID = -1 + + +@triton.jit() +def _causal_conv1d_fwd_kernel( # continuous batching + # Pointers to matrices + x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences + w_ptr, # (dim, width) + bias_ptr, + initial_states_ptr, # conv_states_ptr + cache_indices_ptr, # conv_state_indices_ptr + has_initial_states_ptr, + query_start_loc_ptr, + batch_ptr, + token_chunk_offset_ptr, + o_ptr, # (dim, seqlen) - actually pointing to x_ptr + # Matrix dimensions + batch: tl.int32, # actually padded_batch + dim: tl.constexpr, + seqlen: tl.int32, # cu_seqlen + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, # stride to get to next sequence, + stride_x_dim: tl.constexpr, # stride to get to next feature-value, + stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) + stride_w_dim: tl.constexpr, # stride to get to next dim-axis value + stride_w_width: tl.constexpr, # stride to get to next width-axis value + stride_istate_seq: tl.constexpr, + stride_istate_dim: tl.constexpr, + stride_istate_token: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + HAS_INITIAL_STATES: tl.constexpr, + HAS_CACHE: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + NP2_STATELEN: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + conv_states_ptr = initial_states_ptr + conv_state_indices_ptr = cache_indices_ptr + stride_conv_state_seq = stride_istate_seq + stride_conv_state_dim = stride_istate_dim + stride_conv_state_tok = stride_istate_token + state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value + + # one program handles one chunk in a single sequence + # rather than mixing sequences - to make updating initial_states across sequences efficiently + + # single-sequence id + idx_seq = tl.load(batch_ptr + tl.program_id(0)) + chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) + + # BLOCK_N elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if idx_seq == pad_slot_id: + return + + sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) + sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) + # find the actual sequence length + seqlen = sequence_end_index - sequence_start_index + + token_offset = BLOCK_M * chunk_offset + segment_len = min(BLOCK_M, seqlen - token_offset) + + # base of the sequence + x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] + + if IS_CONTINUOUS_BATCHING: + # cache_idx + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) + else: + # cache_idx + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + conv_states_base = ( + conv_states_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + + # Does 2 things: + # 1. READ prior-block init-state data - [done by every Triton programs] + # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] + if chunk_offset == 0: + # read from conv_states + load_init_state = False + if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES + load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) + if load_init_state: + # load from conv_states + prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH == 5: + conv_states_ptrs = prior_tokens # [BLOCK_N] + # col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + else: + # prior-tokens are zeros + if KERNEL_WIDTH >= 2: # STRATEGY1 + # first chunk and does not have prior-token, so just set to 0 + col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 3: # STRATEGY1 + col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + if KERNEL_WIDTH >= 4: # STRATEGY1 + col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + # if KERNEL_WIDTH >= 5: # STRATEGY1 + # col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) + + # STEP 2: + # here prepare data for updating conv_state + if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) + # just read from 'x' + # copy 'x' data to conv_state + # load only 'x' data (and set 0 before 'x' if seqlen < state_len) + idx_tokens_last = (seqlen - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] + x_ptrs = ( + x_ptr + + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] + + (idx_feats * stride_x_dim)[None, :] + ) # [BLOCK_M,BLOCK_N,] + mask_x = ( + (idx_tokens_last >= 0)[:, None] & (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + conv_states_ptrs_target = conv_states_base[None, :] + (idx_tokens_conv * stride_conv_state_tok)[:, None] + + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.debug_barrier() # NOTE: use this due to bug in Triton compiler + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: + if load_init_state: + # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + conv_states_ptrs_source = ( + conv_states_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens_conv + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + # need this due to the bug in tl.where not enforcing this + # when data is the result of another tl.load + tl.debug_barrier() + new_conv_state = tl.where( + mask, conv_state, loaded_x + ) # BUG in 'tl.where' which requires a barrier before this + conv_states_ptrs_target = ( + conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + else: # load_init_state == False + # update conv_state by shifting left, BUT + # set cols prior to 'x' as zeros + cols from 'x' + idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + VAL = state_len - seqlen + + x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens_conv - VAL >= 0)[:, None] + & (idx_tokens_conv - VAL < seqlen)[:, None] + & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + new_conv_state = tl.load(x_ptrs, mask_x, 0.0) + + conv_states_ptrs_target = ( + conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_states_ptrs_target, new_conv_state, mask) + + else: # chunk_offset > 0 + # read prior-token data from `x` + load_init_state = True + prior_tokens = x_base + (token_offset - 1) * stride_x_token + mask_w = idx_feats < dim + if KERNEL_WIDTH == 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 3: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 4: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + if KERNEL_WIDTH == 5: + # ruff: noqa: F841 + conv_states_ptrs = prior_tokens # [BLOCK_N] + # col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") + + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + x_base_1d = x_base + token_offset * stride_x_token # starting of chunk + + # PRE-LOAD WEIGHTS + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + mask_x_1d = idx_feats < dim + for idx_token in range(segment_len): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < segment_len) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token) * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_fn( + x: torch.Tensor, + weight: torch.Tensor, + bias: Union[torch.Tensor, None], + conv_states: torch.Tensor, + query_start_loc: torch.Tensor, + out: torch.Tensor, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + activation: Optional[str] = "silu", + pad_slot_id: int = PAD_SLOT_ID, + metadata=None, + validate_data=True, +): + """support varlen + continuous batching when x is 2D tensor + + x: (dim,cu_seq_len) + cu_seq_len = total tokens of all seqs in that batch + sequences are concatenated from left to right for varlen + weight: (dim, width) + conv_states: (...,dim,width - 1) itype + updated inplace if provided + [it use `cache_indices` to get the index to the cache of conv_state for that sequence + + conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True + and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' + ] + query_start_loc: (batch + 1) int32 + The cumulative sequence lengths of the sequences in + the batch, used to index into sequence. prepended by 0. + if + x = [5, 1, 1, 1] <- continuous batching (batch=4) + then + query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is + the ending index of the last sequence + [length(query_start_loc)-1 == batch] + for example: query_start_loc = torch.Tensor([0,10,16,17]), + x.shape=(dim,17) + cache_indices: (batch) int32 + indicates the corresponding state index, + like so: conv_state = conv_states[cache_indices[batch_id]] + has_initial_state: (batch) bool + indicates whether should the kernel take the current state as initial + state for the calculations + [single boolean for each sequence in the batch: True or False] + bias: (dim,) + activation: either None or "silu" or "swish" or True + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + + out: same shape as `x` + """ + if isinstance(activation, bool) and activation: + activation = "silu" + + args = None + # Store original dtype to cast back at the end + original_x_dtype = x.dtype + x = x.to(conv_states.dtype) + if metadata is not None: + nums_dict = metadata.nums_dict + args = nums_dict + batch_ptr = metadata.batch_ptr + token_chunk_offset_ptr = metadata.token_chunk_offset_ptr + else: + seqlens = np.diff(query_start_loc.to("cpu")) + args = seqlens + MAX_NUM_PROGRAMS = 1024 + + batch_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device + ) # tracking which seq-idx the Triton program is handling + token_chunk_offset_ptr = torch.full( + (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device + ) # tracking BLOCK_M-based index in the sequence the Triton program is handling + + is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) + dim, cu_seqlen = x.shape + _, width = weight.shape + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + padded_batch = query_start_loc.size(0) - 1 + stride_x_seq = 0 + stride_x_dim = x.stride(0) + stride_x_token = x.stride(1) + stride_w_dim = weight.stride(0) + stride_w_width = weight.stride(1) + stride_istate_seq = 0 + stride_istate_dim = 0 + stride_istate_token = 0 + num_cache_lines = 0 + if conv_states is not None: + # extensions to support vLLM: + # 1. conv_states is used to replaced initial_states + # 2. conv_states serve as a cache with num cache lines can be larger than batch size + # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] + # 4. computation can be skipped if cache_indices[idx] == pad_slot_id + num_cache_lines = conv_states.size(0) + assert ( + num_cache_lines == conv_states.shape[0] + and dim == conv_states.shape[1] + and width - 1 <= conv_states.shape[2] + ), f"{num_cache_lines} {dim} {width} {conv_states.shape}" + stride_istate_seq = conv_states.stride(0) + stride_istate_dim = conv_states.stride(1) + stride_istate_token = conv_states.stride(2) + assert stride_istate_dim == 1 + if out.dim() == 2: + stride_o_seq = 0 + stride_o_dim = out.stride(0) + stride_o_token = out.stride(1) + else: + stride_o_seq = out.stride(0) + stride_o_dim = out.stride(1) + stride_o_token = out.stride(2) + + if validate_data: + assert x.dim() == 2 + assert query_start_loc is not None + assert query_start_loc.dim() == 1 + assert x.stride(0) == 1 or x.stride(1) == 1 + if bias is not None: + assert bias.dim() == 1 + assert dim == bias.size(0) + if cache_indices is not None: + assert cache_indices.dim() == 1 + assert padded_batch == cache_indices.size(0) + if has_initial_state is not None: + assert has_initial_state.size() == (padded_batch,) + assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" + assert weight.stride(1) == 1 + assert (dim, width) == weight.shape + assert is_channel_last, "Need to run in channel-last layout" + + if metadata is None: + + def num_program(META, seqlens): + tot = 0 + + mlist = [] + offsetlist = [] # type: ignore + + nums = -(-seqlens // META["BLOCK_M"]) + + tot = nums.sum().item() + mlist = np.repeat(np.arange(len(nums)), nums) + for idx, num in enumerate(nums): + offsetlist.extend(range(num)) # chunk-idx if a sequence is split into multiple chunks + + if META["batch_ptr"].nelement() < len(mlist): + newlen = len(mlist) + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= len(mlist): + META["batch_ptr"][0 : len(mlist)].copy_(torch.from_numpy(np.array(mlist))) + META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(torch.from_numpy(np.array(offsetlist))) + + META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) + META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(META["x_ptr"].device) + return tot + + else: + + def num_program(META, nums_dict): + tot = nums_dict[META["BLOCK_M"]]["tot"] + + mlist = nums_dict[META["BLOCK_M"]]["mlist"] + mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] + + offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] + + if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: + META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] + META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]]["token_chunk_offset_ptr"] + else: + if META["batch_ptr"].nelement() < mlist_len: + newlen = mlist_len + 1 + META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) + + if META["batch_ptr"].nelement() >= mlist_len: + META["batch_ptr"][0:mlist_len].copy_(mlist) + META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) + return tot + + def grid(META): + return ( + num_program(META, args), + triton.cdiv(dim, META["BLOCK_N"]), + ) + + if batch_ptr.device != x.device: + batch_ptr = batch_ptr.to(x.device) + token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) + + _causal_conv1d_fwd_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_states, + cache_indices, + has_initial_state, + query_start_loc, + batch_ptr, + token_chunk_offset_ptr, + out, + # Matrix dimensions + padded_batch, + dim, + cu_seqlen, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + HAS_INITIAL_STATES=has_initial_state is not None, + HAS_CACHE=conv_states is not None, + IS_CONTINUOUS_BATCHING=cache_indices is not None, + USE_PAD_SLOT=pad_slot_id is not None, + NP2_STATELEN=np2_statelen, + # launch_cooperative_grid=True + BLOCK_M=8, + BLOCK_N=256, + num_stages=2, + ) + return out.to(original_x_dtype) + + +@triton.jit() +def _causal_conv1d_update_kernel( + # Pointers to matrices + x_ptr, # (batch, dim, seqlen) + w_ptr, # (dim, width) + bias_ptr, + conv_state_ptr, + cache_seqlens_ptr, # circular buffer + conv_state_indices_ptr, + num_accepted_tokens_ptr, + query_start_loc_ptr, # (batch + 1) + o_ptr, # (batch, dim, seqlen) + # Matrix dimensions + batch: int, + dim: tl.constexpr, + seqlen: tl.constexpr, + state_len: tl.constexpr, + num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines + # Strides + stride_x_seq: tl.constexpr, + stride_x_dim: tl.constexpr, + stride_x_token: tl.constexpr, + stride_w_dim: tl.constexpr, + stride_w_width: tl.constexpr, + stride_conv_state_seq: tl.constexpr, + stride_conv_state_dim: tl.constexpr, + stride_conv_state_tok: tl.constexpr, + stride_state_indices: tl.constexpr, + stride_o_seq: tl.constexpr, + stride_o_dim: tl.constexpr, + stride_o_token: tl.constexpr, + # others + pad_slot_id: tl.constexpr, + # Meta-parameters + HAS_BIAS: tl.constexpr, + KERNEL_WIDTH: tl.constexpr, + SILU_ACTIVATION: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, + NP2_STATELEN: tl.constexpr, + USE_PAD_SLOT: tl.constexpr, + BLOCK_N: tl.constexpr, +): + # ruff: noqa: E501 + idx_seq = tl.program_id(0) + if idx_seq >= batch: + return + + # [BLOCK_N,] elements along the feature-dimension (channel) + idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) + + if IS_CONTINUOUS_BATCHING: + # mask = idx_seq < batch + conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices).to(tl.int64) + else: + conv_state_batch_coord = idx_seq + if USE_PAD_SLOT: # noqa + if conv_state_batch_coord == pad_slot_id: + # not processing as this is not the actual sequence + return + + if IS_VARLEN: + query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) + query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) + # revise state_len and seqlen + state_len = state_len - (seqlen - (query_end_index - query_start_index)) + seqlen = query_end_index - query_start_index + x_offset = query_start_index * stride_x_token + o_offset = query_start_index * stride_o_token + else: + query_start_index = idx_seq * seqlen + query_end_index = query_start_index + seqlen + x_offset = idx_seq * stride_x_seq + o_offset = idx_seq * stride_o_seq + + if query_start_index == query_end_index: + return + + if IS_SPEC_DECODING: + # The rolling of conv state: + # + # Before forward, the conv_state is: + # [history1, history2, ..., historyM]. + # + # After forward, the conv_state becomes: + # [history2, ..., historyM, draft1, draft2, ..., draftN]. + # + # After acceptance, it becomes: + # + # - accept 1 tokens: [history2, ..., historyM, draft1] + # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] + # - and so on. + conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 + else: + conv_state_token_offset = 0 + + # STEP 1: READ init_state data + conv_states_base = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) + mask_w = idx_feats < dim + + prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok + if KERNEL_WIDTH >= 2: + conv_states_ptrs = prior_tokens # [BLOCK_N] + col0 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 3: + conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] + col1 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 4: + conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] + col2 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 5: + conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] + col3 = tl.load(conv_states_ptrs, mask_w, 0.0) + if KERNEL_WIDTH >= 6: + conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] + col4 = tl.load(conv_states_ptrs, mask_w, 0.0) + + # STEP 2: assume state_len > seqlen + idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] + + # With speculative decoding, the conv_state updates works in a sliding + # window manner, at each forward pass, the tokens are shift by 1, so we + # load since idx_tokens + 1. + conv_state_ptrs_source = ( + conv_state_ptr + + (conv_state_batch_coord * stride_conv_state_seq) + + conv_state_token_offset * stride_conv_state_tok + + (idx_feats * stride_conv_state_dim)[None, :] + + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] + ) # [BLOCK_M, BLOCK_N] + mask = ( + (conv_state_batch_coord < num_cache_lines) + & ((idx_tokens + seqlen) < state_len)[:, None] + & (idx_feats < dim)[None, :] + ) + conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) + + VAL = state_len - seqlen + x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] + + x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] + + mask_x = ( + (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] + ) # token-index # token-index # feature-index + loaded_x = tl.load(x_ptrs, mask_x, 0.0) + tl.debug_barrier() + + new_conv_state = tl.where(mask, conv_state, loaded_x) + + conv_state_base = ( + conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) + ) # [BLOCK_N,] + conv_state_ptrs_target = conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] + mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] + tl.store(conv_state_ptrs_target, new_conv_state, mask) + + # STEP 3: init accumulator + if HAS_BIAS: + bias = bias_ptr + idx_feats + mask_bias = idx_feats < dim + acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] + else: + acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) + + # STEP 4: + # PRE-LOAD WEIGHTS + # first kernel column, configured for weights to handle BLOCK_N features in range + w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] + mask_w = idx_feats < dim + if KERNEL_WIDTH >= 2: + w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor + w_col0 = tl.load(w_ptrs, mask_w, other=0.0) + w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor + w_col1 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 3: + w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor + w_col2 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 4: + w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor + w_col3 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 5: + w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor + w_col4 = tl.load(w_ptrs, mask_w, other=0.0) + if KERNEL_WIDTH >= 6: + w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor + w_col5 = tl.load(w_ptrs, mask_w, other=0.0) + + x_base_1d = x_base # starting of chunk [BLOCK_N] + mask_x_1d = idx_feats < dim + + # STEP 5: compute each token + for idx_token in tl.range(seqlen): + acc = acc_preload + + matrix_w = w_col0 + matrix_x = col0 + for j in tl.static_range(KERNEL_WIDTH): + if KERNEL_WIDTH == 2: + if j == 1: # KERNEL_WIDTH-1: + matrix_w = w_col1 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 3: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 4: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 5: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + elif KERNEL_WIDTH == 6: + if j == 1: + matrix_w = w_col1 + matrix_x = col1 + elif j == 2: + matrix_w = w_col2 + matrix_x = col2 + elif j == 3: + matrix_w = w_col3 + matrix_x = col3 + elif j == 4: + matrix_w = w_col4 + matrix_x = col4 + elif j == 5: + matrix_w = w_col5 + x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] + matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) + + acc += matrix_x * matrix_w # [BLOCK_N] + + if KERNEL_WIDTH == 2: + col0 = matrix_x + elif KERNEL_WIDTH == 3: + col0 = col1 + col1 = matrix_x + elif KERNEL_WIDTH == 4: + col0 = col1 + col1 = col2 + col2 = matrix_x + elif KERNEL_WIDTH == 5: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = matrix_x + elif KERNEL_WIDTH == 6: + col0 = col1 + col1 = col2 + col2 = col3 + col3 = col4 + col4 = matrix_x + + if SILU_ACTIVATION: + acc = acc / (1 + tl.exp(-acc)) + mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index + o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) + + tl.store(o_ptrs, acc, mask=mask_1d) + + +def causal_conv1d_update( + x: torch.Tensor, + conv_state: torch.Tensor, + weight: torch.Tensor, + bias: Optional[torch.Tensor] = None, + activation: Union[bool, str, None] = None, + cache_seqlens: Optional[torch.Tensor] = None, + conv_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, + max_query_len: int = -1, + pad_slot_id: int = PAD_SLOT_ID, + validate_data=False, +): + """ + x: Input tensor which can take the following shapes: + + - `[batch, dim]` - single token prediction + - `[batch, dim, seqlen]` - single or multiple tokens prediction + - `[num_tokens, dim]` - continuous batching, where num_tokens is + the total tokens of all sequences in that batch + + conv_state: (..., dim, state_len), where state_len >= width - 1 + weight: (dim, width) + bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the conv_state + starting at the index + @cache_seqlens % state_len. + conv_state_indices: (batch,), dtype int32 + If not None, the conv_state is a larger tensor along the batch dim, + and we are selecting the batch coords specified by conv_state_indices. + Useful for a continuous batching scenario. + num_accepted_tokens: (batch,), dtype int32 + If not None, it indicates the number of accepted tokens for each + sequence in the batch. + This is used in speculative decoding, where the conv_state is updated + in a sliding window manner. + query_start_loc: (batch + 1,) int32 + If not None, the inputs is given in a varlen fashion and this indicates + the starting index of each sequence in the batch. + max_query_len: int + If query_start_loc is not None, this indicates the maximum query + length in the batch. + pad_slot_id: int + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 + out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` + """ + if validate_data: + assert cache_seqlens is None # not implemented yet - ok for vLLM + assert pad_slot_id is not None + assert x.stride(1) == 1 + if isinstance(activation, bool): + activation = "silu" if activation is True else None + elif activation is not None: + assert activation in ["silu", "swish"] + + original_x_dtype = x.dtype + x = x.to(conv_state.dtype) + unsqueeze = query_start_loc is None and x.dim() == 2 + if unsqueeze: + # make it (batch, dim, seqlen) with seqlen == 1 + x = x.unsqueeze(-1) + if query_start_loc is None: + batch, dim, seqlen = x.shape + else: + assert conv_state_indices is not None + batch = conv_state_indices.size(0) + dim = x.size(1) + seqlen = max_query_len + _, width = weight.shape + # conv_state: (..., dim, state_len), where state_len >= width - 1 + num_cache_lines, _, state_len = conv_state.size() + + if validate_data: + assert dim == weight.size(0) + assert ( + conv_state.stride(-2) == 1 + ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" + assert state_len >= width - 1 + # when above happens, we don't shift-left to keep any records in conv_state + assert dim == conv_state.size(1) + if conv_state_indices is None: + assert conv_state.size(0) >= batch + else: + assert (batch,) == conv_state_indices.shape + + assert num_cache_lines >= batch + assert weight.stride(1) == 1 # Need this + assert cache_seqlens is None # not needed for vLLM - circular buffer + + # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' + out = x + stride_w_dim, stride_w_width = weight.stride() + + if query_start_loc is None: + # X (batch, dim, seqlen) + stride_x_seq, stride_x_dim, stride_x_token = x.stride() + stride_o_seq, stride_o_dim, stride_o_token = out.stride() + else: + # X (dim, cu_seqlen) + stride_x_token, stride_x_dim = x.stride() + stride_x_seq = 0 + stride_o_token, stride_o_dim = out.stride() + stride_o_seq = 0 + + stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() + stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 + if num_accepted_tokens is not None: + state_len = width - 1 + (seqlen - 1) # effective state_len needed + else: + state_len = width - 1 + np2_statelen = triton.next_power_of_2(state_len) + + def grid(META): + return ( + batch, + triton.cdiv(dim, META["BLOCK_N"]), + ) + + _causal_conv1d_update_kernel[grid]( + # Pointers to matrices + x, + weight, + bias, + conv_state, + cache_seqlens, + conv_state_indices, + num_accepted_tokens, + query_start_loc, + out, + # Matrix dimensions + batch, + dim, + seqlen, + state_len, + num_cache_lines, + # stride + stride_x_seq, + stride_x_dim, + stride_x_token, + stride_w_dim, + stride_w_width, + stride_istate_seq, + stride_istate_dim, + stride_istate_token, + stride_state_indices, + stride_o_seq, + stride_o_dim, + stride_o_token, + # others + pad_slot_id, + # META + HAS_BIAS=bias is not None, + KERNEL_WIDTH=width, + SILU_ACTIVATION=activation in ["silu", "swish"], + IS_VARLEN=query_start_loc is not None, + IS_CONTINUOUS_BATCHING=conv_state_indices is not None, + IS_SPEC_DECODING=num_accepted_tokens is not None, + NP2_STATELEN=np2_statelen, + USE_PAD_SLOT=pad_slot_id is not None, + BLOCK_N=256, + ) + if unsqueeze: + out = out.squeeze(-1) + return out.to(original_x_dtype) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py new file mode 100644 index 000000000..50f7f20b7 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -0,0 +1,16 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +from .ops.chunk import chunk_gated_delta_rule +from .ops.fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py new file mode 100644 index 000000000..cd3b0962a --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py new file mode 100644 index 000000000..22c81ae63 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch" + f" seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py new file mode 100644 index 000000000..f20c95d90 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp +from .utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_G"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = ( + tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + if SAVE_NEW_VALUE + else None + ) + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py new file mode 100644 index 000000000..73c2e1f19 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + {"USE_G": lambda args: args["g"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + if FLA_GDN_FIX_BT: + BT = 64 + else: + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py new file mode 100644 index 000000000..aa545e8ec --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp + + +@triton.heuristics( + {"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "USE_G": lambda args: args["g_cumsum"] is not None} +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * exp(b_g_diff) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py new file mode 100644 index 000000000..9cd6a6545 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, S=S, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: + if not head_first and g.shape[1] < g.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch" + f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py new file mode 100644 index 000000000..4ff18d4f6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ( + ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token + ) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py new file mode 100644 index 000000000..8b1d59fc6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py new file mode 100644 index 000000000..7225cd4ae --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.autotune(configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], key=["D"]) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D"], +) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py new file mode 100644 index 000000000..ec0999455 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import triton +import triton.language as tl + + +@triton.jit +def div_normal(x, y): + return x / y + + +div = div_normal +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +if not hasattr(tl, "gather"): + + @triton.jit + def gather(src, index, axis, _builder=None): + # This is a fallback implementation when tl.gather is not supported + # In order to pass triton compiler, there is no actual gather operation + return src + +else: + gather = tl.gather diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py new file mode 100644 index 000000000..46e4d5082 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import input_guard + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["BT"], +) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee") + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee") + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)) + tl.store(p_Ai_12, fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_13, fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_14, fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_23, fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_24, fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_34, fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@input_guard +def solve_tril( + A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + + chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py new file mode 100644 index 000000000..d8f29f287 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from enum import Enum +from typing import Any, Callable, Literal, Optional + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != "hip" else "cuda" +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py new file mode 100644 index 000000000..dec8d2ffc --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py new file mode 100644 index 000000000..99a5e2f70 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -0,0 +1,83 @@ +import torch +import triton +import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune + +# g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +@triton.jit +def fused_gdn_gating_kernel( + g, + A_log, + a, + dt_bias, + seq_len, + NUM_HEADS: tl.constexpr, + beta: tl.constexpr, + threshold: tl.constexpr, + BLK_HEADS: tl.constexpr, +): + i_b, i_s, i_d = tl.program_id(0), tl.program_id(1), tl.program_id(2) + head_off = i_d * BLK_HEADS + tl.arange(0, BLK_HEADS) + off = i_b * seq_len * NUM_HEADS + i_s * NUM_HEADS + head_off + mask = head_off < NUM_HEADS + blk_A_log = tl.load(A_log + head_off, mask=mask) + blk_a = tl.load(a + off, mask=mask) + blk_bias = tl.load(dt_bias + head_off, mask=mask) + # If the model is loaded in fp16, without the .float() here, A might be -inf + x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) + softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) + blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x + tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + + +def _get_fused_gdn_gating_configs(): + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [8, 16, 32, 64] for nw in [1, 2, 4]] + + +def _get_fused_gdn_gating_static_key(a: torch.Tensor): + # group by head size and input dtype + return {"NUM_HEADS": a.shape[1], "a_dtype": str(a.dtype)} + + +@autotune( + kernel_name="fused_gdn_gating:v1", + configs_gen_func=_get_fused_gdn_gating_configs, + static_key_func=_get_fused_gdn_gating_static_key, + run_key_func=lambda a: a.shape[0], +) +def fused_gdn_gating( + A_log: torch.Tensor, + a: torch.Tensor, + dt_bias: torch.Tensor, + beta: float = 1.0, + threshold: float = 20.0, + run_config: dict | None = None, +) -> torch.Tensor: + batch, num_heads = a.shape + seq_len = 1 + + # default heuristic when autotune is disabled + if not run_config: + # choose the largest block size that does not exceed num_heads + candidate_blk = [8, 16, 32, 64] + blk_heads = max([c for c in candidate_blk if c <= max(8, num_heads)] or [8]) + run_config = {"BLK_HEADS": blk_heads, "num_warps": 1} + + BLK_HEADS = run_config["BLK_HEADS"] + num_warps = run_config.get("num_warps", 1) + + grid = (batch, seq_len, triton.cdiv(num_heads, BLK_HEADS)) + g = torch.empty_like(a, dtype=torch.float32) + fused_gdn_gating_kernel[grid]( + g, + A_log, + a, + dt_bias, + seq_len, + num_heads, + beta, + threshold, + BLK_HEADS, + num_warps=num_warps, + ) + return g diff --git a/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py new file mode 100644 index 000000000..89db5e00c --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gated_rmsnorm.py @@ -0,0 +1,174 @@ +import triton +import triton.language as tl +import torch +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.heuristics( + { + "HAS_BIAS": lambda args: args["B"] is not None, + } +) +@triton.jit +def gated_rmsnorm_forward_kernel( + X, # pointer to the input + Y, # pointer to the output + W, # pointer to the weights + B, # pointer to the biases + Z, # pointer to the other branch (required, not optional) + Rstd, # pointer to the 1/std + stride_x_row, # how much to increase the pointer when moving by 1 row + stride_y_row, + stride_z_row, + M, # number of rows in X + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_N: tl.constexpr, + HAS_BIAS: tl.constexpr, + NORM_BEFORE_GATE: tl.constexpr, +): + # Map the program id to the row of X and Y it should compute. + row = tl.program_id(0) + group = tl.program_id(1) + X += row * stride_x_row + group * N + Y += row * stride_y_row + group * N + Z += row * stride_z_row + group * N + Rstd += group * M + W += group * N + if HAS_BIAS: + B += group * N + # Compute variance (RMS norm doesn't use mean) + cols = tl.arange(0, BLOCK_N) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + if not NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=cols < N).to(tl.float32) + x *= z * tl.sigmoid(z) + # RMS norm: compute variance directly without mean subtraction + xbar = tl.where(cols < N, x, 0.0) + var = tl.sum(xbar * xbar, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + tl.store(Rstd + row, rstd) + # Normalize and apply linear transformation + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + if HAS_BIAS: + b = tl.load(B + cols, mask=mask).to(tl.float32) + # RMS norm: normalize without mean subtraction + x_hat = x * rstd + y = x_hat * w + b if HAS_BIAS else x_hat * w + if NORM_BEFORE_GATE: + z = tl.load(Z + cols, mask=mask).to(tl.float32) + y *= z * tl.sigmoid(z) + # Write output + tl.store(Y + cols, y, mask=mask) + + +def _get_gated_rmsnorm_configs(): + """Generate configurations for autotuning gated RMSNorm kernel.""" + configs = [] + # Different BLOCK_N sizes (powers of 2) + for block_n in [64, 128, 256, 512, 1024, 2048, 4096]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + # Skip configurations that are likely to be inefficient + if block_n >= 2048 and num_warps > 4: + continue + if block_n <= 128 and num_warps > 2: + continue + configs.append({"BLOCK_N": block_n, "num_warps": num_warps}) + return configs + + +def _get_gated_rmsnorm_static_key(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + M, N = x.shape + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(weight.dtype), + "N": N, + "has_bias": bias is not None, + } + + +@autotune( + kernel_name="gated_rmsnorm_forward:v1", + configs_gen_func=_get_gated_rmsnorm_configs, + static_key_func=_get_gated_rmsnorm_static_key, + run_key_func=lambda x: x.shape[0], +) +def gated_rmsnorm_forward( + x: torch.Tensor, + weight: torch.Tensor, + bias: torch.Tensor, + eps: float, + z: torch.Tensor, + out: torch.Tensor = None, + group_size: int = None, + norm_before_gate: bool = True, + run_config: dict = None, +): + M, N = x.shape + if group_size is None: + group_size = N + assert N % group_size == 0 + ngroups = N // group_size + assert x.stride(-1) == 1 + # z is required for gated_rmsnorm + assert z is not None, "z cannot be None for gated_rmsnorm_forward" + assert z.stride(-1) == 1 + assert z.shape == (M, N) + assert weight.shape == (N,) + assert weight.stride(-1) == 1 + if bias is not None: + assert bias.stride(-1) == 1 + assert bias.shape == (N,) + # allocate output + if out is not None: + assert out.shape == x.shape + else: + out = torch.empty_like(x) + assert out.stride(-1) == 1 + # For RMS norm, we still need rstd for the kernel + rstd = torch.empty((ngroups * M,), dtype=torch.float32, device=x.device) + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_N // 256, 1), 8) + run_config = {"BLOCK_N": BLOCK_N, "num_warps": num_warps} + + BLOCK_N = run_config["BLOCK_N"] + num_warps = run_config["num_warps"] + + # Validate BLOCK_N against group_size + if group_size > BLOCK_N: + # Fall back to largest valid BLOCK_N + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(group_size)) + if group_size > BLOCK_N: + raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.") + + grid = (M, ngroups) + gated_rmsnorm_forward_kernel[grid]( + x, + out, + weight, + bias, + z, + rstd, + x.stride(0), + out.stride(0), + z.stride(0), + M, + group_size, + eps, + BLOCK_N=BLOCK_N, + NORM_BEFORE_GATE=norm_before_gate, + num_warps=num_warps, + ) + return out diff --git a/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py new file mode 100644 index 000000000..210e78db1 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/gemma_rmsnorm.py @@ -0,0 +1,144 @@ +import torch + +import triton +import triton.language as tl +import os + +from lightllm.common.triton_utils.autotuner import autotune + + +@triton.jit +def _gemma_rmsnorm_fwd_kernel( + x_ptr, + w_ptr, + y_ptr, + x_stride0, + x_stride1, + y_stride0, + y_stride1, + N: tl.constexpr, + EPS: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + row = tl.program_id(0) + x_ptr = x_ptr + row * x_stride0 + y_ptr = y_ptr + row * y_stride0 + + _sum = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + cols * x_stride1, mask=cols < N, other=0.0).to(tl.float32) + _sum += x * x + + var = tl.sum(_sum, axis=0) / N + rstd = 1 / tl.sqrt(var + EPS) + # Normalize and apply linear transformation + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(w_ptr + cols, mask=mask).to(tl.float32) + x = tl.load(x_ptr + cols * x_stride1, mask=mask, other=0.0).to(tl.float32) + x_hat = x * rstd + w = w + 1.0 + y = x_hat * w + # Write output + tl.store(y_ptr + cols * y_stride1, y.to(y_ptr.dtype.element_ty), mask=mask) + + +def _get_gemma_rmsnorm_configs(): + """Generate configurations for autotuning gemma RMSNorm kernel.""" + configs = [] + # Different BLOCK_SIZE values (powers of 2) + for block_size in [128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768, 65536, 65536 * 2]: + # Different number of warps + for num_warps in [1, 2, 4, 8]: + for num_stages in [1, 2, 3, 4, 5]: + configs.append({"BLOCK_SIZE": block_size, "num_warps": num_warps, "num_stages": num_stages}) + return configs + + +def _get_gemma_rmsnorm_static_key(x: torch.Tensor, w: torch.Tensor): + """Generate static key for caching autotuned configurations.""" + N = x.shape[-1] + return { + "x_dtype": str(x.dtype), + "weight_dtype": str(w.dtype), + "N": N, + } + + +@autotune( + kernel_name="gemma_rmsnorm_forward:v1", + configs_gen_func=_get_gemma_rmsnorm_configs, + static_key_func=_get_gemma_rmsnorm_static_key, + run_key_func=lambda x: x.shape[-1], +) +def gemma_rmsnorm_forward(x, w, eps, out=None, run_config: dict = None): + # Inplace gemma RMS Norm + # Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16) + # See https://github.com/huggingface/transformers/pull/29402 + N = x.shape[-1] + y = torch.empty_like(x) if out is None else out + x_arg = x.view(-1, N) + y_arg = y.view(-1, N) + + M, _ = x_arg.shape + + # Default heuristic when autotune is disabled or no config provided + if not run_config: + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BLOCK_SIZE = min(MAX_FUSED_SIZE, triton.next_power_of_2(N)) + if N > BLOCK_SIZE: + raise RuntimeError("This gemma rmsnorm doesn't support feature dim >= 64KB.") + # heuristics for number of warps + num_warps = min(max(BLOCK_SIZE // 256, 1), 8) + run_config = {"BLOCK_SIZE": BLOCK_SIZE, "num_warps": num_warps, "num_stages": 1} + + BLOCK_SIZE = run_config["BLOCK_SIZE"] + num_warps = run_config["num_warps"] + num_stages = run_config["num_stages"] + + _gemma_rmsnorm_fwd_kernel[(M,)]( + x_arg, + w, + y_arg, + x_stride0=x.stride(0), + x_stride1=x.stride(1), + y_stride0=y.stride(0), + y_stride1=y.stride(1), + N=N, + EPS=eps, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=num_stages, + ) + + return y + + +def _gemma_rmsnorm_fwd_torch(x, weight, eps): + original_dtype = x.dtype + x = x.to(torch.float32) + x = x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + eps) + x = x * (1.0 + weight.float()) + return x.to(original_dtype) + + +def test_rms_norm(M, N, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (M, N) + w_shape = (x_shape[-1],) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = gemma_rmsnorm_forward(x, weight, eps) + y_ref = _gemma_rmsnorm_fwd_torch(x, weight, eps) + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + # Use appropriate tolerance based on dtype + atol = 1e-2 if dtype == torch.float32 else 5e-2 + assert torch.allclose(y_tri, y_ref, atol=atol, rtol=0) + return diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index bf0e89887..fbdbcd03d 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -572,4 +572,5 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) + parser.add_argument("--linear_attn_cache_size", type=int, default=2000, help="""The size of linear attn cache. """) return parser diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 8bda50fb7..29f271c5e 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -33,7 +33,7 @@ import uuid from PIL import Image import multiprocessing as mp -from typing import AsyncGenerator, Union +from typing import Any, AsyncGenerator, Union from typing import Callable from lightllm.server import TokenLoad from fastapi import BackgroundTasks, FastAPI, Request, WebSocket, WebSocketDisconnect @@ -130,6 +130,22 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} +@app.get("/get_server_info") +@app.post("/get_server_info") +def get_server_info(): + # 将 StartArgs 转换为字典格式 + from dataclasses import asdict + + server_info: dict[str, Any] = asdict(g_objs.args) + return {**server_info} + + +@app.get("/get_weight_version") +@app.post("/get_weight_version") +def get_weight_version(): + return {"weight_version": g_objs.args.weight_version} + + @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") diff --git a/lightllm/server/api_models.py b/lightllm/server/api_models.py index abd29dc92..d378dd6c5 100644 --- a/lightllm/server/api_models.py +++ b/lightllm/server/api_models.py @@ -72,7 +72,7 @@ class CompletionRequest(BaseModel): # prompt: string or tokens prompt: Union[str, List[str], List[int], List[List[int]]] suffix: Optional[str] = None - max_tokens: Optional[int] = 16 + max_tokens: Optional[int] = 16000 temperature: Optional[float] = 1.0 top_p: Optional[float] = 1.0 n: Optional[int] = 1 @@ -112,7 +112,7 @@ class ChatCompletionRequest(BaseModel): stream: Optional[bool] = False stream_options: Optional[StreamOptions] = None stop: Optional[Union[str, List[str]]] = None - max_tokens: Optional[int] = 16 + max_tokens: Optional[int] = 16000 presence_penalty: Optional[float] = 0.0 frequency_penalty: Optional[float] = 0.0 logit_bias: Optional[Dict[str, float]] = None diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index b4447d808..b0a1189d3 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -1,15 +1,33 @@ import torch from .api_cli import make_argument_parser +from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.utils.log_utils import init_logger -if __name__ == "__main__": - torch.multiprocessing.set_start_method("spawn") # this code will not be ok for settings to fork to subprocess - parser = make_argument_parser() - args = parser.parse_args() +logger = init_logger(__name__) + + +def launch_server(args: StartArgs): from .api_start import pd_master_start, normal_or_p_d_start, config_server_start + try: + # this code will not be ok for settings to fork to subprocess + torch.multiprocessing.set_start_method("spawn") + except RuntimeError as e: + logger.warning(f"Failed to set start method: {e}") + except Exception as e: + logger.error(f"Failed to set start method: {e}") + raise e + if args.run_mode == "pd_master": pd_master_start(args) elif args.run_mode == "config_server": config_server_start(args) else: normal_or_p_d_start(args) + + +if __name__ == "__main__": + parser = make_argument_parser() + args = parser.parse_args() + + launch_server(StartArgs(**vars(args))) diff --git a/lightllm/server/api_start.py b/lightllm/server/api_start.py index 9cc3d38c2..ddff04615 100644 --- a/lightllm/server/api_start.py +++ b/lightllm/server/api_start.py @@ -16,6 +16,7 @@ from lightllm.utils.process_check import is_process_active from lightllm.utils.multinode_utils import send_and_receive_node_ip from lightllm.utils.shm_size_check import check_recommended_shm_size +from lightllm.server.core.objs.start_args_type import StartArgs logger = init_logger(__name__) @@ -51,20 +52,38 @@ def signal_handler(sig, frame): process_manager.terminate_all_processes() logger.info("All processes have been terminated gracefully.") sys.exit(0) + elif sig == signal.SIGHUP: + logger.info("Received SIGHUP (terminal closed), shutting down gracefully...") + if http_server_process and http_server_process.poll() is None: + http_server_process.send_signal(signal.SIGTERM) + + start_time = time.time() + while (time.time() - start_time) < 60: + if not is_process_active(http_server_process.pid): + logger.info("httpserver exit") + break + time.sleep(1) + + if time.time() - start_time < 60: + logger.info("HTTP server has exited gracefully") + else: + logger.warning("HTTP server did not exit in time, killing it...") + kill_recursive(http_server_process) + + process_manager.terminate_all_processes() + logger.info("All processes have been terminated gracefully due to terminal closure.") + sys.exit(0) signal.signal(signal.SIGTERM, signal_handler) signal.signal(signal.SIGINT, signal_handler) + signal.signal(signal.SIGHUP, signal_handler) logger.info(f"start process pid {os.getpid()}") logger.info(f"http server pid {http_server_process.pid}") return -def normal_or_p_d_start(args): - from lightllm.server.core.objs.start_args_type import StartArgs - - args: StartArgs = args - +def normal_or_p_d_start(args: StartArgs): set_unique_server_name(args) if not args.disable_shm_warning: @@ -376,7 +395,7 @@ def normal_or_p_d_start(args): return -def pd_master_start(args): +def pd_master_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "pd_master": return @@ -439,7 +458,7 @@ def pd_master_start(args): http_server_process.wait() -def config_server_start(args): +def config_server_start(args: StartArgs): set_unique_server_name(args) if args.run_mode != "config_server": return diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 71cafd6c4..ef970c412 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -1,37 +1,42 @@ from dataclasses import dataclass, field from typing import List, Optional, Tuple -# 只是为了更好的编程提示 +# 服务启动参数 @dataclass class StartArgs: run_mode: str = field( default="normal", - metadata={"choices": ["normal", "prefill", "decode", "pd_master", "nixl_prefill", "nixl_decode"]}, + metadata={ + "choices": ["normal", "prefill", "decode", "nixl_prefill", "nixl_decode", "pd_master", "config_server"] + }, ) host: str = field(default="127.0.0.1") port: int = field(default=8000) + httpserver_workers: int = field(default=1) zmq_mode: str = field( default="ipc:///tmp/", metadata={"help": "use socket mode or ipc mode, only can be set in ['tcp://', 'ipc:///tmp/']"}, ) - pd_master_ip: str = field(default="127.0.0.1") + pd_master_ip: str = field(default="0.0.0.0") pd_master_port: int = field(default=1212) config_server_host: str = field(default=None) config_server_port: int = field(default=None) pd_decode_rpyc_port: int = field(default=42000) - select_p_d_node_strategy: str = field(default=None) + select_p_d_node_strategy: str = field( + default="round_robin", metadata={"choices": ["random", "round_robin", "adaptive_load"]} + ) model_name: str = field(default="default_model_name") model_dir: Optional[str] = field(default=None) - tokenizer_mode: str = field(default="slow") + tokenizer_mode: str = field(default="fast") load_way: str = field(default="HF") max_total_token_num: Optional[int] = field(default=None) mem_fraction: float = field(default=0.9) batch_max_tokens: Optional[int] = field(default=None) - eos_id: List[int] = field(default_factory=list) + eos_id: Optional[List[int]] = field(default=None) tool_call_parser: Optional[str] = field( - default=None, metadata={"choices": ["llama3", "qwen25", "mistral", "deepseekv3", "kimi_k2", "qwen"]} + default=None, metadata={"choices": ["qwen25", "llama3", "mistral", "deepseekv3", "qwen"]} ) chat_template: Optional[str] = field(default=None) running_max_req_size: int = field(default=1000) @@ -39,11 +44,11 @@ class StartArgs: dp: int = field(default=1) nnodes: int = field(default=1) node_rank: int = field(default=0) - max_req_total_len: int = field(default=2048 + 1024) + max_req_total_len: int = field(default=16384) nccl_host: str = field(default="127.0.0.1") nccl_port: int = field(default=28765) use_config_server_to_init_nccl: bool = field(default=False) - mode: List[str] = field(default_factory=list) + mode: List[str] = field(default_factory=lambda: []) trust_remote_code: bool = field(default=False) disable_log_stats: bool = field(default=False) log_stats_interval: int = field(default=10) @@ -52,11 +57,11 @@ class StartArgs: router_max_wait_tokens: int = field(default=1) disable_aggressive_schedule: bool = field(default=False) disable_dynamic_prompt_cache: bool = field(default=False) - chunked_prefill_size: int = field(default=8192) + chunked_prefill_size: int = field(default=4096) disable_chunked_prefill: bool = field(default=False) diverse_mode: bool = field(default=False) token_healing_mode: bool = field(default=False) - output_constraint_mode: str = field(default="none", metadata={"choices": ["none", "simple", "xgrammar"]}) + output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) enable_multimodal: bool = field(default=False) enable_multimodal_audio: bool = field(default=False) @@ -75,11 +80,11 @@ class StartArgs: health_monitor: bool = field(default=False) metric_gateway: Optional[str] = field(default=None) job_name: str = field(default="lightllm") - grouping_key: List[str] = field(default_factory=list) + grouping_key: List[str] = field(default_factory=lambda: []) push_interval: int = field(default=10) visual_infer_batch_size: int = field(default=1) visual_send_batch_size: int = field(default=1) - visual_gpu_ids: List[int] = field(default_factory=lambda: [0]) + visual_gpu_ids: Optional[List[int]] = field(default=None) visual_tp: int = field(default=1) visual_dp: int = field(default=1) visual_nccl_ports: List[int] = field(default_factory=lambda: [29500]) @@ -88,10 +93,10 @@ class StartArgs: graph_max_batch_size: int = field(default=256) graph_split_batch_size: int = field(default=32) graph_grow_step_size: int = field(default=16) - graph_max_len_in_batch: int = field(default=8192) - quant_type: Optional[str] = field(default=None) + graph_max_len_in_batch: int = field(default=0) + quant_type: Optional[str] = field(default="none") quant_cfg: Optional[str] = field(default=None) - vit_quant_type: Optional[str] = field(default=None) + vit_quant_type: Optional[str] = field(default="none") vit_quant_cfg: Optional[str] = field(default=None) enable_flashinfer_prefill: bool = field(default=False) enable_flashinfer_decode: bool = field(default=False) @@ -101,7 +106,9 @@ class StartArgs: ) ep_redundancy_expert_config_path: Optional[str] = field(default=None) auto_update_redundancy_expert: bool = field(default=False) - mtp_mode: Optional[str] = field(default=None) + mtp_mode: Optional[str] = field( + default=None, metadata={"choices": ["deepseekv3_vanilla", "deepseekv3_eagle", None]} + ) mtp_draft_model_dir: Optional[str] = field(default=None) mtp_step: int = field(default=0) kv_quant_calibration_config_path: Optional[str] = field(default=None) @@ -110,7 +117,7 @@ class StartArgs: pd_node_id: int = field(default=-1) enable_cpu_cache: bool = field(default=False) cpu_cache_storage_size: float = field(default=2) - cpu_cache_token_page_size: int = field(default=64) + cpu_cache_token_page_size: int = field(default=256) enable_disk_cache: bool = field(default=False) disk_cache_storage_size: float = field(default=10) disk_cache_dir: Optional[str] = field(default=None) @@ -131,3 +138,19 @@ class StartArgs: # kernel setting enable_fa3: bool = field(default=False) + + httpserver_workers: int = field(default=1) + disable_shm_warning: bool = field(default=False) + dp_balancer: str = field(default="bs_balancer", metadata={"choices": ["round_robin", "bs_balancer"]}) + enable_custom_allgather: bool = field(default=False) + enable_fused_shared_experts: bool = field(default=False) + enable_mps: bool = field(default=False) + multinode_router_gloo_port: int = field(default=20001) + schedule_time_interval: float = field(default=0.03) + use_dynamic_prompt_cache: bool = field(default=False) + disable_custom_allreduce: bool = field(default=False) + + # hybrid attention model + linear_attn_cache_size: int = field(default=2000) + + weight_version: str = "default" diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py new file mode 100644 index 000000000..95f4be1fa --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -0,0 +1,420 @@ +import torch +import numpy as np +import collections +import xxhash +import threading +import time +from typing import Tuple, Dict, Set, List, Optional, Union +from typing_extensions import override +from sortedcontainers import SortedSet +from abc import ABC, abstractmethod +import math +from dataclasses import dataclass, field + +from .shared_arr import SharedArray +from .radix_cache import UniqueTimeIdGenerator +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + +time_gen = UniqueTimeIdGenerator() + + +class HybridRadixNode: + def __init__(self): + # Core data + self.edge: Tuple[int, ...] = () + self.childrens_list: List["HybridRadixNode"] = [] + self.parent: Optional["HybridRadixNode"] = None + + # LightLLM specific + self.token_id_key = torch.zeros((0,), device="cpu", dtype=torch.int64) + self.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=torch.int64) + self.buffer_idx: Optional[int] = None + + # Node metadata + self.node_value_len: int = 0 + self.node_prefix_total_len: int = 0 + self.ref_counter: int = 0 + self.time_id: int = 0 + + # Eviction metadata + self.hit_count: int = 0 + self.insert_time: float = 0.0 + self.last_access: float = 0.0 + self.node_id: int = 0 + + def is_leaf(self) -> bool: + return len(self.childrens_list) == 0 + + def has_buffer(self) -> bool: + return self.buffer_idx is not None + + def is_referenced(self) -> bool: + return self.ref_counter > 0 + + def collect_path_values(self) -> torch.Tensor: + """Collect all values from root to this node.""" + segments = [] + node = self + while node.parent is not None: + if len(node.token_mem_index_value) > 0: + segments.append(node.token_mem_index_value) + node = node.parent + + if not segments: + return torch.zeros((0,), device="cpu", dtype=torch.int64) + + # Reverse order and concatenate + segments.reverse() + return torch.cat(segments, dim=0) + + def update_time(self): + self.time_id = time_gen.generate_time_id() + self.last_access = time.time() + + def remove_child(self, child_node: "HybridRadixNode"): + child_node.parent = None + self.childrens_list.remove(child_node) + + def get_kv_cache_compare_key(self): + return (self.is_referenced(), not self.is_leaf(), self.has_buffer(), self.time_id) + + def get_buffer_compare_key(self): + return self.time_id + + def add_and_return_new_child(self, token_id_key, token_mem_index_value, buffer_idx): + child = HybridRadixNode() + child.token_id_key = token_id_key + child.token_mem_index_value = token_mem_index_value + child.buffer_idx = buffer_idx + self.childrens_list.append(child) + child.parent = self + + new_len = len(child.token_mem_index_value) + child.node_value_len = new_len + child.node_prefix_total_len = self.node_prefix_total_len + new_len + return child + + +class HybridRadixCache: + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager + + self.mem_manager: Qwen3NextMemoryManager = mem_manager + + self._key_dtype = torch.int64 + self._value_dtype = torch.int64 + + self.root_node = HybridRadixNode() + self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) + self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 + + self.evict_kv_cache_tree_set: Set[HybridRadixNode] = SortedSet(key=lambda x: x.get_kv_cache_compare_key()) + self.evict_buffer_tree_set: Set[HybridRadixNode] = SortedSet(key=lambda x: x.get_buffer_compare_key()) + self.evict_kv_cache_tree_set.add(self.root_node) + self.evict_buffer_tree_set.add(self.root_node) + + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.refed_tokens_num.arr[0] = 0 + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_tokens_num.arr[0] = 0 + self.tree_total_buffers_num = SharedArray( + f"{unique_name}_tree_total_buffers_num_{rank_in_node}", (1,), dtype=np.int64 + ) + self.tree_total_buffers_num.arr[0] = 0 + + def update_evict_info(self, node: HybridRadixNode): + # Update time once at the beginning + if node == self.root_node: + return + + if node.has_buffer(): + # Remove and re-add to update position in sorted set + try: + self.evict_buffer_tree_set.discard(node) + except ValueError: + pass + if not node.is_leaf(): + # Remove and re-add to update position in sorted set + try: + self.evict_kv_cache_tree_set.discard(node) + except ValueError: + pass + + node.update_time() + + if node.has_buffer(): + self.evict_buffer_tree_set.add(node) + if not node.is_leaf(): + self.evict_kv_cache_tree_set.add(node) + return + + def insert(self, key, value, buffer_idx: int) -> Tuple[int, Optional[HybridRadixNode]]: + logger.info( + f"insert key len: {len(key)}, value len: {len(value)}, buffer_idx: {buffer_idx} key[:10]: {key[:10]}" + ) + assert key is not None and value is not None and buffer_idx is not None + assert len(key) == len(value) and len(key) >= 1 + + return self._insert_helper(self.root_node, key, value, buffer_idx, len(key), 0) + + def _insert_helper( + self, node: HybridRadixNode, key, value, buffer_idx, key_len, prefix_len + ) -> Tuple[int, Optional[HybridRadixNode]]: + # 插入的前提是已经完全覆盖当前节点 + # 遍历当前的所有子节点,找到第一个完全匹配的节点,继续插入 + # 如果找不到完全匹配的节点,则直接插入 + for child in node.childrens_list: + if key_len < child.node_value_len: + continue + if torch.equal(child.token_id_key, key[0 : child.node_value_len]): + # 完全匹配,继续向下插入 + return self._insert_helper( + child, + key[child.node_value_len :], + value[child.node_value_len :], + buffer_idx, + key_len - child.node_value_len, + prefix_len + child.node_value_len, + ) + + # 没有找到完全匹配的节点,直接插入 + # Prevent set corruption by removing node before modifying it (which changes is_leaf status) + if node != self.root_node: + try: + self.evict_kv_cache_tree_set.discard(node) + except ValueError: + pass + if node.has_buffer(): + try: + self.evict_buffer_tree_set.discard(node) + except ValueError: + pass + + new_child = node.add_and_return_new_child(key, value, buffer_idx) + new_child.update_time() + self.evict_kv_cache_tree_set.add(new_child) + self.evict_buffer_tree_set.add(new_child) + self.update_evict_info(node) + self.tree_total_tokens_num.arr[0] += len(value) + self.tree_total_buffers_num.arr[0] += 1 + return prefix_len, new_child + + def match_prefix(self, key, update_refs=False): + logger.info(f"match_prefix key len: {len(key)}, update_refs: {update_refs} key[:10]: {key[:10]}") + if len(key) == 0: + return None, 0, None + + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + + if tree_node != self.root_node and tree_node is not None: + if len(ans_value_list) != 0: + value = torch.cat(ans_value_list, dim=0) + else: + value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + logger.info(f"match_prefix success len: {len(value)}") + return tree_node, len(value), value + else: + logger.info("match_prefix failed") + return None, 0, None + + def _match_prefix_helper( + self, node: HybridRadixNode, key, ans_value_list, update_refs=False + ) -> Optional[HybridRadixNode]: + # 匹配的前提是已经完全覆盖当前节点 + # 遍历所有节点,假设完全匹配key, 则返回。 + + if len(key) == 0: + return node + + for child in node.childrens_list: + if len(key) < child.node_value_len: + continue + if torch.equal(child.token_id_key, key[0 : child.node_value_len]): + # 完全匹配,继续向下匹配 + ans_value_list.append(child.token_mem_index_value) + match_node = self._match_prefix_helper( + child, + key[child.node_value_len :], + ans_value_list, + update_refs=update_refs, + ) + if match_node is not None: + if update_refs: + self.add_node_ref_counter(child) + self.update_evict_info(child) + return match_node + else: + ans_value_list.pop() + return node + + def evict_kv_cache(self, need_remove_tokens, evict_memindexes, evict_buffer_indexes): + logger.info( + f"evict_kv_cache need: {need_remove_tokens}" + f"total: {self.tree_total_tokens_num.arr[0]}" + f"refed: {self.refed_tokens_num.arr[0]}" + ) + if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: + assert False, f"""can not free tree tokens {need_remove_tokens}, + tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, + refed_tokens_num {self.refed_tokens_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_tokens: + node: HybridRadixNode = self.evict_kv_cache_tree_set.pop(0) + assert ( + node.ref_counter == 0 and len(node.childrens_list) == 0 and node != self.root_node + ), "error evict tree node state" + num_evicted += len(node.token_mem_index_value) + evict_memindexes.append(node.token_mem_index_value) + if node.has_buffer(): + evict_buffer_indexes.append(node.buffer_idx) + self.tree_total_buffers_num.arr[0] -= 1 + # update total token num + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: HybridRadixNode = node.parent + + # Prevent set corruption by removing parent before modifying it + if parent_node != self.root_node: + try: + self.evict_kv_cache_tree_set.discard(parent_node) + except ValueError: + pass + if parent_node.has_buffer(): + try: + self.evict_buffer_tree_set.discard(parent_node) + except ValueError: + pass + + parent_node.remove_child(node) + self.update_evict_info(parent_node) + return + + def evict_buffer_cache(self, need_remove_buffers, evict_buffer_indexes): + if self.tree_total_buffers_num.arr[0] < need_remove_buffers: + assert False, f"""can not free tree buffers {need_remove_buffers}, + tree_total_buffers_num {self.tree_total_buffers_num.arr[0]}""" + num_evicted = 0 + while num_evicted < need_remove_buffers: + node: HybridRadixNode = self.evict_buffer_tree_set.pop(0) + assert node.has_buffer() and node != self.root_node, "error evict buffer node state" + num_evicted += 1 + evict_buffer_indexes.append(node.buffer_idx) + node.buffer_idx = None + self.update_evict_info(node) + return + + def free_radix_cache_to_get_enough_token(self, need_token_num, need_buffer_num=0): + logger.info( + f"free_radix_cache need_token: {need_token_num}" + f"need_buffer: {need_buffer_num}" + f"can_use: {self.mem_manager.can_use_mem_size}" + f"state_cache_can_use: {self.mem_manager.get_state_cache_can_use_size()}" + ) + if need_token_num > self.mem_manager.can_use_mem_size: + need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size + if need_evict_token_num > 0: + evict_memindexes = [] + evict_buffer_indexes = [] + self.evict_kv_cache(need_evict_token_num, evict_memindexes, evict_buffer_indexes) + evict_memindexes = torch.concat(evict_memindexes) + self.mem_manager.free(evict_memindexes) + self.mem_manager.free_state_cache_buffer(evict_buffer_indexes) + + if need_buffer_num > self.mem_manager.get_state_cache_can_use_size(): + need_evict_buffer_num = need_buffer_num - self.mem_manager.get_state_cache_can_use_size() + if need_evict_buffer_num > 0: + evict_buffer_indexes = [] + self.evict_buffer_cache(need_evict_buffer_num, evict_buffer_indexes) + self.mem_manager.free_state_cache_buffer(evict_buffer_indexes) + return + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def add_node_ref_counter(self, node: HybridRadixNode): + if node is None: + return + + while node is not None: + if node != self.root_node: + try: + self.evict_kv_cache_tree_set.discard(node) + except ValueError: + pass + + if node.ref_counter == 0: + self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) + node.ref_counter += 1 + + if node != self.root_node: + self.evict_kv_cache_tree_set.add(node) + + node = node.parent + return + + def dec_node_ref_counter(self, node: HybridRadixNode): + if node is None: + return + + while node is not None: + if node != self.root_node: + try: + self.evict_kv_cache_tree_set.discard(node) + except ValueError: + pass + + if node.ref_counter == 1: + self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) + node.ref_counter -= 1 + + if node != self.root_node: + self.evict_kv_cache_tree_set.add(node) + + node = node.parent + return + + +class _RadixCacheReadOnlyClient: + """ + router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。 + """ + + def __init__(self, unique_name, total_token_num, rank_in_node): + self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) + self.tree_total_tokens_num = SharedArray( + f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 + ) + + def get_refed_tokens_num(self): + return self.refed_tokens_num.arr[0] + + def get_tree_total_tokens_num(self): + return self.tree_total_tokens_num.arr[0] + + def get_unrefed_tokens_num(self): + return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] + + +class RadixCacheReadOnlyClient: + def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): + self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [ + _RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node) + for rank_in_node in range(0, node_world_size, dp_world_size) + ] + + def get_refed_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num() + + def get_tree_total_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num() + + def get_unrefed_tokens_num(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c51774898..0c7fd1c70 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -125,7 +125,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 - def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + def insert(self, key, value=None, buffer_idx=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key @@ -494,7 +494,7 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num): + def free_radix_cache_to_get_enough_token(self, need_token_num, need_buffer_num=0): assert self.mem_manager is not None if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index ab2965887..5919bfb5c 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -111,7 +111,10 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - prefix_len, _ = self.radix_cache.insert(key, value) + buffer_idx = None + if hasattr(self.req_manager, "req_to_buffer_indexes"): + buffer_idx = self.req_manager.req_to_buffer_indexes[req.req_idx].cpu() + prefix_len, _ = self.radix_cache.insert(key, value, buffer_idx) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -399,6 +402,12 @@ def _match_radix_cache(self): g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 + # NOTE 仅用于 Qwen3Next 的 HybridRadixCache. + if hasattr(share_node, "buffer_idx") and share_node.buffer_idx is not None: + cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_indexes[self.req_idx] + g_infer_context.req_manager.mem_manager.copy_state_cache_buffer( + share_node.buffer_idx, cur_buffer_idx + ) self.shm_req.shm_cur_kv_len = self.cur_kv_len return diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 82f3a8ddf..56293a7fa 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -1,6 +1,7 @@ from .chunked_prefill.impl import ChunkedPrefillBackend from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend +from .chunked_prefill.impl_for_hybrid_radix_cache import HybridRadixCacheBackend from .chunked_prefill.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend from .chunked_prefill.impl_for_reward_model import RewardModelBackend from .chunked_prefill.impl_for_token_healing import TokenHealingBackend diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index a780c4da0..838bc0c06 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,6 +10,8 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -163,8 +165,11 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"] + # Use HybridRadixCacheV2 as default for hybrid models + radix_cache_class = RadixCache if not is_hybrid_model else HybridRadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py new file mode 100644 index 000000000..d716d2035 --- /dev/null +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py @@ -0,0 +1,112 @@ +import torch +from .impl import ChunkedPrefillBackend +from typing import List +from typing_extensions import override +from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq +from lightllm.common.basemodel.infer_lock import g_infer_state_lock +from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack +from lightllm.server.router.model_infer.mode_backend.pre import ( + prepare_prefill_inputs, +) +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) + + +class HybridRadixCacheBackend(ChunkedPrefillBackend): + def __init__(self) -> None: + super().__init__() + logger.info("Using HybridRadixCacheBackend for hybrid attention model.") + + @override + def init_model(self, kvargs): + from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + + super().init_model(kvargs) + assert isinstance(self.radix_cache, HybridRadixCache) + return + + def prefill_normal( + self, + event_pack: OverlapEventPack, + prefill_reqs: List[InferReq], + ): + # 第一阶段: 模型推理 + model_input, run_reqs = prepare_prefill_inputs( + prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal + ) + with torch.cuda.stream(g_infer_context.get_overlap_stream()): + model_output = self.model.forward(model_input) + _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( + logits=model_output.logits, + b_req_idx=model_input.b_req_idx, + b_mtp_index=model_input.b_mtp_index, + run_reqs=run_reqs, + is_prefill=True, + b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, + mask_func=self.prefill_mask_func, + ) + sync_event = torch.cuda.Event() + sync_event.record() + + # 第二阶段 + event_pack.notify_post_handle_and_wait_pre_post_handle() + update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) + + # 第三阶段 + event_pack.notify_forward_and_wait_post_handle() + sync_event.synchronize() + self._post_handle( + run_reqs=run_reqs, + next_token_ids=next_token_ids_cpu, + next_token_logprobs=next_token_logprobs_cpu, + run_reqs_update_packs=update_packs, + extra_post_req_handle_func=self.extra_post_req_handle_func, + nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, + ) + + # 对于 Qwen3Next 模型,在 chunked_prefill 过程中调用 radix_cache.insert 保存中间状态 + if not self.disable_chunked_prefill: + for req in run_reqs: + if not req.is_multi_chat_req and req.cur_kv_len < req.get_cur_total_len(): + self._handle_qwen3next_radix_cache_insert(req) + + # 第四阶段 + event_pack.notify_pre_post_handle() + return + + def _handle_qwen3next_radix_cache_insert(self, req: "InferReq"): + # 在 chunked_prefill 过程中,为 Qwen3Next 模型保存 state buffer 中间状态到 radix cache + from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache + + is_qwen3next = isinstance(self.radix_cache, HybridRadixCache) + + if not is_qwen3next or self.radix_cache is None: + return + + g_infer_state_lock.acquire() + try: + # 获取当前 chunked_prefill 处理的 token IDs + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + + # 获取对应的 token 索引 + value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() + + buffer_idx = self.model.req_manager.req_to_buffer_indexes[req.req_idx].cpu() + + # 确保有足够的空间用于新的 buffer + self.radix_cache.free_radix_cache_to_get_enough_token(0, 1) + + # 分配新的 buffer 并复制当前 buffer 的内容 + new_buffer_idx = self.model.req_manager.mem_manager.alloc_state_cache_buffer(1)[0] + self.model.req_manager.mem_manager.copy_state_cache_buffer(buffer_idx, new_buffer_idx) + self.model.req_manager.req_to_buffer_indexes[req.req_idx] = new_buffer_idx + + _, new_shared_kv_node = self.radix_cache.insert(key, value, buffer_idx) + + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + self.radix_cache.add_node_ref_counter(new_shared_kv_node) + req.shared_kv_node = new_shared_kv_node + finally: + g_infer_state_lock.release() diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 55fe7a415..02edf9ff1 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -7,10 +7,12 @@ import setproctitle from datetime import timedelta from typing import Dict, List, Tuple +from transformers import PretrainedConfig from lightllm.server.router.model_infer.mode_backend import ( ChunkedPrefillBackend, FirstTokenConstraintBackend, OutlinesConstraintBackend, + HybridRadixCacheBackend, ReturnPromptLogProbBackend, RewardModelBackend, TokenHealingBackend, @@ -123,7 +125,13 @@ def init_model(self, kvargs): is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" is_nixl_decode_node = self.args.run_mode == "nixl_decode" - if is_prefill_node: + model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"]) + is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"] + use_hybrid_radix_cache = is_hybrid_model and not self.args.disable_dynamic_prompt_cache + + if use_hybrid_radix_cache: + self.backend = HybridRadixCacheBackend() + elif is_prefill_node: if self.args.dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue) else: diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index 022c5ab40..e2a210621 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -90,7 +90,7 @@ def get_current_device_name(): gpu_name = gpu_name.replace(" ", "_") return gpu_name else: - return None + raise RuntimeError("No GPU available") @lru_cache(maxsize=None) diff --git a/test/benchmark/service/benchmark_gsm8k.py b/test/benchmark/service/benchmark_gsm8k.py new file mode 100644 index 000000000..def3fbcb5 --- /dev/null +++ b/test/benchmark/service/benchmark_gsm8k.py @@ -0,0 +1,231 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + f"Invalid API response format. 'generated_text' should be a non-empty list," + f" got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=64) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + f"Warning: Requested {num_questions} questions, " + f"but only {max_available} available after reserving {num_shots} for few-shot. " + f"Using {max_available} questions." + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + # assert all(l != INVALID for l in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=few_shot_examples + questions[i], + temperature=0, + max_tokens=256, + stop=["Question", "Assistant:", "<|separator|>"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/test/test_api/test_chat.py b/test/test_api/test_chat.py new file mode 100644 index 000000000..8ad7d04fe --- /dev/null +++ b/test/test_api/test_chat.py @@ -0,0 +1,183 @@ +from openai import OpenAI +from datetime import datetime +import argparse +import threading +import random +from typing import List + + +class OpenAIMultiTurnChat: + def __init__(self, api_key: str, model: str = "gpt-3.5-turbo", base_url: str = None, client_id: int = 0): + """ + 初始化 OpenAI 多轮对话 + + Args: + api_key: OpenAI API 密钥 + model: 使用的模型名称 + base_url: API 基础 URL(如果使用代理或其他服务) + client_id: 客户端 ID(用于并发测试) + """ + self.client = OpenAI(api_key=api_key, base_url=base_url) + self.model = model + self.conversation_history = [] + self.client_id = client_id + + def add_message(self, role: str, content: str): + """添加消息到对话历史""" + self.conversation_history.append({"role": role, "content": content}) + + def get_response(self, user_message: str, verbose: bool = True) -> str: + """获取 AI 回复(流式)""" + self.add_message("user", user_message) + + try: + response = self.client.chat.completions.create( + model=self.model, messages=self.conversation_history, max_tokens=1000, stream=True + ) + + assistant_reply = "" + if verbose: + print(f"AI (用户_{self.client_id}): ", end="", flush=True) + + for chunk in response: + if chunk.choices[0].delta.content is not None: + content = chunk.choices[0].delta.content + assistant_reply += content + if verbose: + print(content, end="", flush=True) + + if verbose: + print() # 换行 + self.add_message("assistant", assistant_reply) + + return assistant_reply + + except Exception as e: + if verbose: + print(f"请求失败: {e}") + return "请求失败,请检查网络连接或 API 密钥" + + def start_conversation(self, system_prompt: str = None): + """开始新的对话""" + self.conversation_history = [] + + if system_prompt: + self.add_message("system", system_prompt) + + print(f"开始多轮对话 - 用户_{self.client_id} (输入 'quit' 或 'exit' 退出)") + print("-" * 50) + + while True: + user_input = input(f"用户_{self.client_id}: ").strip() + + if user_input.lower() in ["quit", "exit", "退出"]: + print("对话结束") + break + + if not user_input: + continue + + self.get_response(user_input) + print() + + +class ParallelChatManager: + """并发对话管理器""" + + def __init__(self, api_key: str, model: str, base_url: str, parallel: int, system_prompt: str = None): + """ + 初始化并发对话管理器 + + Args: + api_key: OpenAI API 密钥 + model: 使用的模型名称 + base_url: API 基础 URL + parallel: 并发客户端数量 + system_prompt: 系统提示词 + """ + self.clients: List[OpenAIMultiTurnChat] = [] + self.parallel = parallel + self.system_prompt = system_prompt + + # 创建多个客户端实例 + for i in range(parallel): + client = OpenAIMultiTurnChat(api_key=api_key, model=model, base_url=base_url, client_id=i) + if system_prompt: + client.add_message("system", system_prompt) + self.clients.append(client) + + def parallel_request(self, user_message: str): + """并发发送请求""" + responses = [None] * self.parallel + threads = [] + + # 随机选择一个客户端来打印输出 + verbose_client_id = random.randint(0, self.parallel - 1) + + def worker(client_idx: int): + verbose = client_idx == verbose_client_id + response = self.clients[client_idx].get_response(user_message, verbose=verbose) + responses[client_idx] = response + + # 启动所有线程 + for i in range(self.parallel): + thread = threading.Thread(target=worker, args=(i,)) + threads.append(thread) + thread.start() + + # 等待所有线程完成 + for thread in threads: + thread.join() + + return responses + + def start_conversation(self): + """开始并发对话""" + print(f"开始并发多轮对话 (并发数: {self.parallel})") + print("所有客户端输入相同内容,随机显示其中一个客户端的输出") + print("输入 'quit' 或 'exit' 退出") + print("-" * 50) + + while True: + user_input = input("用户输入: ").strip() + + if user_input.lower() in ["quit", "exit", "退出"]: + print("对话结束") + break + + if not user_input: + continue + + print(f"\n[并发请求中... 并发数: {self.parallel}]") + self.parallel_request(user_input) + print() + + +# 使用示例 +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="OpenAI 多轮对话客户端") + parser.add_argument("--port", type=int, default=13688, help="服务端口号 (默认: 13688)") + parser.add_argument("--host", type=str, default="localhost", help="服务主机地址 (默认: localhost)") + parser.add_argument("--api-key", type=str, default="", help="API 密钥 (默认: 空)") + parser.add_argument("--model", type=str, default="gpt-3.5-turbo", help="模型名称 (默认: gpt-3.5-turbo)") + parser.add_argument("--system-prompt", type=str, default="你是一个有用的助手。", help="系统提示词") + parser.add_argument("--parallel", type=int, default=1, help="并发客户端数量 (默认: 1, 不并发)") + + args = parser.parse_args() + + base_url = f"http://{args.host}:{args.port}/v1" + + if args.parallel > 1: + # 并发模式 + manager = ParallelChatManager( + api_key=args.api_key, + model=args.model, + base_url=base_url, + parallel=args.parallel, + system_prompt=args.system_prompt, + ) + manager.start_conversation() + else: + # 单客户端模式 + chat = OpenAIMultiTurnChat(api_key=args.api_key, model=args.model, base_url=base_url) + chat.start_conversation(args.system_prompt) From ee5c4df6d37b277ef0d422eee2105bd3e369451d Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Dec 2025 14:20:54 +0000 Subject: [PATCH 02/19] add radix cache hit rate --- .../router/dynamic_prompt/radix_cache.py | 65 +++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index 0c7fd1c70..d172999de 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -125,6 +125,16 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 + # Hit rate tracking + self.match_prefix_total_calls = SharedArray( + f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64 + ) + self.match_prefix_total_calls.arr[0] = 0 + self.match_prefix_hit_tokens = SharedArray( + f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64 + ) + self.match_prefix_hit_tokens.arr[0] = 0 + def insert(self, key, value=None, buffer_idx=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key @@ -232,6 +242,10 @@ def _insert_helper_no_recursion( def match_prefix(self, key, update_refs=False): assert len(key) != 0 + + # Track total calls + self.match_prefix_total_calls.arr[0] += 1 + ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if tree_node != self.root_node: @@ -239,6 +253,10 @@ def match_prefix(self, key, update_refs=False): value = torch.concat(ans_value_list) else: value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) + + # Track hit tokens + self.match_prefix_hit_tokens.arr[0] += len(value) + return tree_node, len(value), value else: self.dec_node_ref_counter(self.root_node) @@ -509,6 +527,24 @@ def release_mem(mem_index): self.mem_manager.free(mem_index) return + def get_match_prefix_hit_rate(self): + """Get the hit rate as a ratio of hit tokens to total requested tokens""" + total_calls = self.match_prefix_total_calls.arr[0] + if total_calls == 0: + return 0.0 + # We calculate hit rate as the average hit tokens per call + # Note: This is a simplified metric. For true hit rate, you might want to track total requested tokens + total_hit_tokens = self.match_prefix_hit_tokens.arr[0] + return total_hit_tokens / total_calls if total_calls > 0 else 0.0 + + def get_match_prefix_stats(self): + """Get detailed match_prefix statistics""" + return { + "total_calls": self.match_prefix_total_calls.arr[0], + "total_hit_tokens": self.match_prefix_hit_tokens.arr[0], + "hit_rate": self.get_match_prefix_hit_rate(), + } + class _RadixCacheReadOnlyClient: """ @@ -520,6 +556,13 @@ def __init__(self, unique_name, total_token_num, rank_in_node): self.tree_total_tokens_num = SharedArray( f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 ) + # Hit rate tracking + self.match_prefix_total_calls = SharedArray( + f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64 + ) + self.match_prefix_hit_tokens = SharedArray( + f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64 + ) def get_refed_tokens_num(self): return self.refed_tokens_num.arr[0] @@ -530,6 +573,22 @@ def get_tree_total_tokens_num(self): def get_unrefed_tokens_num(self): return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] + def get_match_prefix_hit_rate(self): + """Get the hit rate as a ratio of hit tokens to total calls""" + total_calls = self.match_prefix_total_calls.arr[0] + if total_calls == 0: + return 0.0 + total_hit_tokens = self.match_prefix_hit_tokens.arr[0] + return total_hit_tokens / total_calls if total_calls > 0 else 0.0 + + def get_match_prefix_stats(self): + """Get detailed match_prefix statistics""" + return { + "total_calls": self.match_prefix_total_calls.arr[0], + "total_hit_tokens": self.match_prefix_hit_tokens.arr[0], + "hit_rate": self.get_match_prefix_hit_rate(), + } + class RadixCacheReadOnlyClient: def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): @@ -546,3 +605,9 @@ def get_tree_total_tokens_num(self, dp_rank_in_node): def get_unrefed_tokens_num(self, dp_rank_in_node): return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() + + def get_match_prefix_hit_rate(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_hit_rate() + + def get_match_prefix_stats(self, dp_rank_in_node): + return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_stats() From b158299512a269c9d16d897f2274e7d293ffeec7 Mon Sep 17 00:00:00 2001 From: sufubao Date: Tue, 9 Dec 2025 15:32:03 +0000 Subject: [PATCH 03/19] hhh --- lightllm/models/qwen3next/mem_manager.py | 3 + lightllm/server/api_http.py | 2 +- .../dynamic_prompt/hybrid_radix_cache.py | 420 ------------------ .../router/dynamic_prompt/radix_cache.py | 21 +- lightllm/server/router/manager.py | 12 +- .../server/router/model_infer/infer_batch.py | 15 +- .../model_infer/mode_backend/base_backend.py | 8 +- .../impl_for_hybrid_radix_cache.py | 61 ++- lightllm/utils/log_utils.py | 13 + 9 files changed, 87 insertions(+), 468 deletions(-) delete mode 100644 lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index ab8fdd6a6..adf42bd59 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -47,6 +47,9 @@ def get_state_cache_buffer(self, layer_index): def get_state_cache_can_use_size(self): ... + def copy_state_cache_buffer(self, src_idx, tgt_idx): + pass + class Qwen3NextMemoryManager(MemoryManager, HaveStateBuffer): def __init__( diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 29f271c5e..32db64174 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -267,7 +267,7 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) - + logger.info(f"completions request: {request}") resp = await completions_impl(request, raw_request) return resp diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py deleted file mode 100644 index 95f4be1fa..000000000 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ /dev/null @@ -1,420 +0,0 @@ -import torch -import numpy as np -import collections -import xxhash -import threading -import time -from typing import Tuple, Dict, Set, List, Optional, Union -from typing_extensions import override -from sortedcontainers import SortedSet -from abc import ABC, abstractmethod -import math -from dataclasses import dataclass, field - -from .shared_arr import SharedArray -from .radix_cache import UniqueTimeIdGenerator -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - -time_gen = UniqueTimeIdGenerator() - - -class HybridRadixNode: - def __init__(self): - # Core data - self.edge: Tuple[int, ...] = () - self.childrens_list: List["HybridRadixNode"] = [] - self.parent: Optional["HybridRadixNode"] = None - - # LightLLM specific - self.token_id_key = torch.zeros((0,), device="cpu", dtype=torch.int64) - self.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=torch.int64) - self.buffer_idx: Optional[int] = None - - # Node metadata - self.node_value_len: int = 0 - self.node_prefix_total_len: int = 0 - self.ref_counter: int = 0 - self.time_id: int = 0 - - # Eviction metadata - self.hit_count: int = 0 - self.insert_time: float = 0.0 - self.last_access: float = 0.0 - self.node_id: int = 0 - - def is_leaf(self) -> bool: - return len(self.childrens_list) == 0 - - def has_buffer(self) -> bool: - return self.buffer_idx is not None - - def is_referenced(self) -> bool: - return self.ref_counter > 0 - - def collect_path_values(self) -> torch.Tensor: - """Collect all values from root to this node.""" - segments = [] - node = self - while node.parent is not None: - if len(node.token_mem_index_value) > 0: - segments.append(node.token_mem_index_value) - node = node.parent - - if not segments: - return torch.zeros((0,), device="cpu", dtype=torch.int64) - - # Reverse order and concatenate - segments.reverse() - return torch.cat(segments, dim=0) - - def update_time(self): - self.time_id = time_gen.generate_time_id() - self.last_access = time.time() - - def remove_child(self, child_node: "HybridRadixNode"): - child_node.parent = None - self.childrens_list.remove(child_node) - - def get_kv_cache_compare_key(self): - return (self.is_referenced(), not self.is_leaf(), self.has_buffer(), self.time_id) - - def get_buffer_compare_key(self): - return self.time_id - - def add_and_return_new_child(self, token_id_key, token_mem_index_value, buffer_idx): - child = HybridRadixNode() - child.token_id_key = token_id_key - child.token_mem_index_value = token_mem_index_value - child.buffer_idx = buffer_idx - self.childrens_list.append(child) - child.parent = self - - new_len = len(child.token_mem_index_value) - child.node_value_len = new_len - child.node_prefix_total_len = self.node_prefix_total_len + new_len - return child - - -class HybridRadixCache: - def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): - from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager - - self.mem_manager: Qwen3NextMemoryManager = mem_manager - - self._key_dtype = torch.int64 - self._value_dtype = torch.int64 - - self.root_node = HybridRadixNode() - self.root_node.token_id_key = torch.zeros((0,), device="cpu", dtype=self._key_dtype) - self.root_node.token_mem_index_value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) - self.root_node.ref_counter = 1 # 初始化为 1 保证永远不会被 evict 掉 - - self.evict_kv_cache_tree_set: Set[HybridRadixNode] = SortedSet(key=lambda x: x.get_kv_cache_compare_key()) - self.evict_buffer_tree_set: Set[HybridRadixNode] = SortedSet(key=lambda x: x.get_buffer_compare_key()) - self.evict_kv_cache_tree_set.add(self.root_node) - self.evict_buffer_tree_set.add(self.root_node) - - self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) - self.refed_tokens_num.arr[0] = 0 - self.tree_total_tokens_num = SharedArray( - f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 - ) - self.tree_total_tokens_num.arr[0] = 0 - self.tree_total_buffers_num = SharedArray( - f"{unique_name}_tree_total_buffers_num_{rank_in_node}", (1,), dtype=np.int64 - ) - self.tree_total_buffers_num.arr[0] = 0 - - def update_evict_info(self, node: HybridRadixNode): - # Update time once at the beginning - if node == self.root_node: - return - - if node.has_buffer(): - # Remove and re-add to update position in sorted set - try: - self.evict_buffer_tree_set.discard(node) - except ValueError: - pass - if not node.is_leaf(): - # Remove and re-add to update position in sorted set - try: - self.evict_kv_cache_tree_set.discard(node) - except ValueError: - pass - - node.update_time() - - if node.has_buffer(): - self.evict_buffer_tree_set.add(node) - if not node.is_leaf(): - self.evict_kv_cache_tree_set.add(node) - return - - def insert(self, key, value, buffer_idx: int) -> Tuple[int, Optional[HybridRadixNode]]: - logger.info( - f"insert key len: {len(key)}, value len: {len(value)}, buffer_idx: {buffer_idx} key[:10]: {key[:10]}" - ) - assert key is not None and value is not None and buffer_idx is not None - assert len(key) == len(value) and len(key) >= 1 - - return self._insert_helper(self.root_node, key, value, buffer_idx, len(key), 0) - - def _insert_helper( - self, node: HybridRadixNode, key, value, buffer_idx, key_len, prefix_len - ) -> Tuple[int, Optional[HybridRadixNode]]: - # 插入的前提是已经完全覆盖当前节点 - # 遍历当前的所有子节点,找到第一个完全匹配的节点,继续插入 - # 如果找不到完全匹配的节点,则直接插入 - for child in node.childrens_list: - if key_len < child.node_value_len: - continue - if torch.equal(child.token_id_key, key[0 : child.node_value_len]): - # 完全匹配,继续向下插入 - return self._insert_helper( - child, - key[child.node_value_len :], - value[child.node_value_len :], - buffer_idx, - key_len - child.node_value_len, - prefix_len + child.node_value_len, - ) - - # 没有找到完全匹配的节点,直接插入 - # Prevent set corruption by removing node before modifying it (which changes is_leaf status) - if node != self.root_node: - try: - self.evict_kv_cache_tree_set.discard(node) - except ValueError: - pass - if node.has_buffer(): - try: - self.evict_buffer_tree_set.discard(node) - except ValueError: - pass - - new_child = node.add_and_return_new_child(key, value, buffer_idx) - new_child.update_time() - self.evict_kv_cache_tree_set.add(new_child) - self.evict_buffer_tree_set.add(new_child) - self.update_evict_info(node) - self.tree_total_tokens_num.arr[0] += len(value) - self.tree_total_buffers_num.arr[0] += 1 - return prefix_len, new_child - - def match_prefix(self, key, update_refs=False): - logger.info(f"match_prefix key len: {len(key)}, update_refs: {update_refs} key[:10]: {key[:10]}") - if len(key) == 0: - return None, 0, None - - ans_value_list = [] - tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) - - if tree_node != self.root_node and tree_node is not None: - if len(ans_value_list) != 0: - value = torch.cat(ans_value_list, dim=0) - else: - value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) - logger.info(f"match_prefix success len: {len(value)}") - return tree_node, len(value), value - else: - logger.info("match_prefix failed") - return None, 0, None - - def _match_prefix_helper( - self, node: HybridRadixNode, key, ans_value_list, update_refs=False - ) -> Optional[HybridRadixNode]: - # 匹配的前提是已经完全覆盖当前节点 - # 遍历所有节点,假设完全匹配key, 则返回。 - - if len(key) == 0: - return node - - for child in node.childrens_list: - if len(key) < child.node_value_len: - continue - if torch.equal(child.token_id_key, key[0 : child.node_value_len]): - # 完全匹配,继续向下匹配 - ans_value_list.append(child.token_mem_index_value) - match_node = self._match_prefix_helper( - child, - key[child.node_value_len :], - ans_value_list, - update_refs=update_refs, - ) - if match_node is not None: - if update_refs: - self.add_node_ref_counter(child) - self.update_evict_info(child) - return match_node - else: - ans_value_list.pop() - return node - - def evict_kv_cache(self, need_remove_tokens, evict_memindexes, evict_buffer_indexes): - logger.info( - f"evict_kv_cache need: {need_remove_tokens}" - f"total: {self.tree_total_tokens_num.arr[0]}" - f"refed: {self.refed_tokens_num.arr[0]}" - ) - if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: - assert False, f"""can not free tree tokens {need_remove_tokens}, - tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, - refed_tokens_num {self.refed_tokens_num.arr[0]}""" - num_evicted = 0 - while num_evicted < need_remove_tokens: - node: HybridRadixNode = self.evict_kv_cache_tree_set.pop(0) - assert ( - node.ref_counter == 0 and len(node.childrens_list) == 0 and node != self.root_node - ), "error evict tree node state" - num_evicted += len(node.token_mem_index_value) - evict_memindexes.append(node.token_mem_index_value) - if node.has_buffer(): - evict_buffer_indexes.append(node.buffer_idx) - self.tree_total_buffers_num.arr[0] -= 1 - # update total token num - self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) - parent_node: HybridRadixNode = node.parent - - # Prevent set corruption by removing parent before modifying it - if parent_node != self.root_node: - try: - self.evict_kv_cache_tree_set.discard(parent_node) - except ValueError: - pass - if parent_node.has_buffer(): - try: - self.evict_buffer_tree_set.discard(parent_node) - except ValueError: - pass - - parent_node.remove_child(node) - self.update_evict_info(parent_node) - return - - def evict_buffer_cache(self, need_remove_buffers, evict_buffer_indexes): - if self.tree_total_buffers_num.arr[0] < need_remove_buffers: - assert False, f"""can not free tree buffers {need_remove_buffers}, - tree_total_buffers_num {self.tree_total_buffers_num.arr[0]}""" - num_evicted = 0 - while num_evicted < need_remove_buffers: - node: HybridRadixNode = self.evict_buffer_tree_set.pop(0) - assert node.has_buffer() and node != self.root_node, "error evict buffer node state" - num_evicted += 1 - evict_buffer_indexes.append(node.buffer_idx) - node.buffer_idx = None - self.update_evict_info(node) - return - - def free_radix_cache_to_get_enough_token(self, need_token_num, need_buffer_num=0): - logger.info( - f"free_radix_cache need_token: {need_token_num}" - f"need_buffer: {need_buffer_num}" - f"can_use: {self.mem_manager.can_use_mem_size}" - f"state_cache_can_use: {self.mem_manager.get_state_cache_can_use_size()}" - ) - if need_token_num > self.mem_manager.can_use_mem_size: - need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size - if need_evict_token_num > 0: - evict_memindexes = [] - evict_buffer_indexes = [] - self.evict_kv_cache(need_evict_token_num, evict_memindexes, evict_buffer_indexes) - evict_memindexes = torch.concat(evict_memindexes) - self.mem_manager.free(evict_memindexes) - self.mem_manager.free_state_cache_buffer(evict_buffer_indexes) - - if need_buffer_num > self.mem_manager.get_state_cache_can_use_size(): - need_evict_buffer_num = need_buffer_num - self.mem_manager.get_state_cache_can_use_size() - if need_evict_buffer_num > 0: - evict_buffer_indexes = [] - self.evict_buffer_cache(need_evict_buffer_num, evict_buffer_indexes) - self.mem_manager.free_state_cache_buffer(evict_buffer_indexes) - return - - def get_tree_total_tokens_num(self): - return self.tree_total_tokens_num.arr[0] - - def get_refed_tokens_num(self): - return self.refed_tokens_num.arr[0] - - def add_node_ref_counter(self, node: HybridRadixNode): - if node is None: - return - - while node is not None: - if node != self.root_node: - try: - self.evict_kv_cache_tree_set.discard(node) - except ValueError: - pass - - if node.ref_counter == 0: - self.refed_tokens_num.arr[0] += len(node.token_mem_index_value) - node.ref_counter += 1 - - if node != self.root_node: - self.evict_kv_cache_tree_set.add(node) - - node = node.parent - return - - def dec_node_ref_counter(self, node: HybridRadixNode): - if node is None: - return - - while node is not None: - if node != self.root_node: - try: - self.evict_kv_cache_tree_set.discard(node) - except ValueError: - pass - - if node.ref_counter == 1: - self.refed_tokens_num.arr[0] -= len(node.token_mem_index_value) - node.ref_counter -= 1 - - if node != self.root_node: - self.evict_kv_cache_tree_set.add(node) - - node = node.parent - return - - -class _RadixCacheReadOnlyClient: - """ - router 端只读用的客户端,用于从共享内存中读取树结构中的信息,用于进行prompt cache 的调度估计。 - """ - - def __init__(self, unique_name, total_token_num, rank_in_node): - self.refed_tokens_num = SharedArray(f"{unique_name}_refed_tokens_num_{rank_in_node}", (1,), dtype=np.int64) - self.tree_total_tokens_num = SharedArray( - f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 - ) - - def get_refed_tokens_num(self): - return self.refed_tokens_num.arr[0] - - def get_tree_total_tokens_num(self): - return self.tree_total_tokens_num.arr[0] - - def get_unrefed_tokens_num(self): - return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] - - -class RadixCacheReadOnlyClient: - def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): - self.dp_rank_clients: List[_RadixCacheReadOnlyClient] = [ - _RadixCacheReadOnlyClient(unique_name, total_token_num, rank_in_node) - for rank_in_node in range(0, node_world_size, dp_world_size) - ] - - def get_refed_tokens_num(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_refed_tokens_num() - - def get_tree_total_tokens_num(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_tree_total_tokens_num() - - def get_unrefed_tokens_num(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index d172999de..e9bcb979c 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,6 +31,11 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 + # 用于混合线性注意力模型, 例如Qwen3Next + # 在混合线性注意力情景中,buffer_idx 可以有值也可以为None + # 但是如果为None则不能作为最终改的匹配节点 + self.buffer_idx = None + def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -135,7 +140,7 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.match_prefix_hit_tokens.arr[0] = 0 - def insert(self, key, value=None, buffer_idx=None) -> Tuple[int, Optional[TreeNode]]: + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key @@ -338,12 +343,13 @@ def _match_prefix_helper_no_recursion( else: assert False, "error state" - def evict(self, need_remove_tokens, evict_callback): + def evict(self, need_remove_tokens, need_remove_buffers, evict_callback): if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: assert False, f"""can not free tree tokens {need_remove_tokens}, tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, refed_tokens_num {self.refed_tokens_num.arr[0]}""" num_evicted = 0 + release_buffers = [] while num_evicted < need_remove_tokens: node: TreeNode = self.evict_tree_set.pop(0) assert ( @@ -351,6 +357,7 @@ def evict(self, need_remove_tokens, evict_callback): ), "error evict tree node state" num_evicted += len(node.token_mem_index_value) evict_callback(node.token_mem_index_value) + release_buffers.append(node.buffer_idx) # update total token num self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) parent_node: TreeNode = node.parent @@ -358,7 +365,7 @@ def evict(self, need_remove_tokens, evict_callback): if parent_node.is_leaf(): self.evict_tree_set.add(parent_node) - return + return release_buffers def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: """ @@ -512,9 +519,9 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num, need_buffer_num=0): + def free_radix_cache_to_get_enough_token(self, need_token_num, need_evict_buffer_num=0): assert self.mem_manager is not None - if need_token_num > self.mem_manager.can_use_mem_size: + if need_token_num > self.mem_manager.can_use_mem_size or need_evict_buffer_num > 0: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size release_mems = [] @@ -522,10 +529,10 @@ def release_mem(mem_index): release_mems.append(mem_index) return - self.evict(need_evict_token_num, release_mem) + release_buffers = self.evict(need_evict_token_num, need_evict_buffer_num, release_mem) mem_index = torch.concat(release_mems) self.mem_manager.free(mem_index) - return + return release_buffers def get_match_prefix_hit_rate(self): """Get the hit rate as a ratio of hit tokens to total requested tokens""" diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 89c46d9ed..ae5092855 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -250,13 +250,23 @@ async def loop_for_fwd( frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) + + # Get hit rate from radix cache if available + hit_rate = 0.0 + if self.radix_cache_client is not None: + try: + hit_rate = self.radix_cache_client.get_match_prefix_hit_rate(d_i) + except Exception as e: + logger.warning(f"Failed to get hit rate from radix cache: {e}") + logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {paused_req_num} \n" f"dp_i {d_i} frozen token num: {frozen_token_num} \n" f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" - f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" + f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token\n" + f"dp_i {d_i} match_prefix hit_rate: {hit_rate:.4f}" ) self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) # pd decode mode need to update token_load more frequently diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5919bfb5c..22acb77f4 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -34,6 +34,8 @@ class InferenceContext: overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream + use_hybrid_radix_cache: bool = False + def register( self, backend, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int ): @@ -114,7 +116,7 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): buffer_idx = None if hasattr(self.req_manager, "req_to_buffer_indexes"): buffer_idx = self.req_manager.req_to_buffer_indexes[req.req_idx].cpu() - prefix_len, _ = self.radix_cache.insert(key, value, buffer_idx) + prefix_len, _ = self.radix_cache.insert(key, value, buffer_idx=buffer_idx) old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -395,6 +397,13 @@ def _match_radix_cache(self): key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) + + if share_node is not None: + if g_infer_context.use_hybrid_radix_cache: + if share_node.buffer_idx is None: + g_infer_context.radix_cache.dec_node_ref_counter(share_node) + share_node = None + if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -402,8 +411,8 @@ def _match_radix_cache(self): g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 - # NOTE 仅用于 Qwen3Next 的 HybridRadixCache. - if hasattr(share_node, "buffer_idx") and share_node.buffer_idx is not None: + + if g_infer_context.use_hybrid_radix_cache: cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_indexes[self.req_idx] g_infer_context.req_manager.mem_manager.copy_state_cache_buffer( share_node.buffer_idx, cur_buffer_idx diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 838bc0c06..b311ab15c 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,7 +10,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache -from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad @@ -42,6 +41,8 @@ from lightllm.server.pd_io_struct import NIXLChunckedTransTaskRet from .multi_level_kv_cache import MultiLevelKvCacheModule +logger = init_logger(__name__) + class ModeBackend: def __init__(self) -> None: @@ -165,11 +166,8 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) - is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"] - # Use HybridRadixCacheV2 as default for hybrid models - radix_cache_class = RadixCache if not is_hybrid_model else HybridRadixCache self.radix_cache = ( - radix_cache_class( + RadixCache( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py index d716d2035..e91034680 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py @@ -1,6 +1,9 @@ +from numpy import ndarray + + import torch from .impl import ChunkedPrefillBackend -from typing import List +from typing import Any, List from typing_extensions import override from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq from lightllm.common.basemodel.infer_lock import g_infer_state_lock @@ -17,6 +20,7 @@ class HybridRadixCacheBackend(ChunkedPrefillBackend): def __init__(self) -> None: super().__init__() logger.info("Using HybridRadixCacheBackend for hybrid attention model.") + g_infer_context.use_hybrid_radix_cache = True @override def init_model(self, kvargs): @@ -65,48 +69,43 @@ def prefill_normal( nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - # 对于 Qwen3Next 模型,在 chunked_prefill 过程中调用 radix_cache.insert 保存中间状态 if not self.disable_chunked_prefill: for req in run_reqs: - if not req.is_multi_chat_req and req.cur_kv_len < req.get_cur_total_len(): - self._handle_qwen3next_radix_cache_insert(req) + # NOTE 忽略完整的prefill, 因为请求文本全是system prompt 的情况应该比较小 + if req.cur_kv_len < req.get_cur_total_len() - 1: + self._handle_radix_cache_insert(req) # 第四阶段 event_pack.notify_pre_post_handle() return - def _handle_qwen3next_radix_cache_insert(self, req: "InferReq"): - # 在 chunked_prefill 过程中,为 Qwen3Next 模型保存 state buffer 中间状态到 radix cache - from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache - - is_qwen3next = isinstance(self.radix_cache, HybridRadixCache) + def _handle_radix_cache_insert(self, req: "InferReq"): + from lightllm.models.qwen3next.mem_manager import HaveStateBuffer + from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager - if not is_qwen3next or self.radix_cache is None: - return + assert isinstance(self.model.req_manager.mem_manager, HaveStateBuffer) + assert isinstance(self.model.req_manager, Qwen3NextReqManager) - g_infer_state_lock.acquire() - try: - # 获取当前 chunked_prefill 处理的 token IDs - input_token_ids = req.get_input_token_ids() - key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + # 获取当前 chunked_prefill 处理的 token IDs + input_token_ids: Any | ndarray[Any, Any] = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") - # 获取对应的 token 索引 - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() + # 获取对应的 token 索引 + value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() - buffer_idx = self.model.req_manager.req_to_buffer_indexes[req.req_idx].cpu() + buffer_idx = self.model.req_manager.req_to_buffer_indexes[req.req_idx].cpu() - # 确保有足够的空间用于新的 buffer - self.radix_cache.free_radix_cache_to_get_enough_token(0, 1) + # 确保有足够的空间用于新的 buffer + release_buffers = self.radix_cache.free_radix_cache_to_get_enough_token(0, 1) - # 分配新的 buffer 并复制当前 buffer 的内容 - new_buffer_idx = self.model.req_manager.mem_manager.alloc_state_cache_buffer(1)[0] - self.model.req_manager.mem_manager.copy_state_cache_buffer(buffer_idx, new_buffer_idx) - self.model.req_manager.req_to_buffer_indexes[req.req_idx] = new_buffer_idx + # 分配新的 buffer 并复制当前 buffer 的内容 + self.model.req_manager.mem_manager.free_state_cache_buffer(release_buffers) + new_buffer_idx = self.model.req_manager.mem_manager.alloc_state_cache_buffer(1)[0] + self.model.req_manager.mem_manager.copy_state_cache_buffer(buffer_idx, new_buffer_idx) + self.model.req_manager.req_to_buffer_indexes[req.req_idx] = new_buffer_idx - _, new_shared_kv_node = self.radix_cache.insert(key, value, buffer_idx) + _, new_shared_kv_node = self.radix_cache.insert(key, value, buffer_idx) - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - self.radix_cache.add_node_ref_counter(new_shared_kv_node) - req.shared_kv_node = new_shared_kv_node - finally: - g_infer_state_lock.release() + self.radix_cache.dec_node_ref_counter(req.shared_kv_node) + self.radix_cache.add_node_ref_counter(new_shared_kv_node) + req.shared_kv_node = new_shared_kv_node diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index f15309d5c..0e528dd45 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -29,6 +29,17 @@ def format(self, record): return msg +class RankFilter(logging.Filter): + def filter(self, record): + from lightllm.utils.dist_utils import get_current_rank_in_dp + + try: + rank = get_current_rank_in_dp() + return rank == 0 + except: + return False + + _root_logger = logging.getLogger("lightllm") _default_handler = None _default_file_handler = None @@ -45,6 +56,7 @@ def _setup_logger(): _default_handler = logging.StreamHandler(sys.stdout) _default_handler.flush = sys.stdout.flush # type: ignore _default_handler.setLevel(_LOG_LEVEL) + _default_handler.addFilter(RankFilter()) _root_logger.addHandler(_default_handler) if _default_file_handler is None and _LOG_DIR is not None: @@ -56,6 +68,7 @@ def _setup_logger(): _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) _default_file_handler.setFormatter(fmt) + _default_file_handler.addFilter(RankFilter()) _root_logger.addHandler(_default_file_handler) _default_handler.setFormatter(fmt) From 299ff47833ccb2c385c7db59384f6af6a9482ada Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 10 Dec 2025 07:03:10 +0000 Subject: [PATCH 04/19] reset --- .../router/dynamic_prompt/radix_cache.py | 84 ++----------------- 1 file changed, 6 insertions(+), 78 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index e9bcb979c..c51774898 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,11 +31,6 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 - # 用于混合线性注意力模型, 例如Qwen3Next - # 在混合线性注意力情景中,buffer_idx 可以有值也可以为None - # 但是如果为None则不能作为最终改的匹配节点 - self.buffer_idx = None - def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) @@ -130,16 +125,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) ) self.tree_total_tokens_num.arr[0] = 0 - # Hit rate tracking - self.match_prefix_total_calls = SharedArray( - f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64 - ) - self.match_prefix_total_calls.arr[0] = 0 - self.match_prefix_hit_tokens = SharedArray( - f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64 - ) - self.match_prefix_hit_tokens.arr[0] = 0 - def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: if value is None: value = key @@ -247,10 +232,6 @@ def _insert_helper_no_recursion( def match_prefix(self, key, update_refs=False): assert len(key) != 0 - - # Track total calls - self.match_prefix_total_calls.arr[0] += 1 - ans_value_list = [] tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) if tree_node != self.root_node: @@ -258,10 +239,6 @@ def match_prefix(self, key, update_refs=False): value = torch.concat(ans_value_list) else: value = torch.zeros((0,), device="cpu", dtype=self._value_dtype) - - # Track hit tokens - self.match_prefix_hit_tokens.arr[0] += len(value) - return tree_node, len(value), value else: self.dec_node_ref_counter(self.root_node) @@ -343,13 +320,12 @@ def _match_prefix_helper_no_recursion( else: assert False, "error state" - def evict(self, need_remove_tokens, need_remove_buffers, evict_callback): + def evict(self, need_remove_tokens, evict_callback): if self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] < need_remove_tokens: assert False, f"""can not free tree tokens {need_remove_tokens}, tree_total_tokens_num {self.tree_total_tokens_num.arr[0]}, refed_tokens_num {self.refed_tokens_num.arr[0]}""" num_evicted = 0 - release_buffers = [] while num_evicted < need_remove_tokens: node: TreeNode = self.evict_tree_set.pop(0) assert ( @@ -357,7 +333,6 @@ def evict(self, need_remove_tokens, need_remove_buffers, evict_callback): ), "error evict tree node state" num_evicted += len(node.token_mem_index_value) evict_callback(node.token_mem_index_value) - release_buffers.append(node.buffer_idx) # update total token num self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) parent_node: TreeNode = node.parent @@ -365,7 +340,7 @@ def evict(self, need_remove_tokens, need_remove_buffers, evict_callback): if parent_node.is_leaf(): self.evict_tree_set.add(parent_node) - return release_buffers + return def _try_merge(self, child_node: TreeNode) -> Optional[TreeNode]: """ @@ -519,9 +494,9 @@ def _print_helper(self, node: TreeNode, indent): self._print_helper(child, indent=indent + 2) return - def free_radix_cache_to_get_enough_token(self, need_token_num, need_evict_buffer_num=0): + def free_radix_cache_to_get_enough_token(self, need_token_num): assert self.mem_manager is not None - if need_token_num > self.mem_manager.can_use_mem_size or need_evict_buffer_num > 0: + if need_token_num > self.mem_manager.can_use_mem_size: need_evict_token_num = need_token_num - self.mem_manager.can_use_mem_size release_mems = [] @@ -529,28 +504,10 @@ def release_mem(mem_index): release_mems.append(mem_index) return - release_buffers = self.evict(need_evict_token_num, need_evict_buffer_num, release_mem) + self.evict(need_evict_token_num, release_mem) mem_index = torch.concat(release_mems) self.mem_manager.free(mem_index) - return release_buffers - - def get_match_prefix_hit_rate(self): - """Get the hit rate as a ratio of hit tokens to total requested tokens""" - total_calls = self.match_prefix_total_calls.arr[0] - if total_calls == 0: - return 0.0 - # We calculate hit rate as the average hit tokens per call - # Note: This is a simplified metric. For true hit rate, you might want to track total requested tokens - total_hit_tokens = self.match_prefix_hit_tokens.arr[0] - return total_hit_tokens / total_calls if total_calls > 0 else 0.0 - - def get_match_prefix_stats(self): - """Get detailed match_prefix statistics""" - return { - "total_calls": self.match_prefix_total_calls.arr[0], - "total_hit_tokens": self.match_prefix_hit_tokens.arr[0], - "hit_rate": self.get_match_prefix_hit_rate(), - } + return class _RadixCacheReadOnlyClient: @@ -563,13 +520,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node): self.tree_total_tokens_num = SharedArray( f"{unique_name}_tree_total_tokens_num_{rank_in_node}", (1,), dtype=np.int64 ) - # Hit rate tracking - self.match_prefix_total_calls = SharedArray( - f"{unique_name}_match_prefix_total_calls_{rank_in_node}", (1,), dtype=np.int64 - ) - self.match_prefix_hit_tokens = SharedArray( - f"{unique_name}_match_prefix_hit_tokens_{rank_in_node}", (1,), dtype=np.int64 - ) def get_refed_tokens_num(self): return self.refed_tokens_num.arr[0] @@ -580,22 +530,6 @@ def get_tree_total_tokens_num(self): def get_unrefed_tokens_num(self): return self.tree_total_tokens_num.arr[0] - self.refed_tokens_num.arr[0] - def get_match_prefix_hit_rate(self): - """Get the hit rate as a ratio of hit tokens to total calls""" - total_calls = self.match_prefix_total_calls.arr[0] - if total_calls == 0: - return 0.0 - total_hit_tokens = self.match_prefix_hit_tokens.arr[0] - return total_hit_tokens / total_calls if total_calls > 0 else 0.0 - - def get_match_prefix_stats(self): - """Get detailed match_prefix statistics""" - return { - "total_calls": self.match_prefix_total_calls.arr[0], - "total_hit_tokens": self.match_prefix_hit_tokens.arr[0], - "hit_rate": self.get_match_prefix_hit_rate(), - } - class RadixCacheReadOnlyClient: def __init__(self, unique_name, total_token_num, node_world_size, dp_world_size): @@ -612,9 +546,3 @@ def get_tree_total_tokens_num(self, dp_rank_in_node): def get_unrefed_tokens_num(self, dp_rank_in_node): return self.dp_rank_clients[dp_rank_in_node].get_unrefed_tokens_num() - - def get_match_prefix_hit_rate(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_hit_rate() - - def get_match_prefix_stats(self, dp_rank_in_node): - return self.dp_rank_clients[dp_rank_in_node].get_match_prefix_stats() From 1f574f2fd77a87804b0c6f67792ef8d01870fb60 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 10 Dec 2025 09:12:02 +0000 Subject: [PATCH 05/19] draft --- lightllm/common/basemodel/infer_struct.py | 3 + .../layer_infer/transformer_layer_infer.py | 3 +- lightllm/models/qwen3next/mem_manager.py | 30 +---- lightllm/models/qwen3next/model.py | 25 ++-- lightllm/models/qwen3next/req_manager.py | 42 ------ lightllm/server/core/objs/start_args_type.py | 2 - .../dynamic_prompt/hybrid_radix_cache.py | 124 ++++++++++++++++++ .../router/dynamic_prompt/radix_cache.py | 6 + .../server/router/model_infer/infer_batch.py | 24 +--- .../model_infer/mode_backend/__init__.py | 1 - .../model_infer/mode_backend/base_backend.py | 7 +- .../mode_backend/chunked_prefill/impl.py | 5 + .../impl_for_hybrid_radix_cache.py | 111 ---------------- .../server/router/model_infer/model_rpc.py | 10 +- lightllm/utils/log_utils.py | 4 +- 15 files changed, 175 insertions(+), 222 deletions(-) delete mode 100644 lightllm/models/qwen3next/req_manager.py create mode 100644 lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py delete mode 100644 lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 85d3d8c46..34a450a9f 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -88,6 +88,9 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None + # 专门用于管理混合注意力模型的buffer + self.buffer_indexes: torch.Tensor = None + def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: ( diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 0f49cba25..08e7ced77 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -7,7 +7,6 @@ from lightllm.utils.log_utils import init_logger from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager -from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager from lightllm.models.llama.infer_struct import LlamaInferStateInfo from typing import Tuple from typing_extensions import override @@ -250,7 +249,7 @@ def _linear_attn( ): assert layer_weight.is_linear, "layer_weight must be linear" assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager) - assert isinstance(infer_state.req_manager, Qwen3NextReqManager) + input = input.view(-1, infer_cls.embed_dim_) buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx] conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index adf42bd59..cf45bff49 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -9,6 +9,7 @@ from lightllm.utils.envs_utils import get_unique_server_name from lightllm.utils.dist_utils import get_current_rank_in_node from lightllm.server.router.dynamic_prompt.shared_arr import SharedInt +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridMemManager logger = init_logger(__name__) @@ -34,24 +35,7 @@ def get_cell_size(self): return np.prod(self.shape) * self.layer_num * torch._utils._element_size(self.dtype) -class HaveStateBuffer(Protocol): - def alloc_state_cache_buffer(self, need_size): - ... - - def free_state_cache_buffer(self, free_buffer_indexes): - ... - - def get_state_cache_buffer(self, layer_index): - ... - - def get_state_cache_can_use_size(self): - ... - - def copy_state_cache_buffer(self, src_idx, tgt_idx): - pass - - -class Qwen3NextMemoryManager(MemoryManager, HaveStateBuffer): +class Qwen3NextMemoryManager(HybridMemManager): def __init__( self, full_attn_cache_size, @@ -121,14 +105,14 @@ def free_all(self): return @override - def get_state_cache_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]: + def get_buffer(self, layer_index) -> Tuple[torch.Tensor, torch.Tensor]: assert layer_index < self.layer_num, "layer_index is out of range" assert (layer_index + 1) % self.full_attention_interval != 0, "layer_index is not linear attention layer" real_layer_index = layer_index - layer_index // self.full_attention_interval return self.conv_state_mem_manager.buffer[real_layer_index], self.ssm_state_mem_manager.buffer[real_layer_index] @override - def free_state_cache_buffer(self, free_buffer_indexes: List[int], reset=True): + def free_buffer(self, free_buffer_indexes: List[int], reset=True): # conv_state 和 ssm_state 共享buffer_idx self.conv_state_mem_manager.free(free_buffer_indexes) if reset: @@ -136,17 +120,17 @@ def free_state_cache_buffer(self, free_buffer_indexes: List[int], reset=True): self.ssm_state_mem_manager.buffer[:, free_buffer_indexes] = 0 @override - def alloc_state_cache_buffer(self, need_size): + def alloc_buffer(self, need_size): # conv_state 和 ssm_state 共享buffer_idx buffer_indexes = self.conv_state_mem_manager.alloc(need_size) return buffer_indexes @override - def get_state_cache_can_use_size(self): + def get_buffer_can_use_size(self): return self.conv_state_mem_manager.can_use_mem_size @override - def copy_state_cache_buffer(self, src_idx, tgt_idx): + def copy_buffer(self, src_idx, tgt_idx): assert src_idx is not None and tgt_idx is not None assert src_idx != tgt_idx # Use slice operation and in-place copy for better performance diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 27289fc19..a31e4b16d 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -9,8 +9,8 @@ from lightllm.distributed.communication_op import dist_group_manager from lightllm.utils.envs_utils import get_env_start_args from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager -from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager from lightllm.server.core.objs.start_args_type import StartArgs +from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput logger = init_logger(__name__) @@ -25,6 +25,7 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): post_layer_infer_class = Qwen3NextPostLayerInfer def __init__(self, kvargs) -> None: + self.mem_manager: Qwen3NextMemoryManager = None super().__init__(kvargs) @override @@ -85,13 +86,15 @@ def _init_mem_manager(self): mem_fraction=self.mem_fraction, ) - @override - def _init_req_manager(self): - create_max_seq_len = 0 - - if self.batch_max_tokens is not None: - create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) - if self.max_seq_length is not None: - create_max_seq_len = max(create_max_seq_len, self.max_seq_length) - - self.req_manager = Qwen3NextReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) + def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): + from lightllm.common.basemodel.infer_lock import g_infer_state_lock + from lightllm.common.basemodel.infer_context import g_infer_context + + infer_state = super()._create_inferstate(model_input, microbatch_index) + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(infer_state.batch_size) + buffer_indexes = self.mem_manager.alloc_buffer(infer_state.batch_size) + g_infer_state_lock.release() + infer_state.buffer_indexes = buffer_indexes + return infer_state diff --git a/lightllm/models/qwen3next/req_manager.py b/lightllm/models/qwen3next/req_manager.py deleted file mode 100644 index 31df6883f..000000000 --- a/lightllm/models/qwen3next/req_manager.py +++ /dev/null @@ -1,42 +0,0 @@ -from typing import List, Dict -from typing_extensions import override -import torch - -from lightllm.common.req_manager import ReqManager -from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager -from lightllm.utils.envs_utils import get_env_start_args - - -class Qwen3NextReqManager(ReqManager): - def __init__(self, max_request_num, max_sequence_length, mem_manager: Qwen3NextMemoryManager): - super().__init__(max_request_num, max_sequence_length, mem_manager) - self.mem_manager: Qwen3NextMemoryManager = self.mem_manager - self.enable_dynamic_prompt_cache = not get_env_start_args().disable_dynamic_prompt_cache - - self.req_to_buffer_indexes = torch.zeros((max_request_num + 1), dtype=torch.int32, device="cuda") - self.req_to_buffer_indexes[:] = self.mem_manager.EMPTY_BUFFER_INDEX - self.req_to_buffer_indexes[self.HOLD_REQUEST_ID] = self.mem_manager.HOLD_BUFFER_INDEX - - @override - def alloc(self): - from lightllm.server.router.model_infer.infer_batch import g_infer_state_lock, g_infer_context - - req_idx = super().alloc() - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(0, 1) - self.req_to_buffer_indexes[req_idx] = self.mem_manager.alloc_state_cache_buffer(1) - g_infer_state_lock.release() - return req_idx - - @override - def free(self, free_req_indexes: List[int], free_token_index): - super().free(free_req_indexes, free_token_index) - self.req_to_buffer_indexes[free_req_indexes] = self.mem_manager.EMPTY_BUFFER_INDEX - - @override - def free_all(self): - super().free_all() - self.req_to_buffer_indexes[:] = self.mem_manager.EMPTY_BUFFER_INDEX - self.req_to_buffer_indexes[self.HOLD_REQUEST_ID] = self.mem_manager.HOLD_BUFFER_INDEX - return diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index ef970c412..79fc14dd7 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -63,8 +63,6 @@ class StartArgs: token_healing_mode: bool = field(default=False) output_constraint_mode: str = field(default="none", metadata={"choices": ["outlines", "xgrammar", "none"]}) first_token_constraint_mode: bool = field(default=False) - enable_multimodal: bool = field(default=False) - enable_multimodal_audio: bool = field(default=False) enable_tpsp_mix_mode: bool = field(default=False) enable_dp_prefill_balance: bool = field(default=False) enable_decode_microbatch_overlap: bool = field(default=False) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py new file mode 100644 index 000000000..5858ddf0a --- /dev/null +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -0,0 +1,124 @@ +from typing import Set, Protocol, List + +import torch +from sortedcontainers import SortedSet + +from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode +from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager +from lightllm.server.router.model_infer.infer_batch import InferReq + + +class HybridMemManager(MemoryManager): + def alloc_buffer(self, need_size): + ... + + def free_buffer(self, free_buffer_indexes): + ... + + def get_buffer(self, layer_index): + ... + + def get_buffer_can_use_size(self): + ... + + def copy_buffer(self, src_idx, tgt_idx): + ... + + +class HybridRadixCache(RadixCache): + def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): + self.mem_manager: HybridMemManager = mem_manager + super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) + self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: x.time_id) + self.evict_buffer_set.add(self.root_node) + + def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): + if need_buffer_num > self.mem_manager.get_buffer_can_use_size(): + need_evict_buffer_num = need_buffer_num - self.mem_manager.get_buffer_can_use_size() + + release_mems = [] + + def release_mem(mem_index): + release_mems.append(mem_index) + return + + release_buffers = [] + + def release_buffer(buffer_idx): + release_buffers.append(buffer_idx) + return + + self.evict_buffer(need_evict_buffer_num, release_buffer, release_mem) + self.mem_manager.free_buffer(release_buffers) + if len(release_mems) > 0: + mem_index = torch.concat(release_mems) + self.mem_manager.free(mem_index) + return + + def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): + while need_evict_buffer_num > 0: + node = self.evict_buffer_set.pop() + if node.buffer_idx is not None: + evict_buffer_callback(node.buffer_idx) + need_evict_buffer_num -= 1 + else: + # 在混合注意力模型的情景里,只能匹配 buffer_idx 不为 None的节点 + # 假如 buffer_idx 为 None,则当做匹配失败。 + # 所以可以直接把这个节点给释放掉 + if node.is_leaf() and node.ref_counter == 0: + self._remove_leaf_node(node) + return + + def insert_for_hybrid_radix_cache(self, reqs: List["InferReq"]): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + from lightllm.common.basemodel.infer_lock import g_infer_state_lock + + # 确保有足够的空间用于新的 buffer + g_infer_state_lock.acquire() + self.free_radix_cache_to_get_enough_buffer(len(reqs)) + new_buffer_indexes = self.mem_manager.alloc_buffer(len(reqs)) + g_infer_state_lock.release() + + for i, req in enumerate(reqs): + input_token_ids = req.get_input_token_ids() + key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") + value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() + buffer_idx = req.buffer_idx + + # 分配新的 buffer 并复制当前 buffer 的内容 + self.mem_manager.copy_buffer(buffer_idx, new_buffer_indexes[i]) + req.buffer_idx = new_buffer_indexes[i] + + _, new_shared_kv_node = self.insert(key, value) + new_shared_kv_node.buffer_idx = buffer_idx + self.dec_node_ref_counter(req.shared_kv_node) + self.add_node_ref_counter(new_shared_kv_node) + req.shared_kv_node = new_shared_kv_node + + def match_prefix(self, key, update_refs=False): + assert len(key) != 0 + ans_value_list = [] + tree_node = self._match_prefix_helper(self.root_node, key, ans_value_list, update_refs=update_refs) + + while tree_node != self.root_node and tree_node.buffer_idx is None: + self.dec_node_ref_counter(tree_node) + if tree_node.is_leaf() and tree_node.ref_counter == 0: + tree_node = self._remove_leaf_node(tree_node) + else: + tree_node = tree_node.parent + ans_value_list.pop() + + if tree_node == self.root_node: + return None, 0, None + + value = torch.concat(ans_value_list) + return tree_node, len(value), value + + def _remove_leaf_node(self, node: TreeNode): + self.evict_tree_set.discard(node) + self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) + parent_node: TreeNode = node.parent + parent_node.remove_child(node) + if parent_node.is_leaf(): + self.evict_tree_set.add(parent_node) + return parent_node diff --git a/lightllm/server/router/dynamic_prompt/radix_cache.py b/lightllm/server/router/dynamic_prompt/radix_cache.py index c51774898..12b15e7dc 100644 --- a/lightllm/server/router/dynamic_prompt/radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/radix_cache.py @@ -31,6 +31,12 @@ def __init__(self): self.node_value_len = 0 self.node_prefix_total_len = 0 + # 专门用于管理混合注意力模型(例如 Qwen3Next), + # 该类模型每个请求需要管理一个唯一的buffer_idx, + # 放在这里让该类模型能够复用当前的radix_cache代码。 + # 纯注意力模型该 buffer_idx 始终保持为 None + self.buffer_idx = None + def get_compare_key(self): return (0 if self.ref_counter == 0 else 1, len(self.children), self.time_id) diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 22acb77f4..5d013e4c7 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -113,10 +113,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): # .cpu() 是 流内阻塞操作 value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() - buffer_idx = None - if hasattr(self.req_manager, "req_to_buffer_indexes"): - buffer_idx = self.req_manager.req_to_buffer_indexes[req.req_idx].cpu() - prefix_len, _ = self.radix_cache.insert(key, value, buffer_idx=buffer_idx) + prefix_len, node = self.radix_cache.insert(key, value) + node.buffer_idx = req.buffer_idx old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -345,6 +343,10 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 + # 可以用于请求在整个生命周期维护单一大小的buffer的场景 + # 例如混合注意力模型 Qwen3Next + self.buffer_idx = -1 + # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED @@ -397,13 +399,6 @@ def _match_radix_cache(self): key = torch.tensor(input_token_ids, dtype=torch.int64, device="cpu") key = key[0 : len(key) - 1] # 最后一个不需要,因为需要一个额外的token,让其在prefill的时候输出下一个token的值 share_node, kv_len, value_tensor = g_infer_context.radix_cache.match_prefix(key, update_refs=True) - - if share_node is not None: - if g_infer_context.use_hybrid_radix_cache: - if share_node.buffer_idx is None: - g_infer_context.radix_cache.dec_node_ref_counter(share_node) - share_node = None - if share_node is not None: self.shared_kv_node = share_node ready_cache_len = share_node.node_prefix_total_len @@ -411,12 +406,7 @@ def _match_radix_cache(self): g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 - - if g_infer_context.use_hybrid_radix_cache: - cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_indexes[self.req_idx] - g_infer_context.req_manager.mem_manager.copy_state_cache_buffer( - share_node.buffer_idx, cur_buffer_idx - ) + self.buffer_idx = share_node.buffer_idx self.shm_req.shm_cur_kv_len = self.cur_kv_len return diff --git a/lightllm/server/router/model_infer/mode_backend/__init__.py b/lightllm/server/router/model_infer/mode_backend/__init__.py index 56293a7fa..82f3a8ddf 100644 --- a/lightllm/server/router/model_infer/mode_backend/__init__.py +++ b/lightllm/server/router/model_infer/mode_backend/__init__.py @@ -1,7 +1,6 @@ from .chunked_prefill.impl import ChunkedPrefillBackend from .chunked_prefill.impl_for_first_token_constraint_mode import FirstTokenConstraintBackend from .chunked_prefill.impl_for_outlines_constraint_mode import OutlinesConstraintBackend -from .chunked_prefill.impl_for_hybrid_radix_cache import HybridRadixCacheBackend from .chunked_prefill.impl_for_return_all_prompt_logprobs import ReturnPromptLogProbBackend from .chunked_prefill.impl_for_reward_model import RewardModelBackend from .chunked_prefill.impl_for_token_healing import TokenHealingBackend diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index b311ab15c..9e78389b2 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -10,7 +10,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.models import get_model from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache - +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from lightllm.server.router.model_infer.infer_batch import InferReq, InferReqUpdatePack from lightllm.server.router.token_load import TokenLoad from lightllm.common.basemodel.infer_lock import g_infer_state_lock, InferStateLock @@ -100,6 +100,7 @@ def init_model(self, kvargs): self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1 self.is_nixl_pd_mode = self.run_mode in ["nixl_prefill", "nixl_decode"] self.is_nixl_decode_mode = self.run_mode == "nixl_decode" + self.is_hybrid_model = kvargs.get("is_hybrid_model", False) self.logger = init_logger(__name__) @@ -141,6 +142,7 @@ def init_model(self, kvargs): wait_events.append(self.multi_level_cache_module) model_cfg, _ = PretrainedConfig.get_config_dict(self.weight_dir) + self.is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"] model_kvargs = { "weight_dir": self.weight_dir, @@ -166,8 +168,9 @@ def init_model(self, kvargs): self.model, self.is_multimodal = get_model(model_cfg, model_kvargs) self.model: TpPartBaseModel = self.model # for easy typing set_random_seed(2147483647) + radix_cache_class = HybridRadixCache if self.is_hybrid_model else RadixCache self.radix_cache = ( - RadixCache( + radix_cache_class( get_unique_server_name(), self.model.mem_manager.size, self.rank_in_node, diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index 2c3cfaf11..ebfc80841 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -24,6 +24,7 @@ from lightllm.utils.log_utils import init_logger from lightllm.utils.dist_utils import get_current_device_id from lightllm.utils.envs_utils import get_env_start_args +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache from .control_state import ControlState logger = init_logger(__name__) @@ -138,6 +139,10 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) + + if isinstance(g_infer_context.radix_cache, HybridRadixCache): + g_infer_context.radix_cache.insert_for_hybrid_radix_cache(run_reqs) + # 第四阶段 event_pack.notify_pre_post_handle() return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py deleted file mode 100644 index e91034680..000000000 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_hybrid_radix_cache.py +++ /dev/null @@ -1,111 +0,0 @@ -from numpy import ndarray - - -import torch -from .impl import ChunkedPrefillBackend -from typing import Any, List -from typing_extensions import override -from lightllm.server.router.model_infer.infer_batch import g_infer_context, InferReq -from lightllm.common.basemodel.infer_lock import g_infer_state_lock -from lightllm.server.router.model_infer.mode_backend.overlap_events import OverlapEventPack -from lightllm.server.router.model_infer.mode_backend.pre import ( - prepare_prefill_inputs, -) -from lightllm.utils.log_utils import init_logger - -logger = init_logger(__name__) - - -class HybridRadixCacheBackend(ChunkedPrefillBackend): - def __init__(self) -> None: - super().__init__() - logger.info("Using HybridRadixCacheBackend for hybrid attention model.") - g_infer_context.use_hybrid_radix_cache = True - - @override - def init_model(self, kvargs): - from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridRadixCache - - super().init_model(kvargs) - assert isinstance(self.radix_cache, HybridRadixCache) - return - - def prefill_normal( - self, - event_pack: OverlapEventPack, - prefill_reqs: List[InferReq], - ): - # 第一阶段: 模型推理 - model_input, run_reqs = prepare_prefill_inputs( - prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal - ) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): - model_output = self.model.forward(model_input) - _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( - logits=model_output.logits, - b_req_idx=model_input.b_req_idx, - b_mtp_index=model_input.b_mtp_index, - run_reqs=run_reqs, - is_prefill=True, - b_prefill_has_output_cpu=model_input.b_prefill_has_output_cpu, - mask_func=self.prefill_mask_func, - ) - sync_event = torch.cuda.Event() - sync_event.record() - - # 第二阶段 - event_pack.notify_post_handle_and_wait_pre_post_handle() - update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) - - # 第三阶段 - event_pack.notify_forward_and_wait_post_handle() - sync_event.synchronize() - self._post_handle( - run_reqs=run_reqs, - next_token_ids=next_token_ids_cpu, - next_token_logprobs=next_token_logprobs_cpu, - run_reqs_update_packs=update_packs, - extra_post_req_handle_func=self.extra_post_req_handle_func, - nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, - ) - - if not self.disable_chunked_prefill: - for req in run_reqs: - # NOTE 忽略完整的prefill, 因为请求文本全是system prompt 的情况应该比较小 - if req.cur_kv_len < req.get_cur_total_len() - 1: - self._handle_radix_cache_insert(req) - - # 第四阶段 - event_pack.notify_pre_post_handle() - return - - def _handle_radix_cache_insert(self, req: "InferReq"): - from lightllm.models.qwen3next.mem_manager import HaveStateBuffer - from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager - - assert isinstance(self.model.req_manager.mem_manager, HaveStateBuffer) - assert isinstance(self.model.req_manager, Qwen3NextReqManager) - - # 获取当前 chunked_prefill 处理的 token IDs - input_token_ids: Any | ndarray[Any, Any] = req.get_input_token_ids() - key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") - - # 获取对应的 token 索引 - value = self.model.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() - - buffer_idx = self.model.req_manager.req_to_buffer_indexes[req.req_idx].cpu() - - # 确保有足够的空间用于新的 buffer - release_buffers = self.radix_cache.free_radix_cache_to_get_enough_token(0, 1) - - # 分配新的 buffer 并复制当前 buffer 的内容 - self.model.req_manager.mem_manager.free_state_cache_buffer(release_buffers) - new_buffer_idx = self.model.req_manager.mem_manager.alloc_state_cache_buffer(1)[0] - self.model.req_manager.mem_manager.copy_state_cache_buffer(buffer_idx, new_buffer_idx) - self.model.req_manager.req_to_buffer_indexes[req.req_idx] = new_buffer_idx - - _, new_shared_kv_node = self.radix_cache.insert(key, value, buffer_idx) - - self.radix_cache.dec_node_ref_counter(req.shared_kv_node) - self.radix_cache.add_node_ref_counter(new_shared_kv_node) - req.shared_kv_node = new_shared_kv_node diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 02edf9ff1..55fe7a415 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -7,12 +7,10 @@ import setproctitle from datetime import timedelta from typing import Dict, List, Tuple -from transformers import PretrainedConfig from lightllm.server.router.model_infer.mode_backend import ( ChunkedPrefillBackend, FirstTokenConstraintBackend, OutlinesConstraintBackend, - HybridRadixCacheBackend, ReturnPromptLogProbBackend, RewardModelBackend, TokenHealingBackend, @@ -125,13 +123,7 @@ def init_model(self, kvargs): is_nixl_prefill_node = self.args.run_mode == "nixl_prefill" is_nixl_decode_node = self.args.run_mode == "nixl_decode" - model_cfg, _ = PretrainedConfig.get_config_dict(kvargs["weight_dir"]) - is_hybrid_model = model_cfg.get("model_type", "") in ["qwen3_next"] - use_hybrid_radix_cache = is_hybrid_model and not self.args.disable_dynamic_prompt_cache - - if use_hybrid_radix_cache: - self.backend = HybridRadixCacheBackend() - elif is_prefill_node: + if is_prefill_node: if self.args.dp > 1: self.backend = DPChunkedForPrefillNode(self.info_queue) else: diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index 0e528dd45..6bbe87373 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -56,7 +56,7 @@ def _setup_logger(): _default_handler = logging.StreamHandler(sys.stdout) _default_handler.flush = sys.stdout.flush # type: ignore _default_handler.setLevel(_LOG_LEVEL) - _default_handler.addFilter(RankFilter()) + # _default_handler.addFilter(RankFilter()) _root_logger.addHandler(_default_handler) if _default_file_handler is None and _LOG_DIR is not None: @@ -68,7 +68,7 @@ def _setup_logger(): _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) _default_file_handler.setFormatter(fmt) - _default_file_handler.addFilter(RankFilter()) + # _default_file_handler.addFilter(RankFilter()) _root_logger.addHandler(_default_file_handler) _default_handler.setFormatter(fmt) From 00fb9d7a1e87bb264da9769424ba32a2a0c0db40 Mon Sep 17 00:00:00 2001 From: sufubao Date: Wed, 10 Dec 2025 11:45:26 +0000 Subject: [PATCH 06/19] tmp --- lightllm/common/req_manager.py | 1 + .../layer_infer/transformer_layer_infer.py | 4 +-- lightllm/models/qwen3next/mem_manager.py | 3 +- lightllm/models/qwen3next/model.py | 31 +++++++++++++++-- lightllm/models/qwen3next/req_manager.py | 33 +++++++++++++++++++ .../dynamic_prompt/hybrid_radix_cache.py | 3 +- lightllm/server/router/manager.py | 9 ----- .../server/router/model_infer/infer_batch.py | 5 +++ .../model_infer/mode_backend/base_backend.py | 1 - 9 files changed, 71 insertions(+), 19 deletions(-) create mode 100644 lightllm/models/qwen3next/req_manager.py diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 40c8aa993..572191089 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -7,6 +7,7 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size +from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridMemManager logger = init_logger(__name__) diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 08e7ced77..a0f7f98cf 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -251,8 +251,8 @@ def _linear_attn( assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager) input = input.view(-1, infer_cls.embed_dim_) - buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx] - conv_states, ssm_states = infer_state.mem_manager.get_state_cache_buffer(self.layer_idx_) + buffer_idx = infer_state.buffer_indexes + conv_states, ssm_states = infer_state.mem_manager.get_buffer(self.layer_idx_) mixed_qkvzba = layer_weight.linear_in_proj.mm(input) q, k, v, z, b, a = self._fix_query_key_value_ba_ordering(mixed_qkvzba) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index cf45bff49..7eef28a93 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -50,6 +50,7 @@ def __init__( conv_state_shape: Tuple[int, ...], ssm_state_dtype: torch.dtype, ssm_state_shape: Tuple[int, ...], + max_req_num: int, always_copy=False, mem_fraction=0.9, ): @@ -80,8 +81,6 @@ def __init__( f"Ssm state use : " f"{self.ssm_state_mem_manager.get_cell_size() * linear_attn_cache_size / 1024 ** 3} GB Memory.\n" ) - self.EMPTY_BUFFER_INDEX = -1 - self.HOLD_BUFFER_INDEX = self.conv_state_mem_manager.HOLD_TOKEN_MEMINDEX super().__init__(full_attn_cache_size, dtype, num_kv_heads, head_dim, layer_num, always_copy, mem_fraction) @override diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index a31e4b16d..5b8ad3769 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -11,6 +11,7 @@ from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager from lightllm.server.core.objs.start_args_type import StartArgs from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput +from lightllm.models.qwen3next.req_manager import Qwen3NextReqManager logger = init_logger(__name__) @@ -83,18 +84,42 @@ def _init_mem_manager(self): self.head_linear_k_dim, self.head_linear_v_dim, ), + max_req_num=self.max_req_num, mem_fraction=self.mem_fraction, ) + @override def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): from lightllm.common.basemodel.infer_lock import g_infer_state_lock - from lightllm.common.basemodel.infer_context import g_infer_context + from lightllm.server.router.model_infer.infer_batch import g_infer_context infer_state = super()._create_inferstate(model_input, microbatch_index) + + buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] + empty_indexes = buffer_indexes == self.req_manager.EMPTY_BUFFER_INDEX + num_empty = empty_indexes.sum() + if num_empty == 0: + return infer_state + g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(infer_state.batch_size) - buffer_indexes = self.mem_manager.alloc_buffer(infer_state.batch_size) + g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(num_empty) + new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda() g_infer_state_lock.release() + + buffer_indexes[empty_indexes] = new_buffer_indexes + self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] = buffer_indexes infer_state.buffer_indexes = buffer_indexes return infer_state + + @override + def _init_req_manager(self): + create_max_seq_len = 0 + + if self.batch_max_tokens is not None: + create_max_seq_len = max(create_max_seq_len, self.batch_max_tokens) + if self.max_seq_length is not None: + create_max_seq_len = max(create_max_seq_len, self.max_seq_length) + + self.req_manager = Qwen3NextReqManager(self.max_req_num, create_max_seq_len, self.mem_manager) + return diff --git a/lightllm/models/qwen3next/req_manager.py b/lightllm/models/qwen3next/req_manager.py new file mode 100644 index 000000000..82aae3fa5 --- /dev/null +++ b/lightllm/models/qwen3next/req_manager.py @@ -0,0 +1,33 @@ +from typing import override, List + +import torch + +from lightllm.common.req_manager import ReqManager +from lightllm.models.qwen3next.mem_manager import Qwen3NextMemoryManager + + +class Qwen3NextReqManager(ReqManager): + def __init__(self, max_request_num, max_sequence_length, mem_manager: Qwen3NextMemoryManager): + super().__init__(max_request_num, max_sequence_length, mem_manager) + self.EMPTY_BUFFER_INDEX = -1 + self.req_to_buffer_indexes = torch.zeros((self.max_request_num + 1), dtype=torch.int32, device="cuda") + self.req_to_buffer_indexes[:] = self.EMPTY_BUFFER_INDEX + + @override + def free(self, free_req_indexes: List[int], free_token_index): + self.free_buffer(free_req_indexes) + super().free(free_req_indexes, free_token_index) + + @override + def free_all(self): + self.req_to_buffer_indexes[:] = self.EMPTY_BUFFER_INDEX + super().free_all() + return + + def free_buffer(self, free_req_indexes: List[int]): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + if g_infer_context.radix_cache is None: + self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes]) + self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX + return diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 5858ddf0a..e1d839cdc 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -5,7 +5,6 @@ from lightllm.server.router.dynamic_prompt.radix_cache import RadixCache, TreeNode from lightllm.common.kv_cache_mem_manager.mem_manager import MemoryManager -from lightllm.server.router.model_infer.infer_batch import InferReq class HybridMemManager(MemoryManager): @@ -69,7 +68,7 @@ def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token self._remove_leaf_node(node) return - def insert_for_hybrid_radix_cache(self, reqs: List["InferReq"]): + def insert_for_hybrid_radix_cache(self, reqs): from lightllm.server.router.model_infer.infer_batch import g_infer_context from lightllm.common.basemodel.infer_lock import g_infer_state_lock diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index ae5092855..e2a64f6e0 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -251,14 +251,6 @@ async def loop_for_fwd( estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) - # Get hit rate from radix cache if available - hit_rate = 0.0 - if self.radix_cache_client is not None: - try: - hit_rate = self.radix_cache_client.get_match_prefix_hit_rate(d_i) - except Exception as e: - logger.warning(f"Failed to get hit rate from radix cache: {e}") - logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {paused_req_num} \n" @@ -266,7 +258,6 @@ async def loop_for_fwd( f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token\n" - f"dp_i {d_i} match_prefix hit_rate: {hit_rate:.4f}" ) self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) # pd decode mode need to update token_load more frequently diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index 5d013e4c7..f05000742 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -183,8 +183,10 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): if pause_reqs: g_infer_state_lock.acquire() + pause_req_ids = [] free_token_index = [] for req in pause_reqs: + pause_req_ids.append(req.req_id) if self.args.diverse_mode: # 发生暂停的时候,需要清除 diverse 模式下的主从关系 req.clear_master_slave_state() @@ -201,6 +203,9 @@ def pause_reqs(self, pause_reqs: List["InferReq"], is_master_in_dp: bool): free_token_index = custom_cat(free_token_index) self.req_manager.free_token(free_token_index) + if hasattr(self.req_manager, "free_buffer"): + self.req_manager.free_buffer(pause_req_ids) + g_infer_state_lock.release() return self diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 9e78389b2..eb9f98cc7 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -192,7 +192,6 @@ def init_model(self, kvargs): shm_req_manager=self.shm_req_manager, vocab_size=self.model.vocab_size, ) - # 初始化 dp 模式使用的通信 tensor, 对于非dp模式,不会使用到 if self.dp_size > 1: self.dp_reduce_tensor = torch.tensor([0], dtype=torch.int32, device="cuda", requires_grad=False) From 4af3ac532e1fd6befc590d8b0502196be4282937 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 02:11:47 +0000 Subject: [PATCH 07/19] done --- lightllm/models/qwen3next/model.py | 16 --------------- lightllm/models/qwen3next/req_manager.py | 20 +++++++++++++++++++ .../mode_backend/chunked_prefill/impl.py | 4 ++++ 3 files changed, 24 insertions(+), 16 deletions(-) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 5b8ad3769..dc56532e8 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -90,25 +90,9 @@ def _init_mem_manager(self): @override def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): - from lightllm.common.basemodel.infer_lock import g_infer_state_lock - from lightllm.server.router.model_infer.infer_batch import g_infer_context - infer_state = super()._create_inferstate(model_input, microbatch_index) buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] - empty_indexes = buffer_indexes == self.req_manager.EMPTY_BUFFER_INDEX - num_empty = empty_indexes.sum() - if num_empty == 0: - return infer_state - - g_infer_state_lock.acquire() - if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(num_empty) - new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda() - g_infer_state_lock.release() - - buffer_indexes[empty_indexes] = new_buffer_indexes - self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] = buffer_indexes infer_state.buffer_indexes = buffer_indexes return infer_state diff --git a/lightllm/models/qwen3next/req_manager.py b/lightllm/models/qwen3next/req_manager.py index 82aae3fa5..3e62afc99 100644 --- a/lightllm/models/qwen3next/req_manager.py +++ b/lightllm/models/qwen3next/req_manager.py @@ -31,3 +31,23 @@ def free_buffer(self, free_req_indexes: List[int]): self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes]) self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX return + + def alloc_buffer(self, req_indexes: List[int]): + from lightllm.common.basemodel.infer_lock import g_infer_state_lock + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + cur_buffer_indexes = self.req_to_buffer_indexes[req_indexes] + empty_indexes = cur_buffer_indexes == self.EMPTY_BUFFER_INDEX + num_empty = empty_indexes.sum() + if num_empty == 0: + return + + g_infer_state_lock.acquire() + if g_infer_context.radix_cache is not None: + g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_empty) + new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda() + g_infer_state_lock.release() + + cur_buffer_indexes[empty_indexes] = new_buffer_indexes + self.req_to_buffer_indexes[req_indexes] = cur_buffer_indexes + return diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index ebfc80841..bcb00fa8c 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -110,6 +110,10 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal ) + + if hasattr(g_infer_context.req_manager, "req_to_buffer_indexes"): + g_infer_context.req_manager.alloc_buffer(model_input.b_req_idx) + with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( From b5396325b3b00aaeb9da397a2a0709d6f709043a Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 03:55:24 +0000 Subject: [PATCH 08/19] fix --- lightllm/common/basemodel/infer_struct.py | 3 -- .../layer_infer/transformer_layer_infer.py | 2 +- lightllm/models/qwen3next/mem_manager.py | 3 +- lightllm/models/qwen3next/model.py | 8 ----- lightllm/models/qwen3next/req_manager.py | 32 ++++++++----------- .../dynamic_prompt/hybrid_radix_cache.py | 6 ++-- .../mode_backend/chunked_prefill/impl.py | 10 ++---- 7 files changed, 21 insertions(+), 43 deletions(-) diff --git a/lightllm/common/basemodel/infer_struct.py b/lightllm/common/basemodel/infer_struct.py index 34a450a9f..85d3d8c46 100755 --- a/lightllm/common/basemodel/infer_struct.py +++ b/lightllm/common/basemodel/infer_struct.py @@ -88,9 +88,6 @@ def __init__(self): self.dp_output_split_sizes: List[List[int]] = None self.dp_input_split_sizes: List[List[int]] = None - # 专门用于管理混合注意力模型的buffer - self.buffer_indexes: torch.Tensor = None - def init_some_extra_state(self, model, input_ids: torch.Tensor): if self.is_prefill: ( diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index a0f7f98cf..eb0caeaa6 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -251,7 +251,7 @@ def _linear_attn( assert isinstance(infer_state.mem_manager, Qwen3NextMemoryManager) input = input.view(-1, infer_cls.embed_dim_) - buffer_idx = infer_state.buffer_indexes + buffer_idx = infer_state.req_manager.req_to_buffer_indexes[infer_state.b_req_idx] conv_states, ssm_states = infer_state.mem_manager.get_buffer(self.layer_idx_) mixed_qkvzba = layer_weight.linear_in_proj.mm(input) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 7eef28a93..51d7d277e 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -121,8 +121,7 @@ def free_buffer(self, free_buffer_indexes: List[int], reset=True): @override def alloc_buffer(self, need_size): # conv_state 和 ssm_state 共享buffer_idx - buffer_indexes = self.conv_state_mem_manager.alloc(need_size) - return buffer_indexes + return self.conv_state_mem_manager.alloc(need_size) @override def get_buffer_can_use_size(self): diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index dc56532e8..5e6cfc2ef 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -88,14 +88,6 @@ def _init_mem_manager(self): mem_fraction=self.mem_fraction, ) - @override - def _create_inferstate(self, model_input: ModelInput, microbatch_index: int = 0): - infer_state = super()._create_inferstate(model_input, microbatch_index) - - buffer_indexes = self.req_manager.req_to_buffer_indexes[model_input.b_req_idx] - infer_state.buffer_indexes = buffer_indexes - return infer_state - @override def _init_req_manager(self): create_max_seq_len = 0 diff --git a/lightllm/models/qwen3next/req_manager.py b/lightllm/models/qwen3next/req_manager.py index 3e62afc99..05a09ed65 100644 --- a/lightllm/models/qwen3next/req_manager.py +++ b/lightllm/models/qwen3next/req_manager.py @@ -24,30 +24,26 @@ def free_all(self): super().free_all() return - def free_buffer(self, free_req_indexes: List[int]): - from lightllm.server.router.model_infer.infer_batch import g_infer_context - - if g_infer_context.radix_cache is None: - self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes]) - self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX - return - - def alloc_buffer(self, req_indexes: List[int]): + @override + def alloc(self): from lightllm.common.basemodel.infer_lock import g_infer_state_lock from lightllm.server.router.model_infer.infer_batch import g_infer_context - cur_buffer_indexes = self.req_to_buffer_indexes[req_indexes] - empty_indexes = cur_buffer_indexes == self.EMPTY_BUFFER_INDEX - num_empty = empty_indexes.sum() - if num_empty == 0: - return + req_index = super().alloc() g_infer_state_lock.acquire() if g_infer_context.radix_cache is not None: - g_infer_context.radix_cache.free_radix_cache_to_get_enough_token(num_empty) - new_buffer_indexes = self.mem_manager.alloc_buffer(num_empty).cuda() + g_infer_context.radix_cache.free_radix_cache_to_get_enough_buffer(1) + new_buffer_index = self.mem_manager.alloc_buffer(1) + self.req_to_buffer_indexes[req_index] = new_buffer_index g_infer_state_lock.release() - cur_buffer_indexes[empty_indexes] = new_buffer_indexes - self.req_to_buffer_indexes[req_indexes] = cur_buffer_indexes + return req_index + + def free_buffer(self, free_req_indexes: List[int]): + from lightllm.server.router.model_infer.infer_batch import g_infer_context + + if g_infer_context.radix_cache is None: + self.mem_manager.free_buffer(self.req_to_buffer_indexes[free_req_indexes]) + self.req_to_buffer_indexes[free_req_indexes] = self.EMPTY_BUFFER_INDEX return diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index e1d839cdc..225821139 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -82,14 +82,12 @@ def insert_for_hybrid_radix_cache(self, reqs): input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() - buffer_idx = req.buffer_idx # 分配新的 buffer 并复制当前 buffer 的内容 - self.mem_manager.copy_buffer(buffer_idx, new_buffer_indexes[i]) - req.buffer_idx = new_buffer_indexes[i] + self.mem_manager.copy_buffer(req.buffer_idx, new_buffer_indexes[i]) _, new_shared_kv_node = self.insert(key, value) - new_shared_kv_node.buffer_idx = buffer_idx + new_shared_kv_node.buffer_idx = new_buffer_indexes[i] self.dec_node_ref_counter(req.shared_kv_node) self.add_node_ref_counter(new_shared_kv_node) req.shared_kv_node = new_shared_kv_node diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index bcb00fa8c..a95d84116 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -111,9 +111,6 @@ def prefill_normal( prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal ) - if hasattr(g_infer_context.req_manager, "req_to_buffer_indexes"): - g_infer_context.req_manager.alloc_buffer(model_input.b_req_idx) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( @@ -132,6 +129,9 @@ def prefill_normal( event_pack.notify_post_handle_and_wait_pre_post_handle() update_packs = self._pre_post_handle(run_reqs, is_chuncked_mode=not self.disable_chunked_prefill) + if isinstance(g_infer_context.radix_cache, HybridRadixCache): + g_infer_context.radix_cache.insert_for_hybrid_radix_cache(run_reqs) + # 第三阶段 event_pack.notify_forward_and_wait_post_handle() sync_event.synchronize() @@ -143,10 +143,6 @@ def prefill_normal( extra_post_req_handle_func=self.extra_post_req_handle_func, nixl_prefill_chuncked_handle_func=self.nixl_prefill_chuncked_handle_func, ) - - if isinstance(g_infer_context.radix_cache, HybridRadixCache): - g_infer_context.radix_cache.insert_for_hybrid_radix_cache(run_reqs) - # 第四阶段 event_pack.notify_pre_post_handle() return From 68e1cee3dc1c5e6d84a9f1aae0a9b3b21ab3cc1d Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 04:11:54 +0000 Subject: [PATCH 09/19] fix cudagraph --- lightllm/models/qwen3next/mem_manager.py | 1 + lightllm/models/qwen3next/req_manager.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/lightllm/models/qwen3next/mem_manager.py b/lightllm/models/qwen3next/mem_manager.py index 51d7d277e..399ef487c 100644 --- a/lightllm/models/qwen3next/mem_manager.py +++ b/lightllm/models/qwen3next/mem_manager.py @@ -68,6 +68,7 @@ def __init__( self.ssm_state_shape = ssm_state_shape assert linear_attn_cache_size is not None + self.HOLD_BUFFER_INDEX = linear_attn_cache_size self.conv_state_mem_manager = LayerCacheMemoryManager( linear_attn_cache_size, conv_state_dtype, conv_state_shape, self.linear_attn_layer_num, "conv_state" ) diff --git a/lightllm/models/qwen3next/req_manager.py b/lightllm/models/qwen3next/req_manager.py index 05a09ed65..ae1e961c6 100644 --- a/lightllm/models/qwen3next/req_manager.py +++ b/lightllm/models/qwen3next/req_manager.py @@ -11,7 +11,8 @@ def __init__(self, max_request_num, max_sequence_length, mem_manager: Qwen3NextM super().__init__(max_request_num, max_sequence_length, mem_manager) self.EMPTY_BUFFER_INDEX = -1 self.req_to_buffer_indexes = torch.zeros((self.max_request_num + 1), dtype=torch.int32, device="cuda") - self.req_to_buffer_indexes[:] = self.EMPTY_BUFFER_INDEX + self.req_to_buffer_indexes[:-1] = self.EMPTY_BUFFER_INDEX + self.req_to_buffer_indexes[-1] = self.mem_manager.HOLD_BUFFER_INDEX @override def free(self, free_req_indexes: List[int], free_token_index): @@ -20,7 +21,7 @@ def free(self, free_req_indexes: List[int], free_token_index): @override def free_all(self): - self.req_to_buffer_indexes[:] = self.EMPTY_BUFFER_INDEX + self.req_to_buffer_indexes[:-1] = self.EMPTY_BUFFER_INDEX super().free_all() return From 599732a0ce6abc7f2c00fbd1f6c3bbd2a0357ead Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 08:02:05 +0000 Subject: [PATCH 10/19] update kernel --- .../layer_infer/transformer_layer_infer.py | 20 +- .../qwen3next/triton_kernel/causal_conv1d.py | 1025 +---------------- .../qwen3next/triton_kernel/fla/__init__.py | 8 - .../qwen3next/triton_kernel/fla/ops/chunk.py | 24 +- .../triton_kernel/fla/ops/chunk_delta_h.py | 120 +- .../triton_kernel/fla/ops/chunk_o.py | 27 +- .../fla/ops/chunk_scaled_dot_kkt.py | 45 +- .../qwen3next/triton_kernel/fla/ops/cumsum.py | 96 +- .../triton_kernel/fla/ops/fused_recurrent.py | 46 +- .../qwen3next/triton_kernel/fla/ops/l2norm.py | 8 +- .../qwen3next/triton_kernel/fla/ops/op.py | 41 +- .../triton_kernel/fla/ops/solve_tril.py | 482 +++++--- .../qwen3next/triton_kernel/fla/ops/utils.py | 18 +- .../triton_kernel/fla/ops/wy_fast.py | 47 +- .../triton_kernel/fla_bak/__init__.py | 15 + .../qwen3next/triton_kernel/fla_bak/chunk.py | 225 ++++ .../triton_kernel/fla_bak/chunk_delta_h.py | 257 +++++ .../triton_kernel/fla_bak/chunk_o.py | 167 +++ .../fla_bak/chunk_scaled_dot_kkt.py | 136 +++ .../qwen3next/triton_kernel/fla_bak/cumsum.py | 200 ++++ .../triton_kernel/fla_bak/fused_recurrent.py | 367 ++++++ .../qwen3next/triton_kernel/fla_bak/index.py | 30 + .../qwen3next/triton_kernel/fla_bak/l2norm.py | 137 +++ .../qwen3next/triton_kernel/fla_bak/op.py | 36 + .../triton_kernel/fla_bak/solve_tril.py | 271 +++++ .../qwen3next/triton_kernel/fla_bak/utils.py | 173 +++ .../triton_kernel/fla_bak/wy_fast.py | 122 ++ lightllm/server/api_cli.py | 2 +- 28 files changed, 2868 insertions(+), 1277 deletions(-) create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/index.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/op.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py create mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index eb0caeaa6..858e71685 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -14,8 +14,9 @@ from lightllm.models.qwen3next.triton_kernel.gated_rmsnorm import gated_rmsnorm_forward from lightllm.models.qwen3next.triton_kernel.causal_conv1d import causal_conv1d_fn, causal_conv1d_update from lightllm.models.qwen3next.triton_kernel.fused_gdn_gating import fused_gdn_gating -from lightllm.models.qwen3next.triton_kernel.fla.ops.chunk import chunk_gated_delta_rule -from lightllm.models.qwen3next.triton_kernel.fla.ops.fused_recurrent import fused_recurrent_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops import chunk_gated_delta_rule +from lightllm.models.qwen3next.triton_kernel.fla.ops import fused_recurrent_gated_delta_rule + from lightllm.distributed import all_reduce from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.qwen3next.triton_kernel.gemma_rmsnorm import gemma_rmsnorm_forward @@ -260,15 +261,13 @@ def _linear_attn( if is_prefill: mixed_qkv = mixed_qkv.transpose(0, 1) - out_tensor = infer_cls.alloc_tensor(mixed_qkv.shape, mixed_qkv.dtype, device=mixed_qkv.device) - causal_conv1d_fn( + out_tensor = causal_conv1d_fn( mixed_qkv, layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), - layer_weight.linear_conv1d.mm_param.bias, - conv_states.transpose(1, 2), - infer_state.b1_cu_q_seq_len, - out=out_tensor, + bias=layer_weight.linear_conv1d.mm_param.bias, + query_start_loc=infer_state.b1_cu_q_seq_len, cache_indices=buffer_idx, + conv_states=conv_states.transpose(1, 2), activation=self.activation, ) mixed_qkv = out_tensor.transpose(0, 1) @@ -277,10 +276,9 @@ def _linear_attn( mixed_qkv, conv_states.transpose(1, 2), layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), - layer_weight.linear_conv1d.mm_param.bias, - self.activation, + bias=layer_weight.linear_conv1d.mm_param.bias, + activation=self.activation, conv_state_indices=buffer_idx, - validate_data=True, ) # Rearrange mixed_qkv to query, key, value diff --git a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py index 202ce7460..c6d099a2d 100644 --- a/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py +++ b/lightllm/models/qwen3next/triton_kernel/causal_conv1d.py @@ -1,388 +1,33 @@ -# SPDX-License-Identifier: Apache-2.0 -# Adapted from -# https://github.com/vllm-project/vllm/blob/v0.11.0rc1/vllm/model_executor/layers/mamba/ops/causal_conv1d.py -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/attention/mamba/causal_conv1d.py -# Copyright (c) 2024, Tri Dao. -# Adapted from https://github.com/Dao-AILab/causal-conv1d/blob/main/causal_conv1d/causal_conv1d_interface.py +from typing import Optional -from typing import Optional, Union - -import numpy as np import torch -import triton -import triton.language as tl - -PAD_SLOT_ID = -1 - - -@triton.jit() -def _causal_conv1d_fwd_kernel( # continuous batching - # Pointers to matrices - x_ptr, # (dim, cu_seqlen) holding `batch` of actual sequences + padded sequences - w_ptr, # (dim, width) - bias_ptr, - initial_states_ptr, # conv_states_ptr - cache_indices_ptr, # conv_state_indices_ptr - has_initial_states_ptr, - query_start_loc_ptr, - batch_ptr, - token_chunk_offset_ptr, - o_ptr, # (dim, seqlen) - actually pointing to x_ptr - # Matrix dimensions - batch: tl.int32, # actually padded_batch - dim: tl.constexpr, - seqlen: tl.int32, # cu_seqlen - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, # stride to get to next sequence, - stride_x_dim: tl.constexpr, # stride to get to next feature-value, - stride_x_token: tl.constexpr, # stride to get to next token (same feature-index, same sequence-index) - stride_w_dim: tl.constexpr, # stride to get to next dim-axis value - stride_w_width: tl.constexpr, # stride to get to next width-axis value - stride_istate_seq: tl.constexpr, - stride_istate_dim: tl.constexpr, - stride_istate_token: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - HAS_INITIAL_STATES: tl.constexpr, - HAS_CACHE: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - NP2_STATELEN: tl.constexpr, - BLOCK_M: tl.constexpr, - BLOCK_N: tl.constexpr, -): - conv_states_ptr = initial_states_ptr - conv_state_indices_ptr = cache_indices_ptr - stride_conv_state_seq = stride_istate_seq - stride_conv_state_dim = stride_istate_dim - stride_conv_state_tok = stride_istate_token - state_len = KERNEL_WIDTH - 1 # can be passed via argument if it's not the same as this value - - # one program handles one chunk in a single sequence - # rather than mixing sequences - to make updating initial_states across sequences efficiently - - # single-sequence id - idx_seq = tl.load(batch_ptr + tl.program_id(0)) - chunk_offset = tl.load(token_chunk_offset_ptr + tl.program_id(0)) - - # BLOCK_N elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if idx_seq == pad_slot_id: - return - - sequence_start_index = tl.load(query_start_loc_ptr + idx_seq) - sequence_end_index = tl.load(query_start_loc_ptr + idx_seq + 1) - # find the actual sequence length - seqlen = sequence_end_index - sequence_start_index - - token_offset = BLOCK_M * chunk_offset - segment_len = min(BLOCK_M, seqlen - token_offset) - - # base of the sequence - x_base = x_ptr + sequence_start_index * stride_x_token + idx_feats * stride_x_dim # [BLOCK_N,] - - if IS_CONTINUOUS_BATCHING: - # cache_idx - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq).to(tl.int64) - else: - # cache_idx - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - conv_states_base = ( - conv_states_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) - ) # [BLOCK_N,] - - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - - # Does 2 things: - # 1. READ prior-block init-state data - [done by every Triton programs] - # 2. update conv_state with new data [only by the Triton program handles chunk_offset=0] - if chunk_offset == 0: - # read from conv_states - load_init_state = False - if HAS_INITIAL_STATES: # the new HAS_INITIAL_STATES - load_init_state = tl.load(has_initial_states_ptr + idx_seq).to(tl.int1) - if load_init_state: - # load from conv_states - prior_tokens = conv_states_base + (state_len - 1) * stride_conv_state_tok - mask_w = idx_feats < dim - if KERNEL_WIDTH == 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 3: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 4: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH == 5: - conv_states_ptrs = prior_tokens # [BLOCK_N] - # col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 1 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 2 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - conv_states_ptrs = prior_tokens - 3 * stride_conv_state_tok # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - else: - # prior-tokens are zeros - if KERNEL_WIDTH >= 2: # STRATEGY1 - # first chunk and does not have prior-token, so just set to 0 - col0 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - if KERNEL_WIDTH >= 3: # STRATEGY1 - col1 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - if KERNEL_WIDTH >= 4: # STRATEGY1 - col2 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - # if KERNEL_WIDTH >= 5: # STRATEGY1 - # col3 = tl.zeros((BLOCK_N,), dtype=x_ptr.dtype.element_ty) - - # STEP 2: - # here prepare data for updating conv_state - if state_len <= seqlen: # SMALL_CACHE=True (only move part of 'x' into conv_state cache) - # just read from 'x' - # copy 'x' data to conv_state - # load only 'x' data (and set 0 before 'x' if seqlen < state_len) - idx_tokens_last = (seqlen - state_len) + tl.arange(0, NP2_STATELEN) # [BLOCK_M] - x_ptrs = ( - x_ptr - + ((sequence_start_index + idx_tokens_last) * stride_x_token)[:, None] - + (idx_feats * stride_x_dim)[None, :] - ) # [BLOCK_M,BLOCK_N,] - mask_x = ( - (idx_tokens_last >= 0)[:, None] & (idx_tokens_last < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - conv_states_ptrs_target = conv_states_base[None, :] + (idx_tokens_conv * stride_conv_state_tok)[:, None] - - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.debug_barrier() # NOTE: use this due to bug in Triton compiler - tl.store(conv_states_ptrs_target, new_conv_state, mask) - - else: - if load_init_state: - # update conv_state by shifting left, i.e. take last few cols from conv_state + cols from 'x' - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - conv_states_ptrs_source = ( - conv_states_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + (idx_feats * stride_conv_state_dim)[None, :] - + ((idx_tokens_conv + seqlen) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ( - (conv_state_batch_coord < num_cache_lines) - & ((idx_tokens_conv + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :] - ) - conv_state = tl.load(conv_states_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - - x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] - - mask_x = ( - (idx_tokens_conv - VAL >= 0)[:, None] - & (idx_tokens_conv - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - # need this due to the bug in tl.where not enforcing this - # when data is the result of another tl.load - tl.debug_barrier() - new_conv_state = tl.where( - mask, conv_state, loaded_x - ) # BUG in 'tl.where' which requires a barrier before this - conv_states_ptrs_target = ( - conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_states_ptrs_target, new_conv_state, mask) - else: # load_init_state == False - # update conv_state by shifting left, BUT - # set cols prior to 'x' as zeros + cols from 'x' - idx_tokens_conv = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - VAL = state_len - seqlen - - x_ptrs = x_base[None, :] + ((idx_tokens_conv - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] - - mask_x = ( - (idx_tokens_conv - VAL >= 0)[:, None] - & (idx_tokens_conv - VAL < seqlen)[:, None] - & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - new_conv_state = tl.load(x_ptrs, mask_x, 0.0) - - conv_states_ptrs_target = ( - conv_states_base + (idx_tokens_conv * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = (idx_tokens_conv < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_states_ptrs_target, new_conv_state, mask) - - else: # chunk_offset > 0 - # read prior-token data from `x` - load_init_state = True - prior_tokens = x_base + (token_offset - 1) * stride_x_token - mask_w = idx_feats < dim - if KERNEL_WIDTH == 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - if KERNEL_WIDTH == 3: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - if KERNEL_WIDTH == 4: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - if KERNEL_WIDTH == 5: - # ruff: noqa: F841 - conv_states_ptrs = prior_tokens # [BLOCK_N] - # col3 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 1 * stride_x_token # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 2 * stride_x_token # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - conv_states_ptrs = prior_tokens - 3 * stride_x_token # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0, cache_modifier=".ca") - - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) - - x_base_1d = x_base + token_offset * stride_x_token # starting of chunk - - # PRE-LOAD WEIGHTS - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - mask_x_1d = idx_feats < dim - for idx_token in range(segment_len): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < segment_len) & (idx_feats < dim) # token-index # feature-index - o_ptrs = o_ptr + (sequence_start_index + token_offset + idx_token) * stride_o_token + (idx_feats * stride_o_dim) - - tl.store(o_ptrs, acc, mask=mask_1d) +from sgl_kernel import causal_conv1d_fwd +from sgl_kernel import causal_conv1d_update as causal_conv1d_update_kernel def causal_conv1d_fn( x: torch.Tensor, weight: torch.Tensor, - bias: Union[torch.Tensor, None], - conv_states: torch.Tensor, - query_start_loc: torch.Tensor, - out: torch.Tensor, + bias: Optional[torch.Tensor] = None, + query_start_loc: Optional[torch.Tensor] = None, cache_indices: Optional[torch.Tensor] = None, has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", - pad_slot_id: int = PAD_SLOT_ID, - metadata=None, - validate_data=True, + pad_slot_id: int = -1, + **kwargs, ): - """support varlen + continuous batching when x is 2D tensor - - x: (dim,cu_seq_len) - cu_seq_len = total tokens of all seqs in that batch + """ + x: (batch, dim, seqlen) or (dim,cu_seq_len) for varlen sequences are concatenated from left to right for varlen weight: (dim, width) - conv_states: (...,dim,width - 1) itype - updated inplace if provided - [it use `cache_indices` to get the index to the cache of conv_state for that sequence - - conv_state[cache_indices[i]] for seq-i - to be used as initial_state when has_initial_state[i] = True - and after that conv_state[cache_indices[i]] need to be shift-left and updated with values from 'x' - ] + bias: (dim,) query_start_loc: (batch + 1) int32 The cumulative sequence lengths of the sequences in the batch, used to index into sequence. prepended by 0. - if - x = [5, 1, 1, 1] <- continuous batching (batch=4) - then - query_start_loc = [0, 5, 6, 7, 8] <- the starting index of the next sequence; while the last value is - the ending index of the last sequence - [length(query_start_loc)-1 == batch] for example: query_start_loc = torch.Tensor([0,10,16,17]), x.shape=(dim,17) cache_indices: (batch) int32 @@ -391,497 +36,37 @@ def causal_conv1d_fn( has_initial_state: (batch) bool indicates whether should the kernel take the current state as initial state for the calculations - [single boolean for each sequence in the batch: True or False] - bias: (dim,) - activation: either None or "silu" or "swish" or True + conv_states: (...,dim,width - 1) itype + updated inplace if provided + activation: either None or "silu" or "swish" pad_slot_id: int - if cache_indices is passed, lets the kernel identify padded - entries that will not be processed, - for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] - in this case, the kernel will not process entries at - indices 0 and 3 - - out: same shape as `x` - """ - if isinstance(activation, bool) and activation: - activation = "silu" - - args = None - # Store original dtype to cast back at the end - original_x_dtype = x.dtype - x = x.to(conv_states.dtype) - if metadata is not None: - nums_dict = metadata.nums_dict - args = nums_dict - batch_ptr = metadata.batch_ptr - token_chunk_offset_ptr = metadata.token_chunk_offset_ptr - else: - seqlens = np.diff(query_start_loc.to("cpu")) - args = seqlens - MAX_NUM_PROGRAMS = 1024 - - batch_ptr = torch.full( - (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device - ) # tracking which seq-idx the Triton program is handling - token_chunk_offset_ptr = torch.full( - (MAX_NUM_PROGRAMS,), PAD_SLOT_ID, dtype=torch.int32, device=x.device - ) # tracking BLOCK_M-based index in the sequence the Triton program is handling - - is_channel_last = (x.stride(0) == 1) & (x.stride(1) > 1) - dim, cu_seqlen = x.shape - _, width = weight.shape - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) - - padded_batch = query_start_loc.size(0) - 1 - stride_x_seq = 0 - stride_x_dim = x.stride(0) - stride_x_token = x.stride(1) - stride_w_dim = weight.stride(0) - stride_w_width = weight.stride(1) - stride_istate_seq = 0 - stride_istate_dim = 0 - stride_istate_token = 0 - num_cache_lines = 0 - if conv_states is not None: - # extensions to support vLLM: - # 1. conv_states is used to replaced initial_states - # 2. conv_states serve as a cache with num cache lines can be larger than batch size - # 3. mapping from sequence x[idx] to a cache line at index as specified via cache_indices[idx] - # 4. computation can be skipped if cache_indices[idx] == pad_slot_id - num_cache_lines = conv_states.size(0) - assert ( - num_cache_lines == conv_states.shape[0] - and dim == conv_states.shape[1] - and width - 1 <= conv_states.shape[2] - ), f"{num_cache_lines} {dim} {width} {conv_states.shape}" - stride_istate_seq = conv_states.stride(0) - stride_istate_dim = conv_states.stride(1) - stride_istate_token = conv_states.stride(2) - assert stride_istate_dim == 1 - if out.dim() == 2: - stride_o_seq = 0 - stride_o_dim = out.stride(0) - stride_o_token = out.stride(1) - else: - stride_o_seq = out.stride(0) - stride_o_dim = out.stride(1) - stride_o_token = out.stride(2) - - if validate_data: - assert x.dim() == 2 - assert query_start_loc is not None - assert query_start_loc.dim() == 1 - assert x.stride(0) == 1 or x.stride(1) == 1 - if bias is not None: - assert bias.dim() == 1 - assert dim == bias.size(0) - if cache_indices is not None: - assert cache_indices.dim() == 1 - assert padded_batch == cache_indices.size(0) - if has_initial_state is not None: - assert has_initial_state.size() == (padded_batch,) - assert conv_states is not None, "ERROR: `has_initial_state` is used, which needs also `conv_states`" - assert weight.stride(1) == 1 - assert (dim, width) == weight.shape - assert is_channel_last, "Need to run in channel-last layout" - - if metadata is None: - - def num_program(META, seqlens): - tot = 0 - - mlist = [] - offsetlist = [] # type: ignore - - nums = -(-seqlens // META["BLOCK_M"]) - - tot = nums.sum().item() - mlist = np.repeat(np.arange(len(nums)), nums) - for idx, num in enumerate(nums): - offsetlist.extend(range(num)) # chunk-idx if a sequence is split into multiple chunks - - if META["batch_ptr"].nelement() < len(mlist): - newlen = len(mlist) + 1 - META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - - if META["batch_ptr"].nelement() >= len(mlist): - META["batch_ptr"][0 : len(mlist)].copy_(torch.from_numpy(np.array(mlist))) - META["token_chunk_offset_ptr"][0 : len(mlist)].copy_(torch.from_numpy(np.array(offsetlist))) - - META["batch_ptr"] = META["batch_ptr"].to(META["x_ptr"].device) - META["token_chunk_offset_ptr"] = META["token_chunk_offset_ptr"].to(META["x_ptr"].device) - return tot - - else: - - def num_program(META, nums_dict): - tot = nums_dict[META["BLOCK_M"]]["tot"] - - mlist = nums_dict[META["BLOCK_M"]]["mlist"] - mlist_len = nums_dict[META["BLOCK_M"]]["mlist_len"] - - offsetlist = nums_dict[META["BLOCK_M"]]["offsetlist"] - - if nums_dict[META["BLOCK_M"]]["batch_ptr"] is not None: - META["batch_ptr"] = nums_dict[META["BLOCK_M"]]["batch_ptr"] - META["token_chunk_offset_ptr"] = nums_dict[META["BLOCK_M"]]["token_chunk_offset_ptr"] - else: - if META["batch_ptr"].nelement() < mlist_len: - newlen = mlist_len + 1 - META["batch_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - META["token_chunk_offset_ptr"].resize_(newlen).fill_(PAD_SLOT_ID) - - if META["batch_ptr"].nelement() >= mlist_len: - META["batch_ptr"][0:mlist_len].copy_(mlist) - META["token_chunk_offset_ptr"][0:mlist_len].copy_(offsetlist) - return tot + if cache_indices is passed, lets the kernel identify padded + entries that will not be processed, + for example: cache_indices = [pad_slot_id, 1, 20, pad_slot_id] + in this case, the kernel will not process entries at + indices 0 and 3 - def grid(META): - return ( - num_program(META, args), - triton.cdiv(dim, META["BLOCK_N"]), - ) - if batch_ptr.device != x.device: - batch_ptr = batch_ptr.to(x.device) - token_chunk_offset_ptr = token_chunk_offset_ptr.to(x.device) + out: (batch, dim, seqlen) + """ + if activation not in [None, "silu", "swish"]: + raise NotImplementedError("activation must be None, silu, or swish") + if x.stride(-1) != 1: + x = x.contiguous() + bias = bias.contiguous() if bias is not None else None - _causal_conv1d_fwd_kernel[grid]( - # Pointers to matrices + causal_conv1d_fwd( x, weight, bias, conv_states, + query_start_loc, cache_indices, has_initial_state, - query_start_loc, - batch_ptr, - token_chunk_offset_ptr, - out, - # Matrix dimensions - padded_batch, - dim, - cu_seqlen, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others + activation in ["silu", "swish"], pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - HAS_INITIAL_STATES=has_initial_state is not None, - HAS_CACHE=conv_states is not None, - IS_CONTINUOUS_BATCHING=cache_indices is not None, - USE_PAD_SLOT=pad_slot_id is not None, - NP2_STATELEN=np2_statelen, - # launch_cooperative_grid=True - BLOCK_M=8, - BLOCK_N=256, - num_stages=2, - ) - return out.to(original_x_dtype) - - -@triton.jit() -def _causal_conv1d_update_kernel( - # Pointers to matrices - x_ptr, # (batch, dim, seqlen) - w_ptr, # (dim, width) - bias_ptr, - conv_state_ptr, - cache_seqlens_ptr, # circular buffer - conv_state_indices_ptr, - num_accepted_tokens_ptr, - query_start_loc_ptr, # (batch + 1) - o_ptr, # (batch, dim, seqlen) - # Matrix dimensions - batch: int, - dim: tl.constexpr, - seqlen: tl.constexpr, - state_len: tl.constexpr, - num_cache_lines: tl.constexpr, # added to support vLLM larger cache lines - # Strides - stride_x_seq: tl.constexpr, - stride_x_dim: tl.constexpr, - stride_x_token: tl.constexpr, - stride_w_dim: tl.constexpr, - stride_w_width: tl.constexpr, - stride_conv_state_seq: tl.constexpr, - stride_conv_state_dim: tl.constexpr, - stride_conv_state_tok: tl.constexpr, - stride_state_indices: tl.constexpr, - stride_o_seq: tl.constexpr, - stride_o_dim: tl.constexpr, - stride_o_token: tl.constexpr, - # others - pad_slot_id: tl.constexpr, - # Meta-parameters - HAS_BIAS: tl.constexpr, - KERNEL_WIDTH: tl.constexpr, - SILU_ACTIVATION: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, - NP2_STATELEN: tl.constexpr, - USE_PAD_SLOT: tl.constexpr, - BLOCK_N: tl.constexpr, -): - # ruff: noqa: E501 - idx_seq = tl.program_id(0) - if idx_seq >= batch: - return - - # [BLOCK_N,] elements along the feature-dimension (channel) - idx_feats = tl.program_id(1) * BLOCK_N + tl.arange(0, BLOCK_N) - - if IS_CONTINUOUS_BATCHING: - # mask = idx_seq < batch - conv_state_batch_coord = tl.load(conv_state_indices_ptr + idx_seq * stride_state_indices).to(tl.int64) - else: - conv_state_batch_coord = idx_seq - if USE_PAD_SLOT: # noqa - if conv_state_batch_coord == pad_slot_id: - # not processing as this is not the actual sequence - return - - if IS_VARLEN: - query_start_index = tl.load(query_start_loc_ptr + idx_seq).to(tl.int64) - query_end_index = tl.load(query_start_loc_ptr + (idx_seq + 1)).to(tl.int64) - # revise state_len and seqlen - state_len = state_len - (seqlen - (query_end_index - query_start_index)) - seqlen = query_end_index - query_start_index - x_offset = query_start_index * stride_x_token - o_offset = query_start_index * stride_o_token - else: - query_start_index = idx_seq * seqlen - query_end_index = query_start_index + seqlen - x_offset = idx_seq * stride_x_seq - o_offset = idx_seq * stride_o_seq - - if query_start_index == query_end_index: - return - - if IS_SPEC_DECODING: - # The rolling of conv state: - # - # Before forward, the conv_state is: - # [history1, history2, ..., historyM]. - # - # After forward, the conv_state becomes: - # [history2, ..., historyM, draft1, draft2, ..., draftN]. - # - # After acceptance, it becomes: - # - # - accept 1 tokens: [history2, ..., historyM, draft1] - # - accept 2 tokens: [history3, ..., historyM, draft1, draft2] - # - and so on. - conv_state_token_offset = tl.load(num_accepted_tokens_ptr + idx_seq).to(tl.int64) - 1 - else: - conv_state_token_offset = 0 - - # STEP 1: READ init_state data - conv_states_base = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) - ) - mask_w = idx_feats < dim - - prior_tokens = conv_states_base + conv_state_token_offset * stride_conv_state_tok - if KERNEL_WIDTH >= 2: - conv_states_ptrs = prior_tokens # [BLOCK_N] - col0 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 3: - conv_states_ptrs = prior_tokens + 1 * stride_conv_state_tok # [BLOCK_N] - col1 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 4: - conv_states_ptrs = prior_tokens + 2 * stride_conv_state_tok # [BLOCK_N] - col2 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 5: - conv_states_ptrs = prior_tokens + 3 * stride_conv_state_tok # [BLOCK_N] - col3 = tl.load(conv_states_ptrs, mask_w, 0.0) - if KERNEL_WIDTH >= 6: - conv_states_ptrs = prior_tokens + 4 * stride_conv_state_tok # [BLOCK_N] - col4 = tl.load(conv_states_ptrs, mask_w, 0.0) - - # STEP 2: assume state_len > seqlen - idx_tokens = tl.arange(0, NP2_STATELEN) # [BLOCK_M] - - # With speculative decoding, the conv_state updates works in a sliding - # window manner, at each forward pass, the tokens are shift by 1, so we - # load since idx_tokens + 1. - conv_state_ptrs_source = ( - conv_state_ptr - + (conv_state_batch_coord * stride_conv_state_seq) - + conv_state_token_offset * stride_conv_state_tok - + (idx_feats * stride_conv_state_dim)[None, :] - + ((idx_tokens + (1 if IS_SPEC_DECODING else seqlen)) * stride_conv_state_tok)[:, None] - ) # [BLOCK_M, BLOCK_N] - mask = ( - (conv_state_batch_coord < num_cache_lines) - & ((idx_tokens + seqlen) < state_len)[:, None] - & (idx_feats < dim)[None, :] ) - conv_state = tl.load(conv_state_ptrs_source, mask, other=0.0) - - VAL = state_len - seqlen - x_base = x_ptr + x_offset + (idx_feats * stride_x_dim) # [BLOCK_N] - - x_ptrs = x_base[None, :] + ((idx_tokens - VAL) * stride_x_token)[:, None] # [BLOCK_M, BLOCK_N] - - mask_x = ( - (idx_tokens - VAL >= 0)[:, None] & (idx_tokens - VAL < seqlen)[:, None] & (idx_feats < dim)[None, :] - ) # token-index # token-index # feature-index - loaded_x = tl.load(x_ptrs, mask_x, 0.0) - tl.debug_barrier() - - new_conv_state = tl.where(mask, conv_state, loaded_x) - - conv_state_base = ( - conv_state_ptr + (conv_state_batch_coord * stride_conv_state_seq) + (idx_feats * stride_conv_state_dim) - ) # [BLOCK_N,] - conv_state_ptrs_target = conv_state_base + (idx_tokens * stride_conv_state_tok)[:, None] # [BLOCK_M, BLOCK_N] - mask = (idx_tokens < state_len)[:, None] & (idx_feats < dim)[None, :] - tl.store(conv_state_ptrs_target, new_conv_state, mask) - - # STEP 3: init accumulator - if HAS_BIAS: - bias = bias_ptr + idx_feats - mask_bias = idx_feats < dim - acc_preload = tl.load(bias, mask=mask_bias, other=0.0).to(tl.float32) # [BLOCK_N] - else: - acc_preload = tl.zeros((BLOCK_N,), dtype=tl.float32) - - # STEP 4: - # PRE-LOAD WEIGHTS - # first kernel column, configured for weights to handle BLOCK_N features in range - w_base = w_ptr + (idx_feats * stride_w_dim) # [BLOCK_N,] - mask_w = idx_feats < dim - if KERNEL_WIDTH >= 2: - w_ptrs = w_base + (0 * stride_w_width) # [BLOCK_N] tensor - w_col0 = tl.load(w_ptrs, mask_w, other=0.0) - w_ptrs = w_base + (1 * stride_w_width) # [BLOCK_N] tensor - w_col1 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 3: - w_ptrs = w_base + (2 * stride_w_width) # [BLOCK_N] tensor - w_col2 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 4: - w_ptrs = w_base + (3 * stride_w_width) # [BLOCK_N] tensor - w_col3 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 5: - w_ptrs = w_base + (4 * stride_w_width) # [BLOCK_N] tensor - w_col4 = tl.load(w_ptrs, mask_w, other=0.0) - if KERNEL_WIDTH >= 6: - w_ptrs = w_base + (5 * stride_w_width) # [BLOCK_N] tensor - w_col5 = tl.load(w_ptrs, mask_w, other=0.0) - - x_base_1d = x_base # starting of chunk [BLOCK_N] - mask_x_1d = idx_feats < dim - - # STEP 5: compute each token - for idx_token in tl.range(seqlen): - acc = acc_preload - - matrix_w = w_col0 - matrix_x = col0 - for j in tl.static_range(KERNEL_WIDTH): - if KERNEL_WIDTH == 2: - if j == 1: # KERNEL_WIDTH-1: - matrix_w = w_col1 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 3: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 4: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 5: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - elif KERNEL_WIDTH == 6: - if j == 1: - matrix_w = w_col1 - matrix_x = col1 - elif j == 2: - matrix_w = w_col2 - matrix_x = col2 - elif j == 3: - matrix_w = w_col3 - matrix_x = col3 - elif j == 4: - matrix_w = w_col4 - matrix_x = col4 - elif j == 5: - matrix_w = w_col5 - x_ptrs_1d = x_base_1d + idx_token * stride_x_token # [BLOCK_N] - matrix_x = tl.load(x_ptrs_1d, mask=mask_x_1d) - - acc += matrix_x * matrix_w # [BLOCK_N] - - if KERNEL_WIDTH == 2: - col0 = matrix_x - elif KERNEL_WIDTH == 3: - col0 = col1 - col1 = matrix_x - elif KERNEL_WIDTH == 4: - col0 = col1 - col1 = col2 - col2 = matrix_x - elif KERNEL_WIDTH == 5: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = matrix_x - elif KERNEL_WIDTH == 6: - col0 = col1 - col1 = col2 - col2 = col3 - col3 = col4 - col4 = matrix_x - - if SILU_ACTIVATION: - acc = acc / (1 + tl.exp(-acc)) - mask_1d = (idx_token < seqlen) & (idx_feats < dim) # token-index # feature-index - o_ptrs = o_ptr + o_offset + idx_token * stride_o_token + (idx_feats * stride_o_dim) - - tl.store(o_ptrs, acc, mask=mask_1d) + return x def causal_conv1d_update( @@ -889,24 +74,14 @@ def causal_conv1d_update( conv_state: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, - activation: Union[bool, str, None] = None, + activation: Optional[str] = None, cache_seqlens: Optional[torch.Tensor] = None, conv_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - query_start_loc: Optional[torch.Tensor] = None, - max_query_len: int = -1, - pad_slot_id: int = PAD_SLOT_ID, - validate_data=False, + pad_slot_id: int = -1, ): """ - x: Input tensor which can take the following shapes: - - - `[batch, dim]` - single token prediction - - `[batch, dim, seqlen]` - single or multiple tokens prediction - - `[num_tokens, dim]` - continuous batching, where num_tokens is - the total tokens of all sequences in that batch - - conv_state: (..., dim, state_len), where state_len >= width - 1 + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) cache_seqlens: (batch,), dtype int32. @@ -918,140 +93,30 @@ def causal_conv1d_update( If not None, the conv_state is a larger tensor along the batch dim, and we are selecting the batch coords specified by conv_state_indices. Useful for a continuous batching scenario. - num_accepted_tokens: (batch,), dtype int32 - If not None, it indicates the number of accepted tokens for each - sequence in the batch. - This is used in speculative decoding, where the conv_state is updated - in a sliding window manner. - query_start_loc: (batch + 1,) int32 - If not None, the inputs is given in a varlen fashion and this indicates - the starting index of each sequence in the batch. - max_query_len: int - If query_start_loc is not None, this indicates the maximum query - length in the batch. pad_slot_id: int if cache_indices is passed, lets the kernel identify padded entries that will not be processed, for example: cache_indices = [pad_slot_id, 1 ,20 ,pad_slot_id] in this case, the kernel will not process entries at indices 0 and 3 - out: (batch, dim) or (batch, dim, seqlen) or (num_tokens, dim), same shape as `x` + out: (batch, dim) or (batch, dim, seqlen) """ - if validate_data: - assert cache_seqlens is None # not implemented yet - ok for vLLM - assert pad_slot_id is not None - assert x.stride(1) == 1 - if isinstance(activation, bool): - activation = "silu" if activation is True else None - elif activation is not None: - assert activation in ["silu", "swish"] - - original_x_dtype = x.dtype - x = x.to(conv_state.dtype) - unsqueeze = query_start_loc is None and x.dim() == 2 + if activation not in [None, "silu", "swish"]: + raise NotImplementedError(f"activation must be None, silu, or swish, actual: {activation}") + activation_val = activation in ["silu", "swish"] + unsqueeze = x.dim() == 2 if unsqueeze: - # make it (batch, dim, seqlen) with seqlen == 1 x = x.unsqueeze(-1) - if query_start_loc is None: - batch, dim, seqlen = x.shape - else: - assert conv_state_indices is not None - batch = conv_state_indices.size(0) - dim = x.size(1) - seqlen = max_query_len - _, width = weight.shape - # conv_state: (..., dim, state_len), where state_len >= width - 1 - num_cache_lines, _, state_len = conv_state.size() - - if validate_data: - assert dim == weight.size(0) - assert ( - conv_state.stride(-2) == 1 - ), f"ERROR: expect contiguous along feat-dim of conv_state (currently stride={conv_state.stride()})" - assert state_len >= width - 1 - # when above happens, we don't shift-left to keep any records in conv_state - assert dim == conv_state.size(1) - if conv_state_indices is None: - assert conv_state.size(0) >= batch - else: - assert (batch,) == conv_state_indices.shape - - assert num_cache_lines >= batch - assert weight.stride(1) == 1 # Need this - assert cache_seqlens is None # not needed for vLLM - circular buffer - - # adopt the strategy in vLLM that overwrite on 'x' directly, rather than creating a new tensor 'o' - out = x - stride_w_dim, stride_w_width = weight.stride() - - if query_start_loc is None: - # X (batch, dim, seqlen) - stride_x_seq, stride_x_dim, stride_x_token = x.stride() - stride_o_seq, stride_o_dim, stride_o_token = out.stride() - else: - # X (dim, cu_seqlen) - stride_x_token, stride_x_dim = x.stride() - stride_x_seq = 0 - stride_o_token, stride_o_dim = out.stride() - stride_o_seq = 0 - - stride_istate_seq, stride_istate_dim, stride_istate_token = conv_state.stride() - stride_state_indices = conv_state_indices.stride(0) if conv_state_indices is not None else 0 - if num_accepted_tokens is not None: - state_len = width - 1 + (seqlen - 1) # effective state_len needed - else: - state_len = width - 1 - np2_statelen = triton.next_power_of_2(state_len) - - def grid(META): - return ( - batch, - triton.cdiv(dim, META["BLOCK_N"]), - ) - - _causal_conv1d_update_kernel[grid]( - # Pointers to matrices + causal_conv1d_update_kernel( x, + conv_state, weight, bias, - conv_state, + activation_val, cache_seqlens, conv_state_indices, - num_accepted_tokens, - query_start_loc, - out, - # Matrix dimensions - batch, - dim, - seqlen, - state_len, - num_cache_lines, - # stride - stride_x_seq, - stride_x_dim, - stride_x_token, - stride_w_dim, - stride_w_width, - stride_istate_seq, - stride_istate_dim, - stride_istate_token, - stride_state_indices, - stride_o_seq, - stride_o_dim, - stride_o_token, - # others pad_slot_id, - # META - HAS_BIAS=bias is not None, - KERNEL_WIDTH=width, - SILU_ACTIVATION=activation in ["silu", "swish"], - IS_VARLEN=query_start_loc is not None, - IS_CONTINUOUS_BATCHING=conv_state_indices is not None, - IS_SPEC_DECODING=num_accepted_tokens is not None, - NP2_STATELEN=np2_statelen, - USE_PAD_SLOT=pad_slot_id is not None, - BLOCK_N=256, ) if unsqueeze: - out = out.squeeze(-1) - return out.to(original_x_dtype) + x = x.squeeze(-1) + return x diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py index 50f7f20b7..0e89cf9f7 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -6,11 +6,3 @@ # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - -from .ops.chunk import chunk_gated_delta_rule -from .ops.fused_recurrent import fused_recurrent_gated_delta_rule - -__all__ = [ - "chunk_gated_delta_rule", - "fused_recurrent_gated_delta_rule", -] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py index 22c81ae63..944c71e2b 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import warnings -from typing import Optional import torch from einops import rearrange @@ -32,11 +31,11 @@ def chunk_gated_delta_rule_fwd( scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( k=k, @@ -84,7 +83,7 @@ def forward( scale: float, initial_state: torch.Tensor, output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): if use_qk_l2norm_in_kernel: @@ -117,7 +116,7 @@ def chunk_gated_delta_rule( scale: float = None, initial_state: torch.Tensor = None, output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, head_first: bool = False, use_qk_l2norm_in_kernel: bool = False, ): @@ -197,8 +196,8 @@ def chunk_gated_delta_rule( q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) if not head_first and q.shape[1] < q.shape[2]: warnings.warn( - f"Input tensor shape suggests potential format mismatch" - f" seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "Input tensor shape suggests potential" + f" format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", @@ -218,7 +217,16 @@ def chunk_gated_delta_rule( if scale is None: scale = k.shape[-1] ** -0.5 o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel + q, + k, + v, + g, + beta, + scale, + initial_state, + output_final_state, + cu_seqlens, + use_qk_l2norm_in_kernel, ) if head_first: o = rearrange(o, "b t h ... -> b h t ...") diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py index f20c95d90..e4c020f49 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -16,14 +15,15 @@ from .index import prepare_chunk_indices, prepare_chunk_offsets from .op import exp -from .utils import is_nvidia_hopper, use_cuda_graph +from .utils import use_cuda_graph -NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] +NUM_WARPS = [2, 4, 8, 16] @triton.heuristics( { "USE_G": lambda args: args["g"] is not None, + "USE_GK": lambda args: args["gk"] is not None, "USE_INITIAL_STATE": lambda args: args["h0"] is not None, "STORE_FINAL_STATE": lambda args: args["ht"] is not None, "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, @@ -37,7 +37,7 @@ for num_stages in [2, 3, 4] for BV in [32, 64] ], - key=["H", "K", "V", "BT", "USE_G"], + key=["H", "K", "V", "BT"], use_cuda_graph=use_cuda_graph, ) @triton.jit(do_not_specialize=["T"]) @@ -47,6 +47,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( w, v_new, g, + gk, h, h0, ht, @@ -60,6 +61,7 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( BT: tl.constexpr, BV: tl.constexpr, USE_G: tl.constexpr, + USE_GK: tl.constexpr, USE_INITIAL_STATE: tl.constexpr, STORE_FINAL_STATE: tl.constexpr, SAVE_NEW_VALUE: tl.constexpr, @@ -68,7 +70,10 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( i_v, i_nh = tl.program_id(0), tl.program_id(1) i_n, i_h = i_nh // H, i_nh % H if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos NT = tl.cdiv(T, BT) boh = tl.load(chunk_offsets + i_n).to(tl.int32) @@ -87,12 +92,12 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( b_h4 = tl.zeros([64, BV], dtype=tl.float32) # calculate offset - h += (boh * H + i_h) * K * V - v += (bos * H + i_h) * V - k += (bos * Hg + i_h // (H // Hg)) * K - w += (bos * H + i_h) * K + h += ((boh * H + i_h) * K * V).to(tl.int64) + v += ((bos * H + i_h) * V).to(tl.int64) + k += ((bos * Hg + i_h // (H // Hg)) * K).to(tl.int64) + w += ((bos * H + i_h) * K).to(tl.int64) if SAVE_NEW_VALUE: - v_new += (bos * H + i_h) * V + v_new += ((bos * H + i_h) * V).to(tl.int64) stride_v = H * V stride_h = H * K * V stride_k = Hg * K @@ -130,66 +135,93 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) - p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_v_new = ( - tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - if SAVE_NEW_VALUE - else None - ) - b_v_new = tl.zeros([BT, BV], dtype=tl.float32) p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + b_v = tl.dot(b_w, b_h1.to(b_w.dtype)) if K > 64: p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + b_v += tl.dot(b_w, b_h2.to(b_w.dtype)) if K > 128: p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + b_v += tl.dot(b_w, b_h3.to(b_w.dtype)) if K > 192: p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) - b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + b_v += tl.dot(b_w, b_h4.to(b_w.dtype)) + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) - b_v if SAVE_NEW_VALUE: - p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + p_v = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v, b_v.to(p_v.dtype.element_ty), boundary_check=(0, 1)) + last_idx = min((i_t + 1) * BT, T) - 1 if USE_G: m_t = (i_t * BT + tl.arange(0, BT)) < T - last_idx = min((i_t + 1) * BT, T) - 1 b_g_last = tl.load(g + bos * H + last_idx * H + i_h) p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) - b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] b_g_last = exp(b_g_last) - b_h1 = b_h1 * b_g_last + b_h1 *= b_g_last if K > 64: - b_h2 = b_h2 * b_g_last + b_h2 *= b_g_last if K > 128: - b_h3 = b_h3 * b_g_last + b_h3 *= b_g_last if K > 192: - b_h4 = b_h4 * b_g_last - b_v_new = b_v_new.to(k.dtype.element_ty) + b_h4 *= b_g_last + + if USE_GK: + o_k1 = tl.arange(0, 64) + b_gk_last1 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k1, + mask=(o_k1 < K), + other=0.0, + ) + b_h1 *= exp(b_gk_last1)[:, None] + if K > 64: + o_k2 = 64 + o_k1 + b_gk_last2 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k2, + mask=(o_k2 < K), + other=0.0, + ) + b_h2 *= exp(b_gk_last2)[:, None] + if K > 128: + o_k3 = 128 + o_k1 + b_gk_last3 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k3, + mask=(o_k3 < K), + other=0.0, + ) + b_h3 *= exp(b_gk_last3)[:, None] + if K > 192: + o_k4 = 192 + o_k1 + b_gk_last4 = tl.load( + gk + (bos + last_idx) * H * K + i_h * K + o_k4, + mask=(o_k4 < K), + other=0.0, + ) + b_h4 *= exp(b_gk_last4)[:, None] + b_v = b_v.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h1 += tl.dot(b_k, b_v_new) + b_h1 += tl.dot(b_k, b_v) if K > 64: p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h2 += tl.dot(b_k, b_v_new) + b_h2 += tl.dot(b_k, b_v) if K > 128: p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h3 += tl.dot(b_k, b_v_new) + b_h3 += tl.dot(b_k, b_v) if K > 192: p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h4 += tl.dot(b_k, b_v_new) - + b_h4 += tl.dot(b_k, b_v) # epilogue if STORE_FINAL_STATE: p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) @@ -209,13 +241,16 @@ def chunk_gated_delta_rule_fwd_h( k: torch.Tensor, w: torch.Tensor, u: torch.Tensor, - g: Optional[torch.Tensor] = None, - initial_state: Optional[torch.Tensor] = None, + g: torch.Tensor | None = None, + gk: torch.Tensor | None = None, + initial_state: torch.Tensor | None = None, output_final_state: bool = False, chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, + cu_seqlens: torch.LongTensor | None = None, ) -> tuple[torch.Tensor, torch.Tensor]: + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. B, T, Hg, K, V = *k.shape, u.shape[-1] H = u.shape[-2] BT = chunk_size @@ -225,7 +260,11 @@ def chunk_gated_delta_rule_fwd_h( if cu_seqlens is None: N, NT, chunk_offsets = B, triton.cdiv(T, BT), None else: - N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + N, NT, chunk_offsets = ( + len(cu_seqlens) - 1, + len(chunk_indices), + prepare_chunk_offsets(cu_seqlens, BT), + ) assert K <= 256, "current kernel does not support head dimension larger than 256." h = k.new_empty(B, NT, H, K, V) @@ -242,6 +281,7 @@ def grid(meta): w=w, v_new=v_new, g=g, + gk=gk, h=h, h0=initial_state, ht=final_state, diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py index 73c2e1f19..b70185a05 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -9,7 +9,6 @@ # ruff: noqa: E501 -from typing import Optional import torch @@ -25,7 +24,10 @@ @triton.heuristics( - {"USE_G": lambda args: args["g"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } ) @triton.autotune( configs=[ @@ -64,8 +66,14 @@ def chunk_fwd_kernel_o( if IS_VARLEN: i_tg = i_t - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos NT = tl.cdiv(T, BT) else: @@ -126,17 +134,14 @@ def chunk_fwd_o( k: torch.Tensor, v: torch.Tensor, h: torch.Tensor, - g: Optional[torch.Tensor] = None, # cumsum of log decay - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, + g: torch.Tensor | None = None, # cumsum of log decay + scale: float | None = None, + cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] - if FLA_GDN_FIX_BT: - BT = 64 - else: - BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) if scale is None: diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py index aa545e8ec..31ba1c7df 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -19,7 +18,10 @@ @triton.heuristics( - {"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "USE_G": lambda args: args["g_cumsum"] is not None} + { + "USE_G": lambda args: args["g"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } ) @triton.autotune( configs=[ @@ -34,7 +36,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( k, beta, - g_cumsum, + g, A, cu_seqlens, chunk_indices, @@ -50,8 +52,14 @@ def chunk_scaled_dot_kkt_fwd_kernel( i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -64,14 +72,19 @@ def chunk_scaled_dot_kkt_fwd_kernel( b_A = tl.zeros([BT, BT], dtype=tl.float32) for i_k in range(tl.cdiv(K, BK)): p_k = tl.make_block_ptr( - k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), ) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = b_k * b_beta[:, None] b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) if USE_G: - p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) b_g_diff = b_g[:, None] - b_g[None, :] b_A = b_A * exp(b_g_diff) @@ -84,9 +97,9 @@ def chunk_scaled_dot_kkt_fwd_kernel( def chunk_scaled_dot_kkt_fwd( k: torch.Tensor, - beta: torch.Tensor, - g_cumsum: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, + g: torch.Tensor | None = None, + beta: torch.Tensor | None = None, + cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, ) -> torch.Tensor: @@ -98,9 +111,8 @@ def chunk_scaled_dot_kkt_fwd( The key tensor of shape `[B, T, H, K]`. beta (torch.Tensor): The beta tensor of shape `[B, T, H]`. - g_cumsum (torch.Tensor): - The cumulative sum of the gate tensor of shape `[B, T, H]`. - Default: None + g (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. Default: `None`. cu_seqlens (torch.LongTensor): The cumulative sequence lengths of the input tensor. Default: None @@ -112,18 +124,19 @@ def chunk_scaled_dot_kkt_fwd( Returns: beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. """ - + # This kernel is slightly different from fla to support Q/K with different head numbers. + # In fla, Q/K always have the same head number, so Hg is always equal to H. B, T, Hg, K = k.shape - H = beta.shape[-1] BT = chunk_size chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( k=k, + g=g, beta=beta, - g_cumsum=g_cumsum, A=A, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py index 9cd6a6545..98dbccae8 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 import warnings -from typing import Optional import torch @@ -43,8 +42,14 @@ def chunk_local_cumsum_scalar_kernel( i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -88,8 +93,14 @@ def chunk_local_cumsum_vector_kernel( i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -101,11 +112,39 @@ def chunk_local_cumsum_vector_kernel( m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) if HEAD_FIRST: - p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_s = tl.make_block_ptr( + s + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h * T) * S, + (T, S), + (S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) else: - p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_s = tl.make_block_ptr( + s + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) + p_o = tl.make_block_ptr( + o + (bos * H + i_h) * S, + (T, S), + (H * S, 1), + (i_t * BT, i_s * BS), + (BT, BS), + (1, 0), + ) # [BT, BS] b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) b_o = tl.dot(m_s, b_s, allow_tf32=False) @@ -116,9 +155,9 @@ def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if head_first: B, H, T = g.shape @@ -131,7 +170,16 @@ def chunk_local_cumsum_scalar( g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) grid = (NT, B * H) chunk_local_cumsum_scalar_kernel[grid]( - g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, ) return g @@ -140,9 +188,9 @@ def chunk_local_cumsum_vector( g: torch.Tensor, chunk_size: int, reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: torch.dtype | None = torch.float, ) -> torch.Tensor: if head_first: B, H, T, S = g.shape @@ -162,7 +210,17 @@ def grid(meta): # this kernel is equivalent to # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) chunk_local_cumsum_vector_kernel[grid]( - g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, S=S, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + g_org, + g, + cu_seqlens, + chunk_indices, + T=T, + B=B, + H=H, + S=S, + BT=BT, + HEAD_FIRST=head_first, + REVERSE=reverse, ) return g @@ -172,15 +230,15 @@ def chunk_local_cumsum( g: torch.Tensor, chunk_size: int, reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, + cu_seqlens: torch.Tensor | None = None, head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, + output_dtype: torch.dtype | None = torch.float, **kwargs, ) -> torch.Tensor: if not head_first and g.shape[1] < g.shape[2]: warnings.warn( - f"Input tensor shape suggests potential format mismatch" - f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + f"Input tensor shape suggests potential format mismatch: " + f"seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " "This may indicate the inputs were passed in head-first format [B, H, T, ...] " "when head_first=False was specified. " "Please verify your input tensor format matches the expected shape [B, T, H, ...].", diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py index 4ff18d4f6..e399b3c0a 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/fused_recurrent.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -59,12 +58,16 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( IS_VARLEN: tl.constexpr, IS_CONTINUOUS_BATCHING: tl.constexpr, IS_SPEC_DECODING: tl.constexpr, + IS_KDA: tl.constexpr, ): i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) i_n, i_hv = i_nh // HV, i_nh % HV i_h = i_hv // (HV // H) if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int64), + tl.load(cu_seqlens + i_n + 1).to(tl.int64), + ) all = T T = eos - bos else: @@ -85,7 +88,12 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_beta = beta + (bos * HV + i_hv) * V + o_v else: p_beta = beta + bos * HV + i_hv - p_g = g + bos * HV + i_hv + + if not IS_KDA: + p_g = g + bos * HV + i_hv + else: + p_gk = g + (bos * HV + i_hv) * K + o_k + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v mask_k = o_k < K @@ -111,14 +119,18 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) - b_g = tl.load(p_g).to(tl.float32) if USE_QK_L2NORM_IN_KERNEL: b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) b_q = b_q * scale # [BK, BV] - b_h *= exp(b_g) + if not IS_KDA: + b_g = tl.load(p_g).to(tl.float32) + b_h *= exp(b_g) + else: + b_gk = tl.load(p_gk).to(tl.float32) + b_h *= exp(b_gk[:, None]) # [BV] b_v -= tl.sum(b_h * b_k[:, None], 0) if IS_BETA_HEADWISE: @@ -146,7 +158,10 @@ def fused_recurrent_gated_delta_rule_fwd_kernel( p_k += H * K p_o += HV * V p_v += HV * V - p_g += HV + if not IS_KDA: + p_g += HV + else: + p_gk += HV * K p_beta += HV * (V if IS_BETA_HEADWISE else 1) @@ -159,9 +174,9 @@ def fused_recurrent_gated_delta_rule_fwd( scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, H, K, V = *k.shape, v.shape[-1] @@ -219,6 +234,7 @@ def fused_recurrent_gated_delta_rule_fwd( IS_BETA_HEADWISE=beta.ndim == v.ndim, USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, INPLACE_FINAL_STATE=inplace_final_state, + IS_KDA=False, num_warps=num_warps, num_stages=num_stages, ) @@ -238,9 +254,9 @@ def forward( scale: float, initial_state: torch.Tensor, inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ): o, final_state = fused_recurrent_gated_delta_rule_fwd( @@ -270,9 +286,9 @@ def fused_recurrent_gated_delta_rule( scale: float = None, initial_state: torch.Tensor = None, inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, + cu_seqlens: torch.LongTensor | None = None, + ssm_state_indices: torch.Tensor | None = None, + num_accepted_tokens: torch.Tensor | None = None, use_qk_l2norm_in_kernel: bool = False, ) -> tuple[torch.Tensor, torch.Tensor]: r""" diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py index 7225cd4ae..0c67bcf9b 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang import os -from typing import Optional import torch @@ -20,7 +19,10 @@ USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) -@triton.autotune(configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], key=["D"]) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], + key=["D"], +) @triton.jit def l2norm_fwd_kernel1( x, @@ -81,7 +83,7 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) -def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): x_shape_og = x.shape x = x.view(-1, x.shape[-1]) # allocate output diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py index ec0999455..d35c71ab9 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -12,25 +12,46 @@ import triton import triton.language as tl +from .utils import is_gather_supported -@triton.jit -def div_normal(x, y): - return x / y - - -div = div_normal exp = tl.exp log = tl.log log2 = tl.log2 -if not hasattr(tl, "gather"): +if not is_gather_supported: @triton.jit def gather(src, index, axis, _builder=None): - # This is a fallback implementation when tl.gather is not supported - # In order to pass triton compiler, there is no actual gather operation - return src + """ + Gather operation that works when tl.gather is not supported. + This is a fallback implementation that returns None. + Just to make triton compiler happy. + """ + return None else: gather = tl.gather + +if hasattr(triton.language, "_experimental_make_tensor_descriptor"): + # For Triton 3.3.x + make_tensor_descriptor = triton.language._experimental_make_tensor_descriptor +elif hasattr(triton.language, "make_tensor_descriptor"): + # For Triton 3.4.x and later + make_tensor_descriptor = triton.language.make_tensor_descriptor +else: + """ + Fallback implementation when TMA is not supported. + Returns None to indicate TMA descriptors are unavailable. + Just make triton compiler happy. + """ + + @triton.jit + def make_tensor_descriptor( + base, + shape, + strides, + block_shape, + _builder=None, + ): + return None diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py index 46e4d5082..adfd8b11e 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -7,7 +7,8 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional + +import os import torch @@ -15,7 +16,14 @@ import triton.language as tl from .index import prepare_chunk_indices -from .utils import input_guard +from .op import make_tensor_descriptor +from .utils import input_guard, is_amd, is_tma_supported + +FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") +ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] +assert ( + FLA_TRIL_PRECISION in ALLOWED_TRIL_PRECISIONS +), f"FLA_TRIL_PRECISION must be one of {ALLOWED_TRIL_PRECISIONS}, but got {FLA_TRIL_PRECISION}" @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @@ -30,40 +38,63 @@ @triton.jit(do_not_specialize=["T"]) def solve_tril_16x16_kernel( A, - Ad, + Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, + USE_TMA: tl.constexpr, IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] A = A + (bos * H + i_h) * BT - Ad = Ad + (bos * H + i_h) * 16 + Ai = Ai + (bos * H + i_h) * 16 offset = (i_t * 16) % BT - p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) - p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) - b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) - b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + if not USE_TMA: + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + # [16, 16] + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, 16], [H * 16, 1], [16, 16]) + b_A = desc.load([i_t * 16, offset]).to(tl.float32) + b_A = -tl.where(m_A, b_A, 0) - o_i = tl.arange(0, 16) - for i in range(1, min(16, T - i_t * 16)): + for i in range(2, min(16, T - i_t * 16)): + # [16] b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) - mask = o_i == i - b_A = tl.where(mask[:, None], b_a, b_A) - b_A += o_i[:, None] == o_i[None, :] - tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + b_A = tl.where((o_i == i)[:, None], b_a, b_A) + b_A += m_I + if not USE_TMA: + p_Ai = tl.make_block_ptr(Ai, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + tl.store( + p_Ai, + b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * 16, 0], b_A.to(desc_o.dtype, fp_downcast_rounding="rtne")) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @@ -77,35 +108,100 @@ def solve_tril_16x16_kernel( ) @triton.jit(do_not_specialize=["T"]) def merge_16x16_to_32x32_inverse_kernel( - A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - A += (bos * H + i_h) * 32 - Ad += (bos * H + i_h) * 16 - Ai += (bos * H + i_h) * 32 + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT - p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) - p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) - A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) - Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) - Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") - tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + + b_Ai_11 += m_I + b_Ai_22 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) @@ -119,153 +215,257 @@ def merge_16x16_to_32x32_inverse_kernel( ) @triton.jit(do_not_specialize=["T"]) def merge_16x16_to_64x64_inverse_kernel( - A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr + A, + Ai, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + USE_TMA: tl.constexpr, + IS_VARLEN: tl.constexpr, + DOT_PRECISION: tl.constexpr, ): i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T - A += (bos * H + i_h) * 64 - Ad += (bos * H + i_h) * 16 - Ai += (bos * H + i_h) * 64 - - p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) - p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) - p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) - p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) - p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) - p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) - p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) - p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) - - A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) - A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) - A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) - A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) - A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) - A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) - - Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) - Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) - Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) - - Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") - Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee") - Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee") - - Ai_31 = -tl.dot( - Ai_33, - tl.dot(A_31, Ai_11, input_precision="ieee") + tl.dot(A_32, Ai_21, input_precision="ieee"), - input_precision="ieee", + o_i = tl.arange(0, 16) + m_A = o_i[:, None] > o_i[None, :] + m_I = o_i[:, None] == o_i[None, :] + A += (bos * H + i_h) * BT + Ai += (bos * H + i_h) * BT + + if not USE_TMA: + p_A_11 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_A_22 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_A_33 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_A_44 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + b_Ai_11 = tl.load(p_A_11, boundary_check=(0, 1)).to(tl.float32) + b_Ai_22 = tl.load(p_A_22, boundary_check=(0, 1)).to(tl.float32) + b_Ai_33 = tl.load(p_A_33, boundary_check=(0, 1)).to(tl.float32) + b_Ai_44 = tl.load(p_A_44, boundary_check=(0, 1)).to(tl.float32) + else: + desc = make_tensor_descriptor(A, [T, BT], [H * BT, 1], [16, 16]) + desc_o = make_tensor_descriptor(Ai, [T, BT], [H * BT, 1], [16, 16]) + b_Ai_11 = desc.load([i_t * BT + 0, 0]).to(tl.float32) + b_Ai_22 = desc.load([i_t * BT + 16, 16]).to(tl.float32) + b_Ai_33 = desc.load([i_t * BT + 32, 32]).to(tl.float32) + b_Ai_44 = desc.load([i_t * BT + 48, 48]).to(tl.float32) + + # [16, 16] + b_Ai_11 = -tl.where(m_A, b_Ai_11, 0) + b_Ai_22 = -tl.where(m_A, b_Ai_22, 0) + b_Ai_33 = -tl.where(m_A, b_Ai_33, 0) + b_Ai_44 = -tl.where(m_A, b_Ai_44, 0) + + for i in range(2, min(16, T - i_t * BT)): + b_a_11 = -tl.load(A + (i_t * BT + i) * H * BT + o_i) + b_a_11 += tl.sum(b_a_11[:, None] * b_Ai_11, 0) + b_Ai_11 = tl.where((o_i == i)[:, None], b_a_11, b_Ai_11) + for i in range(16 + 2, min(32, T - i_t * BT)): + b_a_22 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 16) + b_a_22 += tl.sum(b_a_22[:, None] * b_Ai_22, 0) + b_Ai_22 = tl.where((o_i == i - 16)[:, None], b_a_22, b_Ai_22) + for i in range(32 + 2, min(48, T - i_t * BT)): + b_a_33 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 32) + b_a_33 += tl.sum(b_a_33[:, None] * b_Ai_33, 0) + b_Ai_33 = tl.where((o_i == i - 32)[:, None], b_a_33, b_Ai_33) + for i in range(48 + 2, min(64, T - i_t * BT)): + b_a_44 = -tl.load(A + (i_t * BT + i) * H * BT + o_i + 48) + b_a_44 += tl.sum(b_a_44[:, None] * b_Ai_44, 0) + b_Ai_44 = tl.where((o_i == i - 48)[:, None], b_a_44, b_Ai_44) + b_Ai_11 += m_I + b_Ai_22 += m_I + b_Ai_33 += m_I + b_Ai_44 += m_I + + if not USE_TMA: + p_A_21 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + b_A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + b_A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + b_A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + b_A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + b_A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + b_A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + else: + b_A_21 = desc.load([i_t * BT + 16, 0]).to(tl.float32) + b_A_31 = desc.load([i_t * BT + 32, 0]).to(tl.float32) + b_A_32 = desc.load([i_t * BT + 32, 16]).to(tl.float32) + b_A_41 = desc.load([i_t * BT + 48, 0]).to(tl.float32) + b_A_42 = desc.load([i_t * BT + 48, 16]).to(tl.float32) + b_A_43 = desc.load([i_t * BT + 48, 32]).to(tl.float32) + + b_Ai_21 = -tl.dot( + tl.dot(b_Ai_22, b_A_21, input_precision=DOT_PRECISION), + b_Ai_11, + input_precision=DOT_PRECISION, ) - Ai_42 = -tl.dot( - Ai_44, - tl.dot(A_42, Ai_22, input_precision="ieee") + tl.dot(A_43, Ai_32, input_precision="ieee"), - input_precision="ieee", + b_Ai_32 = -tl.dot( + tl.dot(b_Ai_33, b_A_32, input_precision=DOT_PRECISION), + b_Ai_22, + input_precision=DOT_PRECISION, ) - Ai_41 = -tl.dot( - Ai_44, - tl.dot(A_41, Ai_11, input_precision="ieee") - + tl.dot(A_42, Ai_21, input_precision="ieee") - + tl.dot(A_43, Ai_31, input_precision="ieee"), - input_precision="ieee", + b_Ai_43 = -tl.dot( + tl.dot(b_Ai_44, b_A_43, input_precision=DOT_PRECISION), + b_Ai_33, + input_precision=DOT_PRECISION, ) - p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) - p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) - p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) - p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) - p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) - p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) - p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) - p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) - tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - - fill_zeros = tl.zeros((16, 16), dtype=tl.float32) - p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)) - p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)) - p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)) - p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)) - p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)) - p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)) - tl.store(p_Ai_12, fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_13, fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_14, fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_23, fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_24, fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_34, fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + b_Ai_31 = -tl.dot( + b_Ai_33, + tl.dot(b_A_31, b_Ai_11, input_precision=DOT_PRECISION) + tl.dot(b_A_32, b_Ai_21, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_42 = -tl.dot( + b_Ai_44, + tl.dot(b_A_42, b_Ai_22, input_precision=DOT_PRECISION) + tl.dot(b_A_43, b_Ai_32, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + b_Ai_41 = -tl.dot( + b_Ai_44, + tl.dot(b_A_41, b_Ai_11, input_precision=DOT_PRECISION) + + tl.dot(b_A_42, b_Ai_21, input_precision=DOT_PRECISION) + + tl.dot(b_A_43, b_Ai_31, input_precision=DOT_PRECISION), + input_precision=DOT_PRECISION, + ) + + if not USE_TMA: + p_Ai_11 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, BT), (H * BT, 1), (i_t * BT + 48, 32), (16, 16), (1, 0)) + tl.store( + p_Ai_11, + b_Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_22, + b_Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_33, + b_Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_44, + b_Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_21, + b_Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_31, + b_Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_32, + b_Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_41, + b_Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_42, + b_Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + tl.store( + p_Ai_43, + b_Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), + boundary_check=(0, 1), + ) + else: + desc_o.store([i_t * BT + 0, 0], b_Ai_11.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 16], b_Ai_22.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 32], b_Ai_33.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 48], b_Ai_44.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 16, 0], b_Ai_21.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 0], b_Ai_31.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 32, 16], b_Ai_32.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 0], b_Ai_41.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 16], b_Ai_42.to(desc_o.dtype, fp_downcast_rounding="rtne")) + desc_o.store([i_t * BT + 48, 32], b_Ai_43.to(desc_o.dtype, fp_downcast_rounding="rtne")) @input_guard def solve_tril( - A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float + A: torch.Tensor, + cu_seqlens: torch.Tensor | None = None, + output_dtype: torch.dtype = torch.float, ) -> torch.Tensor: """ - Compute the inverse of the lower triangular matrix + Compute the inverse of the matrix I + A A should be strictly lower triangular, i.e., A.triu() == 0. Args: A (torch.Tensor): - [B, T, H, K] + [B, T, H, BT], where BT should only be 16, 32, or 64. cu_seqlens (torch.Tensor): - The cumulative sequence lengths of the input tensor. - Default: None. + The cumulative sequence lengths of the input tensor. Default: `None`. output_dtype (torch.dtype): - The dtype of the output tensor. Default: `torch.float` + The dtype of the output tensor. Default: `torch.float`. + If `None`, the output dtype will be the same as the input dtype. Returns: (I + A)^-1 with the same shape as A """ assert A.shape[-1] in [16, 32, 64] + output_dtype = A.dtype if output_dtype is None else output_dtype B, T, H, BT = A.shape - Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) - chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None - NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) - solve_tril_16x16_kernel[NT, B * H]( - A=A, - Ad=Ad, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - BT=BT, - ) + Ai = torch.zeros_like(A, dtype=output_dtype) if BT == 16: - return Ad + merge_fn = solve_tril_16x16_kernel + elif BT == 32: + merge_fn = merge_16x16_to_32x32_inverse_kernel + elif BT == 64: + merge_fn = merge_16x16_to_64x64_inverse_kernel - Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) - merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) merge_fn[NT, B * H]( A=A, - Ad=Ad, Ai=Ai, cu_seqlens=cu_seqlens, chunk_indices=chunk_indices, T=T, H=H, BT=BT, + USE_TMA=is_tma_supported, + DOT_PRECISION=FLA_TRIL_PRECISION, ) return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py index d8f29f287..a890d7010 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/utils.py @@ -11,8 +11,9 @@ import functools import logging import os +from collections.abc import Callable from enum import Enum -from typing import Any, Callable, Literal, Optional +from typing import Any, Literal import torch @@ -43,8 +44,8 @@ def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor] A wrapped version of the input function with single-entry caching. """ - cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] - cache_size = 4 + cache_entries: tuple[tuple | None, dict | None, Any] = [] + cache_size = 8 @functools.wraps(fn) def wrapper(*args: Any, **kwargs: Any) -> Any: @@ -125,8 +126,8 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: # For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. # However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. # Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != "hip" else "cuda" -device_torch_lib = getattr(torch, device) +device = "cuda" +device_torch_lib = getattr(torch, device, None) device_platform = _check_platform() is_amd = device_platform == "amd" @@ -136,7 +137,12 @@ def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: is_nvidia_hopper = is_nvidia and ( "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 ) -use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" +use_cuda_graph = True +is_gather_supported = hasattr(triton.language, "gather") +is_tma_supported = (is_nvidia and torch.cuda.get_device_capability(0)[0] >= 9) and ( + hasattr(triton.language, "_experimental_make_tensor_descriptor") + or hasattr(triton.language, "make_tensor_descriptor") +) def get_all_max_shared_mem(): diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py index dec8d2ffc..c5eaa9534 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -8,7 +8,6 @@ # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -from typing import Optional import torch @@ -51,8 +50,14 @@ def recompute_w_u_fwd_kernel( i_t, i_bh = tl.program_id(0), tl.program_id(1) i_b, i_h = i_bh // H, i_bh % H if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + i_n, i_t = ( + tl.load(chunk_indices + i_t * 2).to(tl.int32), + tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32), + ) + bos, eos = ( + tl.load(cu_seqlens + i_n).to(tl.int32), + tl.load(cu_seqlens + i_n + 1).to(tl.int32), + ) T = eos - bos else: bos, eos = i_b * T, i_b * T + T @@ -64,8 +69,22 @@ def recompute_w_u_fwd_kernel( b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) for i_v in range(tl.cdiv(V, BV)): - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v = tl.make_block_ptr( + v + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) + p_u = tl.make_block_ptr( + u + (bos * H + i_h) * V, + (T, V), + (H * V, 1), + (i_t * BT, i_v * BV), + (BT, BV), + (1, 0), + ) b_v = tl.load(p_v, boundary_check=(0, 1)) b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) b_u = tl.dot(b_A, b_vb, allow_tf32=False) @@ -73,9 +92,21 @@ def recompute_w_u_fwd_kernel( for i_k in range(tl.cdiv(K, BK)): p_k = tl.make_block_ptr( - k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + k + (bos * Hg + i_h // (H // Hg)) * K, + (T, K), + (Hg * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), + ) + p_w = tl.make_block_ptr( + w + (bos * H + i_h) * K, + (T, K), + (H * K, 1), + (i_t * BT, i_k * BK), + (BT, BK), + (1, 0), ) - p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) b_k = tl.load(p_k, boundary_check=(0, 1)) b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) b_w = tl.dot(b_A, b_kb) @@ -88,7 +119,7 @@ def recompute_w_u_fwd( beta: torch.Tensor, g_cumsum: torch.Tensor, A: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor], + cu_seqlens: torch.LongTensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: B, T, Hg, K, V = *k.shape, v.shape[-1] H = v.shape[-2] diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py new file mode 100644 index 000000000..cd3b0962a --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py @@ -0,0 +1,15 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +from .chunk import chunk_gated_delta_rule +from .fused_recurrent import fused_recurrent_gated_delta_rule + +__all__ = [ + "chunk_gated_delta_rule", + "fused_recurrent_gated_delta_rule", +] diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py new file mode 100644 index 000000000..22c81ae63 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py @@ -0,0 +1,225 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch +from einops import rearrange + +from .chunk_delta_h import chunk_gated_delta_rule_fwd_h +from .chunk_o import chunk_fwd_o +from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd +from .cumsum import chunk_local_cumsum +from .l2norm import l2norm_fwd +from .solve_tril import solve_tril +from .utils import SUPPRESS_LEVEL, input_guard +from .wy_fast import recompute_w_u_fwd + + +def chunk_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, +): + g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) + # obtain WY representation. u is actually the new v. + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) + A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) + w, u = recompute_w_u_fwd( + k=k, + v=v, + beta=beta, + A=A, + g_cumsum=g, + cu_seqlens=cu_seqlens, + ) + h, v_new, final_state = chunk_gated_delta_rule_fwd_h( + k=k, + w=w, + u=u, + g=g, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + o = chunk_fwd_o( + q=q, + k=k, + v=v_new, + h=h, + g=g, + scale=scale, + cu_seqlens=cu_seqlens, + ) + if SUPPRESS_LEVEL < 3: + return g, o, A, final_state, None, None, None + elif SUPPRESS_LEVEL >= 3: + return g, o, A, final_state, w, h, v_new + + +class ChunkGatedDeltaRuleFunction(torch.autograd.Function): + @staticmethod + @input_guard + @torch.amp.custom_fwd(device_type="cuda") + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + output_final_state: bool, + cu_seqlens: Optional[torch.LongTensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + if use_qk_l2norm_in_kernel: + q = l2norm_fwd(q) + k = l2norm_fwd(k) + + g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( + q=q, + k=k, + v=v, + g=g, + beta=beta, + scale=scale, + initial_state=initial_state, + output_final_state=output_final_state, + cu_seqlens=cu_seqlens, + ) + ctx.scale = scale + ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel + return o.to(q.dtype), final_state + + +@torch.compiler.disable +def chunk_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float = None, + initial_state: torch.Tensor = None, + output_final_state: bool = False, + cu_seqlens: Optional[torch.LongTensor] = None, + head_first: bool = False, + use_qk_l2norm_in_kernel: bool = False, +): + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. + v (torch.Tensor): + values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + g (torch.Tensor): + (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + beta (torch.Tensor): + betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, H, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + output_final_state (Optional[bool]): + Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + head_first (Optional[bool]): + Whether the inputs are in the head-first format, which is not supported for variable-length inputs. + Default: `False`. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. + final_state (torch.Tensor): + Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, K, V = 4, 2048, 4, 512, 512 + >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') + >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() + >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) + >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') + >>> o, ht = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = chunk_gated_delta_rule( + q, k, v, g, beta, + initial_state=h0, + output_final_state=True, + cu_seqlens=cu_seqlens + ) + """ + assert q.dtype == k.dtype == v.dtype + assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." + assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." + + if head_first: + raise DeprecationWarning( + "head_first is deprecated and will be removed in a future version. " + "Please use head_first=False for now instead.", + stacklevel=2, + ) + q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) + if not head_first and q.shape[1] < q.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch" + f" seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + if q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: + raise ValueError( + f"The number of initial states is expected to be equal to the number of input sequences, " + f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + o, final_state = ChunkGatedDeltaRuleFunction.apply( + q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel + ) + if head_first: + o = rearrange(o, "b t h ... -> b h t ...") + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py new file mode 100644 index 000000000..f20c95d90 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py @@ -0,0 +1,257 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices, prepare_chunk_offsets +from .op import exp +from .utils import is_nvidia_hopper, use_cuda_graph + +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] + + +@triton.heuristics( + { + "USE_G": lambda args: args["g"] is not None, + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "STORE_FINAL_STATE": lambda args: args["ht"] is not None, + "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + } +) +@triton.autotune( + configs=[ + triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ], + key=["H", "K", "V", "BT", "USE_G"], + use_cuda_graph=use_cuda_graph, +) +@triton.jit(do_not_specialize=["T"]) +def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( + k, + v, + w, + v_new, + g, + h, + h0, + ht, + cu_seqlens, + chunk_offsets, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, + STORE_FINAL_STATE: tl.constexpr, + SAVE_NEW_VALUE: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_nh = tl.program_id(0), tl.program_id(1) + i_n, i_h = i_nh // H, i_nh % H + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + boh = tl.load(chunk_offsets + i_n).to(tl.int32) + else: + bos, eos = i_n * T, i_n * T + T + NT = tl.cdiv(T, BT) + boh = i_n * NT + + # [BK, BV] + b_h1 = tl.zeros([64, BV], dtype=tl.float32) + if K > 64: + b_h2 = tl.zeros([64, BV], dtype=tl.float32) + if K > 128: + b_h3 = tl.zeros([64, BV], dtype=tl.float32) + if K > 192: + b_h4 = tl.zeros([64, BV], dtype=tl.float32) + + # calculate offset + h += (boh * H + i_h) * K * V + v += (bos * H + i_h) * V + k += (bos * Hg + i_h // (H // Hg)) * K + w += (bos * H + i_h) * K + if SAVE_NEW_VALUE: + v_new += (bos * H + i_h) * V + stride_v = H * V + stride_h = H * K * V + stride_k = Hg * K + stride_w = H * K + if USE_INITIAL_STATE: + h0 = h0 + i_nh * K * V + if STORE_FINAL_STATE: + ht = ht + i_nh * K * V + + # load initial state + if USE_INITIAL_STATE: + p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) + if K > 64: + p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) + if K > 128: + p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) + if K > 192: + p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) + + # main recurrence + for i_t in range(NT): + p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) + + p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_v_new = ( + tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + if SAVE_NEW_VALUE + else None + ) + b_v_new = tl.zeros([BT, BV], dtype=tl.float32) + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) + if K > 64: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) + if K > 128: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) + if K > 192: + p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) + b_w = tl.load(p_w, boundary_check=(0, 1)) + b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) + b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) + + if SAVE_NEW_VALUE: + p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) + + if USE_G: + m_t = (i_t * BT + tl.arange(0, BT)) < T + last_idx = min((i_t + 1) * BT, T) - 1 + b_g_last = tl.load(g + bos * H + last_idx * H + i_h) + p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_g_last = exp(b_g_last) + b_h1 = b_h1 * b_g_last + if K > 64: + b_h2 = b_h2 * b_g_last + if K > 128: + b_h3 = b_h3 * b_g_last + if K > 192: + b_h4 = b_h4 * b_g_last + b_v_new = b_v_new.to(k.dtype.element_ty) + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h1 += tl.dot(b_k, b_v_new) + if K > 64: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h2 += tl.dot(b_k, b_v_new) + if K > 128: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h3 += tl.dot(b_k, b_v_new) + if K > 192: + p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_h4 += tl.dot(b_k, b_v_new) + + # epilogue + if STORE_FINAL_STATE: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 64: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 128: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + if K > 192: + p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) + tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_gated_delta_rule_fwd_h( + k: torch.Tensor, + w: torch.Tensor, + u: torch.Tensor, + g: Optional[torch.Tensor] = None, + initial_state: Optional[torch.Tensor] = None, + output_final_state: bool = False, + chunk_size: int = 64, # SY: remove this argument and force chunk size 64? + save_new_value: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, u.shape[-1] + H = u.shape[-2] + BT = chunk_size + + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + # N: the actual number of sequences in the batch with either equal or variable lengths + if cu_seqlens is None: + N, NT, chunk_offsets = B, triton.cdiv(T, BT), None + else: + N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) + assert K <= 256, "current kernel does not support head dimension larger than 256." + + h = k.new_empty(B, NT, H, K, V) + final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None + + v_new = torch.empty_like(u) if save_new_value else None + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), N * H) + + chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( + k=k, + v=u, + w=w, + v_new=v_new, + g=g, + h=h, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + chunk_offsets=chunk_offsets, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py new file mode 100644 index 000000000..73c2e1f19 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py @@ -0,0 +1,167 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 + +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp +from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper + +BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] +NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] + + +@triton.heuristics( + {"USE_G": lambda args: args["g"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_fwd_kernel_o( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + USE_G: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + + if IS_VARLEN: + i_tg = i_t + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + NT = tl.cdiv(T, BT) + else: + NT = tl.cdiv(T, BT) + i_tg = i_b * NT + i_t + bos, eos = i_b * T, i_b * T + T + + # offset calculation + q += (bos * Hg + i_h // (H // Hg)) * K + k += (bos * Hg + i_h // (H // Hg)) * K + v += (bos * H + i_h) * V + o += (bos * H + i_h) * V + h += (i_tg * H + i_h).to(tl.int64) * K * V + + b_o = tl.zeros([BT, BV], dtype=tl.float32) + b_A = tl.zeros([BT, BT], dtype=tl.float32) + + for i_k in range(tl.cdiv(K, BK)): + p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) + p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) + # [BT, BK] + b_q = tl.load(p_q, boundary_check=(0, 1)) + # [BK, BT] + b_k = tl.load(p_k, boundary_check=(0, 1)) + # [BK, BV] + b_h = tl.load(p_h, boundary_check=(0, 1)) + + # [BT, BK] @ [BK, BV] -> [BT, BV] + b_o += tl.dot(b_q, b_h) + # [BT, BK] @ [BK, BT] -> [BT, BT] + b_A += tl.dot(b_q, b_k) + + if USE_G: + g += bos * H + i_h + p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_o = b_o * exp(b_g)[:, None] + b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + + p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + + # to fix mma -> mma layout conversion + # already solved by triton v3.2 or higher + b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_fwd_o( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + h: torch.Tensor, + g: Optional[torch.Tensor] = None, # cumsum of log decay + scale: Optional[float] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, +) -> torch.Tensor: + B, T, Hg, K, V = *q.shape, v.shape[-1] + H = v.shape[-2] + if FLA_GDN_FIX_BT: + BT = 64 + else: + BT = min(chunk_size, max(16, triton.next_power_of_2(T))) + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + if scale is None: + scale = k.shape[-1] ** -0.5 + + o = torch.empty_like(v) + + def grid(meta): + return (triton.cdiv(V, meta["BV"]), NT, B * H) + + chunk_fwd_kernel_o[grid]( + q, + k, + v, + h, + g, + o, + cu_seqlens, + chunk_indices, + scale, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + ) + return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py new file mode 100644 index 000000000..aa545e8ec --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py @@ -0,0 +1,136 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .op import exp + + +@triton.heuristics( + {"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "USE_G": lambda args: args["g_cumsum"] is not None} +) +@triton.autotune( + configs=[ + triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_scaled_dot_kkt_fwd_kernel( + k, + beta, + g_cumsum, + A, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + IS_VARLEN: tl.constexpr, + USE_G: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + o_t = i_t * BT + tl.arange(0, BT) + m_t = o_t < T + + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + + b_A = tl.zeros([BT, BT], dtype=tl.float32) + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = b_k * b_beta[:, None] + b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + + if USE_G: + p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + b_g = tl.load(p_g, boundary_check=(0,)) + b_g_diff = b_g[:, None] - b_g[None, :] + b_A = b_A * exp(b_g_diff) + + m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) + b_A = tl.where(m_A, b_A, 0) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_scaled_dot_kkt_fwd( + k: torch.Tensor, + beta: torch.Tensor, + g_cumsum: Optional[torch.Tensor] = None, + cu_seqlens: Optional[torch.LongTensor] = None, + chunk_size: int = 64, + output_dtype: torch.dtype = torch.float32, +) -> torch.Tensor: + r""" + Compute beta * K * K^T. + + Args: + k (torch.Tensor): + The key tensor of shape `[B, T, H, K]`. + beta (torch.Tensor): + The beta tensor of shape `[B, T, H]`. + g_cumsum (torch.Tensor): + The cumulative sum of the gate tensor of shape `[B, T, H]`. + Default: None + cu_seqlens (torch.LongTensor): + The cumulative sequence lengths of the input tensor. + Default: None + chunk_size (int): + The chunk size. Default: 64. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float32` + + Returns: + beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. + """ + + B, T, Hg, K = k.shape + + H = beta.shape[-1] + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) + chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( + k=k, + beta=beta, + g_cumsum=g_cumsum, + A=A, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + BT=BT, + ) + return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py new file mode 100644 index 000000000..9cd6a6545 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py @@ -0,0 +1,200 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import warnings +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import check_shared_mem, input_guard + +BS_LIST = [32, 64] if check_shared_mem() else [16, 32] + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], + key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_scalar_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + BT: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) + else: + p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + # [BT] + b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) + b_o = tl.cumsum(b_s, axis=0) + if REVERSE: + b_z = tl.sum(b_s, axis=0) + b_o = -b_o + b_z[None] + b_s + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], + key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], +) +@triton.jit(do_not_specialize=["T"]) +def chunk_local_cumsum_vector_kernel( + s, + o, + cu_seqlens, + chunk_indices, + T, + B: tl.constexpr, + H: tl.constexpr, + S: tl.constexpr, + BT: tl.constexpr, + BS: tl.constexpr, + REVERSE: tl.constexpr, + IS_VARLEN: tl.constexpr, + HEAD_FIRST: tl.constexpr, +): + i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + o_i = tl.arange(0, BT) + if REVERSE: + m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) + else: + m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) + + if HEAD_FIRST: + p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + else: + p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) + # [BT, BS] + b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) + b_o = tl.dot(m_s, b_s, allow_tf32=False) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) + + +def chunk_local_cumsum_scalar( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + grid = (NT, B * H) + chunk_local_cumsum_scalar_kernel[grid]( + g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + ) + return g + + +def chunk_local_cumsum_vector( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, +) -> torch.Tensor: + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape + BT = chunk_size + chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" + + g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + def grid(meta): + return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + + # keep cumulative normalizer in fp32 + # this kernel is equivalent to + # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) + chunk_local_cumsum_vector_kernel[grid]( + g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, S=S, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse + ) + return g + + +@input_guard +def chunk_local_cumsum( + g: torch.Tensor, + chunk_size: int, + reverse: bool = False, + cu_seqlens: Optional[torch.Tensor] = None, + head_first: bool = False, + output_dtype: Optional[torch.dtype] = torch.float, + **kwargs, +) -> torch.Tensor: + if not head_first and g.shape[1] < g.shape[2]: + warnings.warn( + f"Input tensor shape suggests potential format mismatch" + f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " + "This may indicate the inputs were passed in head-first format [B, H, T, ...] " + "when head_first=False was specified. " + "Please verify your input tensor format matches the expected shape [B, T, H, ...].", + stacklevel=2, + ) + if cu_seqlens is not None: + assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" + if len(g.shape) == 3: + return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + elif len(g.shape) == 4: + return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) + else: + raise ValueError( + f"Unsupported input shape {g.shape}. " + f"which should be (B, T, H, D) if `head_first=False` " + f"or (B, H, T, D) otherwise" + ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py new file mode 100644 index 000000000..4ff18d4f6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py @@ -0,0 +1,367 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .op import exp + + +@triton.heuristics( + { + "USE_INITIAL_STATE": lambda args: args["h0"] is not None, + "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, + "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, + "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, + } +) +@triton.jit(do_not_specialize=["N", "T"]) +def fused_recurrent_gated_delta_rule_fwd_kernel( + q, + k, + v, + g, + beta, + o, + h0, + ht, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + scale, + N: tl.int64, # num of sequences + T: tl.int64, # num of tokens + B: tl.constexpr, + H: tl.constexpr, + HV: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + stride_init_state_token: tl.constexpr, + stride_final_state_token: tl.constexpr, + stride_indices_seq: tl.constexpr, + stride_indices_tok: tl.constexpr, + USE_INITIAL_STATE: tl.constexpr, # whether to use initial state + INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace + IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, + USE_QK_L2NORM_IN_KERNEL: tl.constexpr, + IS_VARLEN: tl.constexpr, + IS_CONTINUOUS_BATCHING: tl.constexpr, + IS_SPEC_DECODING: tl.constexpr, +): + i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) + i_n, i_hv = i_nh // HV, i_nh % HV + i_h = i_hv // (HV // H) + if IS_VARLEN: + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) + all = T + T = eos - bos + else: + bos, eos = i_n * T, i_n * T + T + all = B * T + + if T == 0: + # no tokens to process for this sequence + return + + o_k = i_k * BK + tl.arange(0, BK) + o_v = i_v * BV + tl.arange(0, BV) + + p_q = q + (bos * H + i_h) * K + o_k + p_k = k + (bos * H + i_h) * K + o_k + p_v = v + (bos * HV + i_hv) * V + o_v + if IS_BETA_HEADWISE: + p_beta = beta + (bos * HV + i_hv) * V + o_v + else: + p_beta = beta + bos * HV + i_hv + p_g = g + bos * HV + i_hv + p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v + + mask_k = o_k < K + mask_v = o_v < V + mask_h = mask_k[:, None] & mask_v[None, :] + + b_h = tl.zeros([BK, BV], dtype=tl.float32) + if USE_INITIAL_STATE: + if IS_CONTINUOUS_BATCHING: + if IS_SPEC_DECODING: + i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 + else: + i_t = 0 + p_h0 = ( + h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token + ) + else: + p_h0 = h0 + bos * HV * K * V + p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) + + for i_t in range(0, T): + b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) + b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) + b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) + b_g = tl.load(p_g).to(tl.float32) + + if USE_QK_L2NORM_IN_KERNEL: + b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) + b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) + b_q = b_q * scale + # [BK, BV] + b_h *= exp(b_g) + # [BV] + b_v -= tl.sum(b_h * b_k[:, None], 0) + if IS_BETA_HEADWISE: + b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) + else: + b_beta = tl.load(p_beta).to(tl.float32) + b_v *= b_beta + # [BK, BV] + b_h += b_k[:, None] * b_v[None, :] + # [BV] + b_o = tl.sum(b_h * b_q[:, None], 0) + tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) + + # keep the states for multi-query tokens + if INPLACE_FINAL_STATE: + p_ht = ( + ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token + ) + else: + p_ht = ht + (bos + i_t) * stride_final_state_token + p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] + tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) + + p_q += H * K + p_k += H * K + p_o += HV * V + p_v += HV * V + p_g += HV + p_beta += HV * (V if IS_BETA_HEADWISE else 1) + + +def fused_recurrent_gated_delta_rule_fwd( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, H, K, V = *k.shape, v.shape[-1] + HV = v.shape[2] + N = B if cu_seqlens is None else len(cu_seqlens) - 1 + BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) + NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) + assert NK == 1, "NK > 1 is not supported yet" + num_stages = 3 + num_warps = 1 + + o = q.new_empty(NK, *v.shape) + if inplace_final_state: + final_state = initial_state + else: + final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) + + stride_init_state_token = initial_state.stride(0) + stride_final_state_token = final_state.stride(0) + + if ssm_state_indices is None: + stride_indices_seq, stride_indices_tok = 1, 1 + elif ssm_state_indices.ndim == 1: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 + else: + stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() + + grid = (NK, NV, N * HV) + fused_recurrent_gated_delta_rule_fwd_kernel[grid]( + q=q, + k=k, + v=v, + g=g, + beta=beta, + o=o, + h0=initial_state, + ht=final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + scale=scale, + N=N, + T=T, + B=B, + H=H, + HV=HV, + K=K, + V=V, + BK=BK, + BV=BV, + stride_init_state_token=stride_init_state_token, + stride_final_state_token=stride_final_state_token, + stride_indices_seq=stride_indices_seq, + stride_indices_tok=stride_indices_tok, + IS_BETA_HEADWISE=beta.ndim == v.ndim, + USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, + INPLACE_FINAL_STATE=inplace_final_state, + num_warps=num_warps, + num_stages=num_stages, + ) + o = o.squeeze(0) + return o, final_state + + +class FusedRecurrentFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor, + scale: float, + initial_state: torch.Tensor, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, + ): + o, final_state = fused_recurrent_gated_delta_rule_fwd( + q=q.contiguous(), + k=k.contiguous(), + v=v.contiguous(), + g=g.contiguous(), + beta=beta.contiguous(), + scale=scale, + initial_state=initial_state, + inplace_final_state=inplace_final_state, + cu_seqlens=cu_seqlens, + ssm_state_indices=ssm_state_indices, + num_accepted_tokens=num_accepted_tokens, + use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, + ) + + return o, final_state + + +def fused_recurrent_gated_delta_rule( + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + g: torch.Tensor, + beta: torch.Tensor = None, + scale: float = None, + initial_state: torch.Tensor = None, + inplace_final_state: bool = True, + cu_seqlens: Optional[torch.LongTensor] = None, + ssm_state_indices: Optional[torch.Tensor] = None, + num_accepted_tokens: Optional[torch.Tensor] = None, + use_qk_l2norm_in_kernel: bool = False, +) -> tuple[torch.Tensor, torch.Tensor]: + r""" + Args: + q (torch.Tensor): + queries of shape `[B, T, H, K]`. + k (torch.Tensor): + keys of shape `[B, T, H, K]`. + v (torch.Tensor): + values of shape `[B, T, HV, V]`. + GVA is applied if `HV > H`. + g (torch.Tensor): + g (decays) of shape `[B, T, HV]`. + beta (torch.Tensor): + betas of shape `[B, T, HV]`. + scale (Optional[int]): + Scale factor for the RetNet attention scores. + If not provided, it will default to `1 / sqrt(K)`. Default: `None`. + initial_state (Optional[torch.Tensor]): + Initial state of shape `[N, HV, K, V]` for `N` input sequences. + For equal-length input sequences, `N` equals the batch size `B`. + Default: `None`. + inplace_final_state: bool: + Whether to store the final state in-place to save memory. + Default: `True`. + cu_seqlens (torch.LongTensor): + Cumulative sequence lengths of shape `[N+1]` used for variable-length training, + consistent with the FlashAttention API. + ssm_state_indices (Optional[torch.Tensor]): + Indices to map the input sequences to the initial/final states. + num_accepted_tokens (Optional[torch.Tensor]): + Number of accepted tokens for each sequence during decoding. + + Returns: + o (torch.Tensor): + Outputs of shape `[B, T, HV, V]`. + final_state (torch.Tensor): + Final state of shape `[N, HV, K, V]`. + + Examples:: + >>> import torch + >>> import torch.nn.functional as F + >>> from einops import rearrange + >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule + # inputs with equal lengths + >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 + >>> q = torch.randn(B, T, H, K, device='cuda') + >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) + >>> v = torch.randn(B, T, HV, V, device='cuda') + >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) + >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() + >>> h0 = torch.randn(B, HV, K, V, device='cuda') + >>> o, ht = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + ) + # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required + >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) + # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected + >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) + >>> o_var, ht_var = fused_gated_recurrent_delta_rule( + q, k, v, g, beta, + initial_state=h0, + cu_seqlens=cu_seqlens + ) + """ + if cu_seqlens is not None and q.shape[0] != 1: + raise ValueError( + f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." + f"Please flatten variable-length inputs before processing." + ) + if scale is None: + scale = k.shape[-1] ** -0.5 + else: + assert scale > 0, "scale must be positive" + if beta is None: + beta = torch.ones_like(q[..., 0]) + o, final_state = FusedRecurrentFunction.apply( + q, + k, + v, + g, + beta, + scale, + initial_state, + inplace_final_state, + cu_seqlens, + ssm_state_indices, + num_accepted_tokens, + use_qk_l2norm_in_kernel, + ) + return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/index.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/index.py new file mode 100644 index 000000000..8b1d59fc6 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/index.py @@ -0,0 +1,30 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import torch + +import triton + +from .utils import tensor_cache + + +@tensor_cache +def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: + return cu_seqlens[1:] - cu_seqlens[:-1] + + +@tensor_cache +def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) + return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) + + +@tensor_cache +def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: + return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py new file mode 100644 index 000000000..7225cd4ae --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py @@ -0,0 +1,137 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os +from typing import Optional + +import torch + +import triton +import triton.language as tl + +BT_LIST = [8, 16, 32, 64, 128] + +USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) + + +@triton.autotune(configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], key=["D"]) +@triton.jit +def l2norm_fwd_kernel1( + x, + y, + D, + BD: tl.constexpr, + eps, +): + i_t = tl.program_id(0) + x += i_t * D + y += i_t * D + # Compute mean and variance + cols = tl.arange(0, BD) + mask = cols < D + b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=0) + b_rstd = 1 / tl.sqrt(b_var + eps) + # tl.store(Rstd + i_t, rstd) + # Normalize and apply linear transformation + b_y = b_x * b_rstd + tl.store(y + cols, b_y, mask=mask) + + +@triton.autotune( + configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], + key=["D"], +) +@triton.jit(do_not_specialize=["NB"]) +def l2norm_fwd_kernel( + x, + y, + eps, + NB, + T, + D: tl.constexpr, + BT: tl.constexpr, + BD: tl.constexpr, +): + i_t = tl.program_id(0) + p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) + b_var = tl.sum(b_x * b_x, axis=1) + b_y = b_x / tl.sqrt(b_var + eps)[:, None] + p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) + tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) + + +@triton.jit +def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): + xoffset = tl.program_id(0) * MBLOCK + row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] + xmask = row_idx < M + rindex = tl.arange(0, N)[None, :] + xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) + square = tl.broadcast_to(xs * xs, [MBLOCK, N]) + square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] + rsqrt = tl.rsqrt(square_sum + eps) + tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) + + +def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): + x_shape_og = x.shape + x = x.view(-1, x.shape[-1]) + # allocate output + if output_dtype is None: + y = torch.empty_like(x) + else: + y = torch.empty_like(x, dtype=output_dtype) + assert y.stride(-1) == 1 + T, D = x.shape[0], x.shape[-1] + # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) + # Less than 64KB per feature: enqueue fused kernel + MAX_FUSED_SIZE = 65536 // x.element_size() + BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) + if D > BD: + raise RuntimeError("This layer doesn't support feature dim >= 64KB.") + + if not USE_DEFAULT_FLA_NORM: + MBLOCK = 32 + # M, N = x.shape + l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( + x, + y, + eps, + T, + D, + MBLOCK, + ) + else: + if D <= 512: + NB = triton.cdiv(T, 2048) + + def grid(meta): + return (triton.cdiv(T, meta["BT"]),) + + l2norm_fwd_kernel[grid]( + x, + y, + eps, + NB=NB, + T=T, + D=D, + BD=BD, + ) + else: + l2norm_fwd_kernel1[(T,)]( + x, + y, + eps=eps, + D=D, + BD=BD, + ) + + return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/op.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/op.py new file mode 100644 index 000000000..ec0999455 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/op.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +import os + +import triton +import triton.language as tl + + +@triton.jit +def div_normal(x, y): + return x / y + + +div = div_normal +exp = tl.exp +log = tl.log +log2 = tl.log2 + + +if not hasattr(tl, "gather"): + + @triton.jit + def gather(src, index, axis, _builder=None): + # This is a fallback implementation when tl.gather is not supported + # In order to pass triton compiler, there is no actual gather operation + return src + +else: + gather = tl.gather diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py new file mode 100644 index 000000000..46e4d5082 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py @@ -0,0 +1,271 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices +from .utils import input_guard + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["BT"], +) +@triton.jit(do_not_specialize=["T"]) +def solve_tril_16x16_kernel( + A, + Ad, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + BT: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A = A + (bos * H + i_h) * BT + Ad = Ad + (bos * H + i_h) * 16 + + offset = (i_t * 16) % BT + p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) + p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) + b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) + b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) + + o_i = tl.arange(0, 16) + for i in range(1, min(16, T - i_t * 16)): + b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) + b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) + mask = o_i == i + b_A = tl.where(mask[:, None], b_a, b_A) + b_A += o_i[:, None] == o_i[None, :] + tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [1, 2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_32x32_inverse_kernel( + A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 32 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 32 + + p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4, 5] + ], + key=["H", "BT", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def merge_16x16_to_64x64_inverse_kernel( + A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + + A += (bos * H + i_h) * 64 + Ad += (bos * H + i_h) * 16 + Ai += (bos * H + i_h) * 64 + + p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + + A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) + A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) + A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) + A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) + A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) + A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) + + Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) + Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) + Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) + Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) + + Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") + Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee") + Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee") + + Ai_31 = -tl.dot( + Ai_33, + tl.dot(A_31, Ai_11, input_precision="ieee") + tl.dot(A_32, Ai_21, input_precision="ieee"), + input_precision="ieee", + ) + Ai_42 = -tl.dot( + Ai_44, + tl.dot(A_42, Ai_22, input_precision="ieee") + tl.dot(A_43, Ai_32, input_precision="ieee"), + input_precision="ieee", + ) + Ai_41 = -tl.dot( + Ai_44, + tl.dot(A_41, Ai_11, input_precision="ieee") + + tl.dot(A_42, Ai_21, input_precision="ieee") + + tl.dot(A_43, Ai_31, input_precision="ieee"), + input_precision="ieee", + ) + + p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)) + p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) + p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) + p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) + p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) + p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) + p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) + p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) + p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) + p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) + tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + fill_zeros = tl.zeros((16, 16), dtype=tl.float32) + p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)) + p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)) + p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)) + p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)) + p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)) + p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)) + tl.store(p_Ai_12, fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_13, fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_14, fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_23, fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_24, fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + tl.store(p_Ai_34, fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) + + +@input_guard +def solve_tril( + A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float +) -> torch.Tensor: + """ + Compute the inverse of the lower triangular matrix + A should be strictly lower triangular, i.e., A.triu() == 0. + + Args: + A (torch.Tensor): + [B, T, H, K] + cu_seqlens (torch.Tensor): + The cumulative sequence lengths of the input tensor. + Default: None. + output_dtype (torch.dtype): + The dtype of the output tensor. Default: `torch.float` + + Returns: + (I + A)^-1 with the same shape as A + """ + assert A.shape[-1] in [16, 32, 64] + + B, T, H, BT = A.shape + Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) + + chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) + solve_tril_16x16_kernel[NT, B * H]( + A=A, + Ad=Ad, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + if BT == 16: + return Ad + + Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) + merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) + merge_fn[NT, B * H]( + A=A, + Ad=Ad, + Ai=Ai, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + BT=BT, + ) + return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py new file mode 100644 index 000000000..d8f29f287 --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py @@ -0,0 +1,173 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang +# ruff: noqa: E501 +import contextlib +import functools +import logging +import os +from enum import Enum +from typing import Any, Callable, Literal, Optional + +import torch + +import triton + +logger = logging.getLogger(__name__) + +COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" +FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" +FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" + +SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) + + +def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator that caches the most recent results of a function with tensor inputs. + + This decorator will store the output of the decorated function for the most recent set of input tensors. + The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. + + Args: + fn (Callable[..., torch.Tensor]): + The function to be decorated. It should take tensor inputs and return tensor outputs. + + Returns: + Callable[..., torch.Tensor]: + A wrapped version of the input function with single-entry caching. + """ + + cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] + cache_size = 4 + + @functools.wraps(fn) + def wrapper(*args: Any, **kwargs: Any) -> Any: + nonlocal cache_entries, cache_size + for i, entry in enumerate(cache_entries): + last_args, last_kwargs, last_result = entry + if ( + len(args) == len(last_args) + and len(kwargs) == len(last_kwargs) + and all(a is b for a, b in zip(args, last_args)) + and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) + ): + cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] + return last_result + + result = fn(*args, **kwargs) + + if len(cache_entries) >= cache_size: + cache_entries = cache_entries[1:] + cache_entries.append((args, kwargs, result)) + return result + + return wrapper + + +def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: + """ + A decorator to make sure all input tensors are contiguous and set the device based on input tensors. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) + contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} + + tensor = None + for arg in args: + if isinstance(arg, torch.Tensor): + tensor = arg + break + if tensor is None: + for value in kwargs.values(): + if isinstance(value, torch.Tensor): + tensor = value + break + + if tensor is not None: + ctx = torch.cuda.device(tensor.device.index) + else: + ctx = contextlib.nullcontext() + + with ctx: + return fn(*contiguous_args, **contiguous_kwargs) + + return wrapper + + +@functools.cache +def get_available_device() -> str: + try: + return triton.runtime.driver.active.get_current_target().backend + except BaseException: + return "cpu" + + +@functools.cache +def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: + device = get_available_device() + mapping = { + "cuda": "nvidia", + "hip": "amd", + "xpu": "intel", + } + # return the mapped value, or the original if not found + return mapping.get(device, device) + + +# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. +# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. +# Therefore, we need to check the triton backend to determine the actual GPU vendor. +device = get_available_device() if get_available_device() != "hip" else "cuda" +device_torch_lib = getattr(torch, device) +device_platform = _check_platform() + +is_amd = device_platform == "amd" +is_intel = device_platform == "intel" +is_nvidia = device_platform == "nvidia" +is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) +is_nvidia_hopper = is_nvidia and ( + "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 +) +use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" + + +def get_all_max_shared_mem(): + try: + return [ + triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] + for i in range(device_torch_lib.device_count()) + ] + except BaseException: + return [-1] + + +class Backend(Enum): + ADA = 101376 # RTX 4090 + AMPERE = 166912 # A100 + HOPPER = 232448 # H100 + DEFAULT = 102400 # Default + + @classmethod + def get_shared_memory(cls, arch: str) -> int: + try: + return cls[arch.upper()].value + except KeyError: + return cls.DEFAULT.value + + +@functools.cache +def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: + try: + device_shared_mem_list = get_all_max_shared_mem() + max_shared_memory = device_shared_mem_list[tensor_idx] + return max_shared_memory >= Backend.get_shared_memory(arch) + except Exception: + return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py new file mode 100644 index 000000000..dec8d2ffc --- /dev/null +++ b/lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py @@ -0,0 +1,122 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang +# +# This file contains code copied from the flash-linear-attention project. +# The original source code was licensed under the MIT license and included +# the following copyright notice: +# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# ruff: noqa: E501 +from typing import Optional + +import torch + +import triton +import triton.language as tl + +from .index import prepare_chunk_indices + + +@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) +@triton.autotune( + configs=[ + triton.Config({}, num_warps=num_warps, num_stages=num_stages) + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ], + key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], +) +@triton.jit(do_not_specialize=["T"]) +def recompute_w_u_fwd_kernel( + k, + v, + beta, + w, + u, + A, + g, + cu_seqlens, + chunk_indices, + T, + H: tl.constexpr, + Hg: tl.constexpr, + K: tl.constexpr, + V: tl.constexpr, + BT: tl.constexpr, + BK: tl.constexpr, + BV: tl.constexpr, + IS_VARLEN: tl.constexpr, +): + i_t, i_bh = tl.program_id(0), tl.program_id(1) + i_b, i_h = i_bh // H, i_bh % H + if IS_VARLEN: + i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) + bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) + T = eos - bos + else: + bos, eos = i_b * T, i_b * T + T + p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) + p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) + b_beta = tl.load(p_beta, boundary_check=(0,)) + b_A = tl.load(p_A, boundary_check=(0, 1)) + b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) + + for i_v in range(tl.cdiv(V, BV)): + p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) + b_v = tl.load(p_v, boundary_check=(0, 1)) + b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) + b_u = tl.dot(b_A, b_vb, allow_tf32=False) + tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) + + for i_k in range(tl.cdiv(K, BK)): + p_k = tl.make_block_ptr( + k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) + ) + p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) + b_k = tl.load(p_k, boundary_check=(0, 1)) + b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) + b_w = tl.dot(b_A, b_kb) + tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) + + +def recompute_w_u_fwd( + k: torch.Tensor, + v: torch.Tensor, + beta: torch.Tensor, + g_cumsum: torch.Tensor, + A: torch.Tensor, + cu_seqlens: Optional[torch.LongTensor], +) -> tuple[torch.Tensor, torch.Tensor]: + B, T, Hg, K, V = *k.shape, v.shape[-1] + H = v.shape[-2] + BT = A.shape[-1] + + chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None + NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + BK = 64 + BV = 64 + u = torch.empty_like(v) + w = k.new_empty(B, T, H, K) + recompute_w_u_fwd_kernel[(NT, B * H)]( + k=k, + v=v, + beta=beta, + w=w, + u=u, + A=A, + g=g_cumsum, + cu_seqlens=cu_seqlens, + chunk_indices=chunk_indices, + T=T, + H=H, + Hg=Hg, + K=K, + V=V, + BT=BT, + BK=BK, + BV=BV, + ) + return w, u diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index fbdbcd03d..49496fbd4 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -572,5 +572,5 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) - parser.add_argument("--linear_attn_cache_size", type=int, default=2000, help="""The size of linear attn cache. """) + parser.add_argument("--linear_attn_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) return parser From 6ab8e6a7cf349ebdd9443aa688f8f70a51dfc1ba Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 16:19:19 +0800 Subject: [PATCH 11/19] use autotuner --- .../triton_kernel/fla/ops/chunk_delta_h.py | 49 ++++++++--- .../triton_kernel/fla/ops/chunk_o.py | 52 ++++++++--- .../fla/ops/chunk_scaled_dot_kkt.py | 47 ++++++++-- .../qwen3next/triton_kernel/fla/ops/cumsum.py | 71 ++++++++++++--- .../qwen3next/triton_kernel/fla/ops/l2norm.py | 88 +++++++++++++------ .../triton_kernel/fla/ops/solve_tril.py | 25 +----- .../triton_kernel/fla/ops/wy_fast.py | 9 +- 7 files changed, 238 insertions(+), 103 deletions(-) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py index e4c020f49..5e9684da5 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -16,6 +16,7 @@ from .index import prepare_chunk_indices, prepare_chunk_offsets from .op import exp from .utils import use_cuda_graph +from lightllm.common.triton_utils.autotuner import autotune NUM_WARPS = [2, 4, 8, 16] @@ -30,16 +31,6 @@ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.autotune( - configs=[ - triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] - for num_stages in [2, 3, 4] - for BV in [32, 64] - ], - key=["H", "K", "V", "BT"], - use_cuda_graph=use_cuda_graph, -) @triton.jit(do_not_specialize=["T"]) def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( k, @@ -237,6 +228,29 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) +def _get_chunk_delta_h_configs(): + return [ + {"BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for num_warps in [2, 4] + for num_stages in [2, 3, 4] + for BV in [32, 64] + ] + + +def _get_chunk_delta_h_static_key(H, K, V, BT, **kwargs): + return {"H": H, "K": K, "V": V, "BT": BT} + + +def _get_chunk_delta_h_run_key(H, K, V, BT, **kwargs): + return f"{H}_{K}_{V}_{BT}" + + +@autotune( + kernel_name="chunk_gated_delta_rule_fwd_h", + configs_gen_func=_get_chunk_delta_h_configs, + static_key_func=_get_chunk_delta_h_static_key, + run_key_func=_get_chunk_delta_h_run_key, +) def chunk_gated_delta_rule_fwd_h( k: torch.Tensor, w: torch.Tensor, @@ -248,6 +262,7 @@ def chunk_gated_delta_rule_fwd_h( chunk_size: int = 64, # SY: remove this argument and force chunk size 64? save_new_value: bool = True, cu_seqlens: torch.LongTensor | None = None, + run_config=None, ) -> tuple[torch.Tensor, torch.Tensor]: # This kernel is slightly different from fla to support Q/K with different head numbers. # In fla, Q/K always have the same head number, so Hg is always equal to H. @@ -272,8 +287,15 @@ def chunk_gated_delta_rule_fwd_h( v_new = torch.empty_like(u) if save_new_value else None - def grid(meta): - return (triton.cdiv(V, meta["BV"]), N * H) + # Extract config parameters + if run_config is None: + run_config = {"BV": 64, "num_warps": 2, "num_stages": 2} + + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), N * H) chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( k=k, @@ -293,5 +315,8 @@ def grid(meta): K=K, V=V, BT=BT, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, ) return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py index b70185a05..440d2f57c 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -18,6 +18,7 @@ from .index import prepare_chunk_indices from .op import exp from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper +from lightllm.common.triton_utils.autotuner import autotune BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] @@ -29,16 +30,6 @@ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.autotune( - configs=[ - triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) - for BK in BKV_LIST - for BV in BKV_LIST - for num_warps in NUM_WARPS - for num_stages in [2, 3, 4] - ], - key=["H", "K", "V", "BT"], -) @triton.jit(do_not_specialize=["T"]) def chunk_fwd_kernel_o( q, @@ -129,6 +120,30 @@ def chunk_fwd_kernel_o( tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) +def _get_chunk_o_configs(): + return [ + {"BK": BK, "BV": BV, "num_warps": num_warps, "num_stages": num_stages} + for BK in BKV_LIST + for BV in BKV_LIST + for num_warps in NUM_WARPS + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_o_static_key(H, K, V, BT, **kwargs): + return {"H": H, "K": K, "V": V, "BT": BT} + + +def _get_chunk_o_run_key(H, K, V, BT, **kwargs): + return f"{H}_{K}_{V}_{BT}" + + +@autotune( + kernel_name="chunk_fwd_o", + configs_gen_func=_get_chunk_o_configs, + static_key_func=_get_chunk_o_static_key, + run_key_func=_get_chunk_o_run_key, +) def chunk_fwd_o( q: torch.Tensor, k: torch.Tensor, @@ -138,6 +153,7 @@ def chunk_fwd_o( scale: float | None = None, cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, + run_config=None, ) -> torch.Tensor: B, T, Hg, K, V = *q.shape, v.shape[-1] H = v.shape[-2] @@ -149,8 +165,16 @@ def chunk_fwd_o( o = torch.empty_like(v) - def grid(meta): - return (triton.cdiv(V, meta["BV"]), NT, B * H) + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "BV": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + BV = run_config.get("BV", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + + grid = (triton.cdiv(V, BV), NT, B * H) chunk_fwd_kernel_o[grid]( q, @@ -168,5 +192,9 @@ def grid(meta): K=K, V=V, BT=BT, + BK=BK, + BV=BV, + num_warps=num_warps, + num_stages=num_stages, ) return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py index 31ba1c7df..3f746a0dd 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -15,6 +15,7 @@ from .index import prepare_chunk_indices from .op import exp +from lightllm.common.triton_utils.autotuner import autotune @triton.heuristics( @@ -23,15 +24,6 @@ "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, } ) -@triton.autotune( - configs=[ - triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) - for BK in [32, 64, 128] - for num_warps in [2, 4, 8] - for num_stages in [2, 3, 4] - ], - key=["H", "K", "BT", "IS_VARLEN"], -) @triton.jit(do_not_specialize=["T"]) def chunk_scaled_dot_kkt_fwd_kernel( k, @@ -95,6 +87,31 @@ def chunk_scaled_dot_kkt_fwd_kernel( tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) +def _get_chunk_scaled_dot_kkt_configs(): + return [ + {"BK": BK, "num_warps": num_warps, "num_stages": num_stages} + for BK in [32, 64, 128] + for num_warps in [2, 4, 8] + for num_stages in [2, 3, 4] + ] + + +def _get_chunk_scaled_dot_kkt_static_key(H, K, BT, cu_seqlens, **kwargs): + IS_VARLEN = cu_seqlens is not None + return {"H": H, "K": K, "BT": BT, "IS_VARLEN": IS_VARLEN} + + +def _get_chunk_scaled_dot_kkt_run_key(H, K, BT, cu_seqlens, **kwargs): + IS_VARLEN = cu_seqlens is not None + return f"{H}_{K}_{BT}_{IS_VARLEN}" + + +@autotune( + kernel_name="chunk_scaled_dot_kkt_fwd", + configs_gen_func=_get_chunk_scaled_dot_kkt_configs, + static_key_func=_get_chunk_scaled_dot_kkt_static_key, + run_key_func=_get_chunk_scaled_dot_kkt_run_key, +) def chunk_scaled_dot_kkt_fwd( k: torch.Tensor, g: torch.Tensor | None = None, @@ -102,6 +119,7 @@ def chunk_scaled_dot_kkt_fwd( cu_seqlens: torch.LongTensor | None = None, chunk_size: int = 64, output_dtype: torch.dtype = torch.float32, + run_config=None, ) -> torch.Tensor: r""" Compute beta * K * K^T. @@ -132,6 +150,14 @@ def chunk_scaled_dot_kkt_fwd( chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) + # Extract config parameters + if run_config is None: + run_config = {"BK": 64, "num_warps": 2, "num_stages": 2} + + BK = run_config.get("BK", 64) + num_warps = run_config.get("num_warps", 2) + num_stages = run_config.get("num_stages", 2) + A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( k=k, @@ -145,5 +171,8 @@ def chunk_scaled_dot_kkt_fwd( Hg=Hg, K=K, BT=BT, + BK=BK, + num_warps=num_warps, + num_stages=num_stages, ) return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py index 98dbccae8..4c936020c 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -16,15 +16,12 @@ from .index import prepare_chunk_indices from .utils import check_shared_mem, input_guard +from lightllm.common.triton_utils.autotuner import autotune BS_LIST = [32, 64] if check_shared_mem() else [16, 32] @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], -) @triton.jit(do_not_specialize=["T"]) def chunk_local_cumsum_scalar_kernel( s, @@ -70,10 +67,6 @@ def chunk_local_cumsum_scalar_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], - key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], -) @triton.jit(do_not_specialize=["T"]) def chunk_local_cumsum_vector_kernel( s, @@ -151,6 +144,26 @@ def chunk_local_cumsum_vector_kernel( tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) +def _get_cumsum_scalar_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] + + +def _get_cumsum_scalar_static_key(B, H, BT, REVERSE, cu_seqlens, **kwargs): + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "BT": BT, "IS_VARLEN": IS_VARLEN, "REVERSE": REVERSE} + + +def _get_cumsum_scalar_run_key(B, H, BT, REVERSE, cu_seqlens, **kwargs): + IS_VARLEN = cu_seqlens is not None + return f"{B}_{H}_{BT}_{IS_VARLEN}_{REVERSE}" + + +@autotune( + kernel_name="chunk_local_cumsum_scalar", + configs_gen_func=_get_cumsum_scalar_configs, + static_key_func=_get_cumsum_scalar_static_key, + run_key_func=_get_cumsum_scalar_run_key, +) def chunk_local_cumsum_scalar( g: torch.Tensor, chunk_size: int, @@ -158,6 +171,7 @@ def chunk_local_cumsum_scalar( cu_seqlens: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, + run_config=None, ) -> torch.Tensor: if head_first: B, H, T = g.shape @@ -168,6 +182,13 @@ def chunk_local_cumsum_scalar( chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) + + # Extract config parameters + if run_config is None: + run_config = {"num_warps": 2} + + num_warps = run_config.get("num_warps", 2) + grid = (NT, B * H) chunk_local_cumsum_scalar_kernel[grid]( g_org, @@ -180,10 +201,31 @@ def chunk_local_cumsum_scalar( BT=BT, HEAD_FIRST=head_first, REVERSE=reverse, + num_warps=num_warps, ) return g +def _get_cumsum_vector_configs(): + return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] + + +def _get_cumsum_vector_static_key(B, H, S, BT, REVERSE, cu_seqlens, **kwargs): + IS_VARLEN = cu_seqlens is not None + return {"B": B, "H": H, "S": S, "BT": BT, "IS_VARLEN": IS_VARLEN, "REVERSE": REVERSE} + + +def _get_cumsum_vector_run_key(B, H, S, BT, REVERSE, cu_seqlens, **kwargs): + IS_VARLEN = cu_seqlens is not None + return f"{B}_{H}_{S}_{BT}_{IS_VARLEN}_{REVERSE}" + + +@autotune( + kernel_name="chunk_local_cumsum_vector", + configs_gen_func=_get_cumsum_vector_configs, + static_key_func=_get_cumsum_vector_static_key, + run_key_func=_get_cumsum_vector_run_key, +) def chunk_local_cumsum_vector( g: torch.Tensor, chunk_size: int, @@ -191,6 +233,7 @@ def chunk_local_cumsum_vector( cu_seqlens: torch.Tensor | None = None, head_first: bool = False, output_dtype: torch.dtype | None = torch.float, + run_config=None, ) -> torch.Tensor: if head_first: B, H, T, S = g.shape @@ -203,8 +246,14 @@ def chunk_local_cumsum_vector( g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) - def grid(meta): - return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) + # Extract config parameters + if run_config is None: + run_config = {"BS": 32, "num_warps": 2} + + BS = run_config.get("BS", 32) + num_warps = run_config.get("num_warps", 2) + + grid = (triton.cdiv(S, BS), NT, B * H) # keep cumulative normalizer in fp32 # this kernel is equivalent to @@ -219,8 +268,10 @@ def grid(meta): H=H, S=S, BT=BT, + BS=BS, HEAD_FIRST=head_first, REVERSE=reverse, + num_warps=num_warps, ) return g diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py index 0c67bcf9b..e0af7495f 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -13,16 +13,13 @@ import triton import triton.language as tl +from lightllm.common.triton_utils.autotuner import autotune BT_LIST = [8, 16, 32, 64, 128] USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) -@triton.autotune( - configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], - key=["D"], -) @triton.jit def l2norm_fwd_kernel1( x, @@ -46,10 +43,6 @@ def l2norm_fwd_kernel1( tl.store(y + cols, b_y, mask=mask) -@triton.autotune( - configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], - key=["D"], -) @triton.jit(do_not_specialize=["NB"]) def l2norm_fwd_kernel( x, @@ -83,6 +76,63 @@ def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) +def _get_l2norm_kernel1_configs(): + return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] + + +def _get_l2norm_kernel1_static_key(D, **kwargs): + return {"D": D} + + +def _get_l2norm_kernel1_run_key(D, **kwargs): + return f"{D}" + + +@autotune( + kernel_name="l2norm_fwd_kernel1", + configs_gen_func=_get_l2norm_kernel1_configs, + static_key_func=_get_l2norm_kernel1_static_key, + run_key_func=_get_l2norm_kernel1_run_key, +) +def _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD, run_config=None): + if run_config is None: + run_config = {"num_warps": 4} + + num_warps = run_config.get("num_warps", 4) + T = x.shape[0] + + l2norm_fwd_kernel1[(T,)](x, y, eps=eps, D=D, BD=BD, num_warps=num_warps) + + +def _get_l2norm_kernel_configs(): + return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] + + +def _get_l2norm_kernel_static_key(D, **kwargs): + return {"D": D} + + +def _get_l2norm_kernel_run_key(D, **kwargs): + return f"{D}" + + +@autotune( + kernel_name="l2norm_fwd_kernel", + configs_gen_func=_get_l2norm_kernel_configs, + static_key_func=_get_l2norm_kernel_static_key, + run_key_func=_get_l2norm_kernel_run_key, +) +def _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB, run_config=None): + if run_config is None: + run_config = {"BT": 32, "num_warps": 4} + + BT = run_config.get("BT", 32) + num_warps = run_config.get("num_warps", 4) + + grid = (triton.cdiv(T, BT),) + l2norm_fwd_kernel[grid](x, y, eps, NB=NB, T=T, D=D, BT=BT, BD=BD, num_warps=num_warps) + + def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | None = None): x_shape_og = x.shape x = x.view(-1, x.shape[-1]) @@ -114,26 +164,8 @@ def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: torch.dtype | N else: if D <= 512: NB = triton.cdiv(T, 2048) - - def grid(meta): - return (triton.cdiv(T, meta["BT"]),) - - l2norm_fwd_kernel[grid]( - x, - y, - eps, - NB=NB, - T=T, - D=D, - BD=BD, - ) + _l2norm_fwd_kernel_wrapper(x, y, eps, T, D, BD, NB) else: - l2norm_fwd_kernel1[(T,)]( - x, - y, - eps=eps, - D=D, - BD=BD, - ) + _l2norm_fwd_kernel1_wrapper(x, y, eps, D, BD) return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py index adfd8b11e..9b1cde861 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/solve_tril.py @@ -18,6 +18,7 @@ from .index import prepare_chunk_indices from .op import make_tensor_descriptor from .utils import input_guard, is_amd, is_tma_supported +from lightllm.common.triton_utils.autotuner import autotune FLA_TRIL_PRECISION = os.environ.get("FLA_TRIL_PRECISION", "ieee") ALLOWED_TRIL_PRECISIONS = ["ieee", "tf32"] if is_amd else ["ieee", "tf32", "tf32x3"] @@ -27,14 +28,6 @@ @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4, 8] - for num_stages in [2, 3, 4, 5] - ], - key=["BT"], -) @triton.jit(do_not_specialize=["T"]) def solve_tril_16x16_kernel( A, @@ -98,14 +91,6 @@ def solve_tril_16x16_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4, 8] - for num_stages in [2, 3, 4, 5] - ], - key=["H", "BT", "IS_VARLEN"], -) @triton.jit(do_not_specialize=["T"]) def merge_16x16_to_32x32_inverse_kernel( A, @@ -205,14 +190,6 @@ def merge_16x16_to_32x32_inverse_kernel( @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] - for num_stages in [2, 3, 4, 5] - ], - key=["H", "BT", "IS_VARLEN"], -) @triton.jit(do_not_specialize=["T"]) def merge_16x16_to_64x64_inverse_kernel( A, diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py index c5eaa9534..fb67297e4 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/wy_fast.py @@ -15,17 +15,10 @@ import triton.language as tl from .index import prepare_chunk_indices +from lightllm.common.triton_utils.autotuner import autotune @triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] - for num_stages in [2, 3, 4] - ], - key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], -) @triton.jit(do_not_specialize=["T"]) def recompute_w_u_fwd_kernel( k, From 0fec91cd7b07574787c32e09606046c952580d09 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 16:41:40 +0800 Subject: [PATCH 12/19] fix autotuner --- .../triton_kernel/fla/ops/chunk_delta_h.py | 12 +++++--- .../triton_kernel/fla/ops/chunk_o.py | 11 ++++++-- .../fla/ops/chunk_scaled_dot_kkt.py | 12 ++++---- .../qwen3next/triton_kernel/fla/ops/cumsum.py | 28 ++++++++++++------- .../qwen3next/triton_kernel/fla/ops/l2norm.py | 14 ++++++---- 5 files changed, 49 insertions(+), 28 deletions(-) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py index 5e9684da5..2ea88d41f 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -237,12 +237,16 @@ def _get_chunk_delta_h_configs(): ] -def _get_chunk_delta_h_static_key(H, K, V, BT, **kwargs): - return {"H": H, "K": K, "V": V, "BT": BT} +def _get_chunk_delta_h_static_key(k, u, chunk_size, **kwargs): + B, T, Hg, K = k.shape + V = u.shape[-1] + H = u.shape[-2] + return {"H": H, "K": K, "V": V, "BT": chunk_size} -def _get_chunk_delta_h_run_key(H, K, V, BT, **kwargs): - return f"{H}_{K}_{V}_{BT}" +def _get_chunk_delta_h_run_key(k, u, **kwargs): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] @autotune( diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py index 440d2f57c..9006fb0ac 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -130,12 +130,17 @@ def _get_chunk_o_configs(): ] -def _get_chunk_o_static_key(H, K, V, BT, **kwargs): +def _get_chunk_o_static_key(q, v, chunk_size, **kwargs): + B, T, Hg, K = q.shape + V = v.shape[-1] + H = v.shape[-2] + BT = 64 if FLA_GDN_FIX_BT else min(chunk_size, max(16, triton.next_power_of_2(T))) return {"H": H, "K": K, "V": V, "BT": BT} -def _get_chunk_o_run_key(H, K, V, BT, **kwargs): - return f"{H}_{K}_{V}_{BT}" +def _get_chunk_o_run_key(q, v, **kwargs): + # Return batch * heads as run key + return q.shape[0] * q.shape[2] @autotune( diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py index 3f746a0dd..c9fcb7d03 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -96,14 +96,16 @@ def _get_chunk_scaled_dot_kkt_configs(): ] -def _get_chunk_scaled_dot_kkt_static_key(H, K, BT, cu_seqlens, **kwargs): +def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size, cu_seqlens, **kwargs): + B, T, Hg, K = k.shape + H = beta.shape[-1] IS_VARLEN = cu_seqlens is not None - return {"H": H, "K": K, "BT": BT, "IS_VARLEN": IS_VARLEN} + return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN} -def _get_chunk_scaled_dot_kkt_run_key(H, K, BT, cu_seqlens, **kwargs): - IS_VARLEN = cu_seqlens is not None - return f"{H}_{K}_{BT}_{IS_VARLEN}" +def _get_chunk_scaled_dot_kkt_run_key(k, beta, **kwargs): + # Return batch * heads as run key + return k.shape[0] * k.shape[2] @autotune( diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py index 4c936020c..c548812ca 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -148,14 +148,18 @@ def _get_cumsum_scalar_configs(): return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] -def _get_cumsum_scalar_static_key(B, H, BT, REVERSE, cu_seqlens, **kwargs): +def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first, **kwargs): + if head_first: + B, H, T = g.shape + else: + B, T, H = g.shape IS_VARLEN = cu_seqlens is not None - return {"B": B, "H": H, "BT": BT, "IS_VARLEN": IS_VARLEN, "REVERSE": REVERSE} + return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} -def _get_cumsum_scalar_run_key(B, H, BT, REVERSE, cu_seqlens, **kwargs): - IS_VARLEN = cu_seqlens is not None - return f"{B}_{H}_{BT}_{IS_VARLEN}_{REVERSE}" +def _get_cumsum_scalar_run_key(g, **kwargs): + # Return total number of elements as run key + return g.shape[0] * g.shape[1] @autotune( @@ -210,14 +214,18 @@ def _get_cumsum_vector_configs(): return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] -def _get_cumsum_vector_static_key(B, H, S, BT, REVERSE, cu_seqlens, **kwargs): +def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first, **kwargs): + if head_first: + B, H, T, S = g.shape + else: + B, T, H, S = g.shape IS_VARLEN = cu_seqlens is not None - return {"B": B, "H": H, "S": S, "BT": BT, "IS_VARLEN": IS_VARLEN, "REVERSE": REVERSE} + return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} -def _get_cumsum_vector_run_key(B, H, S, BT, REVERSE, cu_seqlens, **kwargs): - IS_VARLEN = cu_seqlens is not None - return f"{B}_{H}_{S}_{BT}_{IS_VARLEN}_{REVERSE}" +def _get_cumsum_vector_run_key(g, **kwargs): + # Return batch * heads as run key + return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0] @autotune( diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py index e0af7495f..d05a955b0 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -80,12 +80,13 @@ def _get_l2norm_kernel1_configs(): return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] -def _get_l2norm_kernel1_static_key(D, **kwargs): +def _get_l2norm_kernel1_static_key(x, **kwargs): + D = x.shape[-1] return {"D": D} -def _get_l2norm_kernel1_run_key(D, **kwargs): - return f"{D}" +def _get_l2norm_kernel1_run_key(x, **kwargs): + return x.shape[0] # T @autotune( @@ -108,12 +109,13 @@ def _get_l2norm_kernel_configs(): return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] -def _get_l2norm_kernel_static_key(D, **kwargs): +def _get_l2norm_kernel_static_key(x, **kwargs): + D = x.shape[-1] return {"D": D} -def _get_l2norm_kernel_run_key(D, **kwargs): - return f"{D}" +def _get_l2norm_kernel_run_key(x, **kwargs): + return x.shape[0] # T @autotune( From 77b4a8e9e40a0c5a42a5e42eac7754f774274487 Mon Sep 17 00:00:00 2001 From: sufubao Date: Thu, 11 Dec 2025 12:49:45 +0000 Subject: [PATCH 13/19] update_kernel --- .../{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json | 8 ++++ .../{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json | 8 ++++ .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 8 ++++ .../{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json | 7 +++ ...ARLEN=true,REVERSE=false}_NVIDIA_H200.json | 38 ++++++++++++++++ ...H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json | 7 +++ ...8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 37 +--------------- ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 12 +++++ ...6,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 2 +- ...num=1,use_fp8_w8a8=false}_NVIDIA_H200.json | 27 ++++++++++++ ...um=10,use_fp8_w8a8=false}_NVIDIA_H200.json | 27 ++++++++++++ .../{topk_num=10}_NVIDIA_H200.json | 12 +++++ ...rch.bfloat16,topk_num=10}_NVIDIA_H200.json | 18 ++++++++ ...out_dtype=torch.bfloat16}_NVIDIA_H200.json | 36 +++++++++++++++ lightllm/common/triton_utils/autotuner.py | 4 +- .../layer_infer/transformer_layer_infer.py | 34 +++++++------- .../layer_weights/transformer_layer_weight.py | 1 - lightllm/models/qwen3next/model.py | 12 +++++ .../qwen3next/triton_kernel/fla/ops/chunk.py | 4 +- .../triton_kernel/fla/ops/chunk_delta_h.py | 4 +- .../triton_kernel/fla/ops/chunk_o.py | 4 +- .../fla/ops/chunk_scaled_dot_kkt.py | 6 ++- .../qwen3next/triton_kernel/fla/ops/cumsum.py | 8 ++-- .../qwen3next/triton_kernel/fla/ops/l2norm.py | 8 ++-- .../triton_kernel/fused_gdn_gating.py | 44 +++++++++++-------- lightllm/server/httpserver/manager.py | 2 +- lightllm/utils/log_utils.py | 2 +- 27 files changed, 287 insertions(+), 93 deletions(-) create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json create mode 100644 lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..4b002622a --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=16,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..cc5c68eb7 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=32,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 128, + "BV": 64, + "num_stages": 2, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..7421097fa --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_fwd_o/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,8 @@ +{ + "4": { + "BK": 64, + "BV": 128, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json new file mode 100644 index 000000000..d831f32c4 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_gated_delta_rule_fwd_h/{BT=64,H=8,K=128,V=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4": { + "BV": 32, + "num_stages": 4, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json new file mode 100644 index 000000000..412046d09 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -0,0 +1,38 @@ +{ + "1": { + "num_warps": 4 + }, + "100": { + "num_warps": 8 + }, + "1024": { + "num_warps": 4 + }, + "128": { + "num_warps": 1 + }, + "16": { + "num_warps": 4 + }, + "2048": { + "num_warps": 2 + }, + "256": { + "num_warps": 1 + }, + "32": { + "num_warps": 8 + }, + "4096": { + "num_warps": 2 + }, + "64": { + "num_warps": 8 + }, + "8": { + "num_warps": 8 + }, + "8448": { + "num_warps": 1 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json new file mode 100644 index 000000000..9fbae2414 --- /dev/null +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_scaled_dot_kkt_fwd/{BT=64,H=8,IS_VARLEN=true,K=128}_NVIDIA_H200.json @@ -0,0 +1,7 @@ +{ + "4": { + "BK": 64, + "num_stages": 3, + "num_warps": 4 + } +} \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json index b139a72ba..9a339538c 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -1,38 +1,3 @@ { - "1": { - "BLK_HEADS": 64, - "num_warps": 2 - }, - "100": { - "BLK_HEADS": 16, - "num_warps": 2 - }, - "1024": { - "BLK_HEADS": 8, - "num_warps": 2 - }, - "128": { - "BLK_HEADS": 64, - "num_warps": 2 - }, - "16": { - "BLK_HEADS": 16, - "num_warps": 1 - }, - "256": { - "BLK_HEADS": 16, - "num_warps": 2 - }, - "32": { - "BLK_HEADS": 16, - "num_warps": 1 - }, - "64": { - "BLK_HEADS": 8, - "num_warps": 2 - }, - "8": { - "BLK_HEADS": 64, - "num_warps": 4 - } + "8448": null } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json index 0b388b1a8..316fb7678 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -7,6 +7,10 @@ "BLOCK_N": 256, "num_warps": 1 }, + "16384": { + "BLOCK_N": 128, + "num_warps": 1 + }, "2048": { "BLOCK_N": 64, "num_warps": 1 @@ -15,6 +19,10 @@ "BLOCK_N": 256, "num_warps": 1 }, + "32768": { + "BLOCK_N": 256, + "num_warps": 2 + }, "512": { "BLOCK_N": 512, "num_warps": 4 @@ -23,6 +31,10 @@ "BLOCK_N": 256, "num_warps": 1 }, + "67584": { + "BLOCK_N": 64, + "num_warps": 1 + }, "8": { "BLOCK_N": 512, "num_warps": 8 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json index 10685c2e2..3fd0050d7 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gemma_rmsnorm_forward:v1/{N=2048,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -1,6 +1,6 @@ { "2048": { - "BLOCK_SIZE": 2048, + "BLOCK_SIZE": 4096, "num_stages": 4, "num_warps": 4 } diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json index dac851f69..fde50e757 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=128,N=2048,expert_num=512,mul_routed_weight=true,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -44,6 +44,15 @@ "num_stages": 2, "num_warps": 4 }, + "20480": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "2560": { "BLOCK_SIZE_K": 64, "BLOCK_SIZE_M": 16, @@ -62,6 +71,15 @@ "num_stages": 3, "num_warps": 4 }, + "40960": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "640": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -79,5 +97,14 @@ "NEED_TRANS": false, "num_stages": 2, "num_warps": 4 + }, + "84480": { + "BLOCK_SIZE_K": 32, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json index b38406dc3..612f2b51e 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/grouped_matmul:v1/{K=2048,N=256,expert_num=512,mul_routed_weight=false,out_dtype=torch.bfloat16,topk_num=10,use_fp8_w8a8=false}_NVIDIA_H200.json @@ -44,6 +44,15 @@ "num_stages": 3, "num_warps": 4 }, + "2048": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 64, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 1, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "256": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -62,6 +71,15 @@ "num_stages": 3, "num_warps": 4 }, + "4096": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 32, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 + }, "64": { "BLOCK_SIZE_K": 128, "BLOCK_SIZE_M": 16, @@ -79,5 +97,14 @@ "NEED_TRANS": false, "num_stages": 5, "num_warps": 4 + }, + "8448": { + "BLOCK_SIZE_K": 64, + "BLOCK_SIZE_M": 128, + "BLOCK_SIZE_N": 128, + "GROUP_SIZE_M": 16, + "NEED_TRANS": false, + "num_stages": 3, + "num_warps": 4 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json index 8e0ff1cf8..65c618475 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -19,6 +19,10 @@ "BLOCK_SIZE": 256, "num_warps": 8 }, + "2048": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, "256": { "BLOCK_SIZE": 128, "num_warps": 4 @@ -27,6 +31,10 @@ "BLOCK_SIZE": 128, "num_warps": 4 }, + "4096": { + "BLOCK_SIZE": 128, + "num_warps": 8 + }, "64": { "BLOCK_SIZE": 128, "num_warps": 8 @@ -34,5 +42,9 @@ "8": { "BLOCK_SIZE": 256, "num_warps": 8 + }, + "8448": { + "BLOCK_SIZE": 128, + "num_warps": 8 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json index e459d2f32..4d6191579 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_sum_reduce:v1/{hidden_dim=2048,out_dtype=torch.bfloat16,topk_num=10}_NVIDIA_H200.json @@ -29,6 +29,12 @@ "NUM_STAGE": 2, "num_warps": 16 }, + "2048": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 + }, "256": { "BLOCK_DIM": 1024, "BLOCK_M": 1, @@ -41,6 +47,12 @@ "NUM_STAGE": 1, "num_warps": 8 }, + "4096": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 4 + }, "64": { "BLOCK_DIM": 512, "BLOCK_M": 1, @@ -52,5 +64,11 @@ "BLOCK_M": 1, "NUM_STAGE": 1, "num_warps": 8 + }, + "8448": { + "BLOCK_DIM": 1024, + "BLOCK_M": 1, + "NUM_STAGE": 1, + "num_warps": 8 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 1a1eb3f74..4c0fdb9d2 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -59,6 +59,18 @@ "NUM_STAGES": 2, "num_warps": 4 }, + "2048": { + "BLOCK_M": 1, + "BLOCK_N": 256, + "NUM_STAGES": 1, + "num_warps": 4 + }, + "20480": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 1, + "num_warps": 1 + }, "256": { "BLOCK_M": 1, "BLOCK_N": 256, @@ -83,6 +95,18 @@ "NUM_STAGES": 1, "num_warps": 4 }, + "4096": { + "BLOCK_M": 8, + "BLOCK_N": 64, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "40960": { + "BLOCK_M": 1, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "64": { "BLOCK_M": 1, "BLOCK_N": 32, @@ -106,5 +130,17 @@ "BLOCK_N": 64, "NUM_STAGES": 2, "num_warps": 1 + }, + "8448": { + "BLOCK_M": 8, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, + "84480": { + "BLOCK_M": 32, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 } } \ No newline at end of file diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index a919f7b28..ec95c4b27 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -168,7 +168,7 @@ def __call__(self, *args, **kwargs): if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized(): logger.warning( f"No kernel config for {self.kernel_name} in {KernelConfigs.get_config_file_name(static_key)}," - f"the performance may be suboptimal!" + f"the performance may be suboptimal! " f"You can use LIGHTLLM_TRITON_AUTOTUNE_LEVEL=1 to enable autotune.", ) self.cached_configs[static_key] = {} @@ -215,7 +215,7 @@ def _try_load_cache(self, static_key): cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key)) if os.path.exists(cache_file): - logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") + logger.info(f"Loading cached configs for {self.kernel_name} - {static_key.items()}") with open(cache_file, "rb") as f: self.cached_configs[static_key] = orjson.loads(f.read()) return True diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 858e71685..17ecb13c2 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -103,22 +103,23 @@ def _get_qkv( @override def _get_o( - self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + self, input, gate_value, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ) -> torch.Tensor: - input = input * layer_weight._gate - layer_weight._gate = None - o_tensor = layer_weight.o_proj.mm(input) + # Handle different input shapes from different attention kernels + input = input.view(-1, gate_value.shape[-1]) + gated_input = input * gate_value + o_tensor = layer_weight.o_proj.mm(gated_input) return o_tensor def _context_full_attn( - self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + self, input, gate_value, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight ): q, cache_kv = self._get_qkv(input, infer_state, layer_weight) input = None self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) q = None - o = self._get_o(o, infer_state, layer_weight) + o = self._get_o(o, gate_value, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return o @@ -130,8 +131,8 @@ def context_forward( if self.is_linear: o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=True, infer_cls=self) else: - layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) - o = self._context_full_attn(input1, infer_state, layer_weight) + gate_value = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) + o = self._context_full_attn(input1, gate_value, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -143,13 +144,15 @@ def context_forward( input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return input_embdings - def _token_full_attn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight): + def _token_full_attn( + self, input, gate_value, infer_state: LlamaInferStateInfo, layer_weight: Qwen3NextTransformerLayerWeight + ): q, cache_kv = self._get_qkv(input, infer_state, layer_weight) input = None self._post_cache_kv(cache_kv, infer_state, layer_weight) o = self._token_attention_kernel(q, infer_state, layer_weight) q = None - o = self._get_o(o, infer_state, layer_weight) + o = self._get_o(o, gate_value, infer_state, layer_weight) if self.tp_world_size_ > 1: all_reduce(o, op=dist.ReduceOp.SUM, group=infer_state.dist_group, async_op=False) return o @@ -161,8 +164,8 @@ def token_forward( if self.is_linear: o = self.linear_attn_infer._linear_attn(input1, infer_state, layer_weight, is_prefill=False, infer_cls=self) else: - layer_weight._gate = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) - o = self._token_full_attn(input1, infer_state, layer_weight) + gate_value = torch.sigmoid(layer_weight.o_gate_proj.mm(input1)) + o = self._token_full_attn(input1, gate_value, infer_state, layer_weight) input_embdings.add_(o.view(-1, self.embed_dim_)) o = None @@ -284,10 +287,7 @@ def _linear_attn( # Rearrange mixed_qkv to query, key, value query, key, value = self._rearrange_mixed_qkv(mixed_qkv) - # Compute beta and g - beta = b.sigmoid() - g = fused_gdn_gating(layer_weight.linear_A_log.weight, a, layer_weight.linear_dt_bias.weight) - g, beta = map(lambda x: rearrange(x, "l d -> 1 l d"), (g, beta)) + g, beta = fused_gdn_gating(layer_weight.linear_A_log.weight, a, b, layer_weight.linear_dt_bias.weight) if is_prefill: initial_state = ssm_states[buffer_idx].contiguous() @@ -304,7 +304,7 @@ def _linear_attn( use_qk_l2norm_in_kernel=True, ) # Update SSM state with final state - ssm_states[buffer_idx, ...] = last_recurrent_state.to(ssm_states.dtype) + ssm_states[buffer_idx, ...] = last_recurrent_state.to(ssm_states.dtype, copy=False) else: batch_size = input.shape[0] cu_seqlens = torch.arange(0, batch_size + 1, dtype=torch.int32, device=input.device) diff --git a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py index da6168593..611e2bc5a 100644 --- a/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen3next/layer_weights/transformer_layer_weight.py @@ -57,7 +57,6 @@ def _init_weight(self): layer_num=self.layer_num_, name="o_gate_proj", ) - self._gate = None return @override diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 5e6cfc2ef..7fa6f6d88 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -1,5 +1,7 @@ import torch +from typing import Optional from typing_extensions import override +import triton from lightllm.models.registry import ModelRegistry from lightllm.models.qwen3_moe.model import Qwen3MOEModel from lightllm.models.qwen3next.layer_weights.transformer_layer_weight import Qwen3NextTransformerLayerWeight @@ -16,6 +18,10 @@ logger = init_logger(__name__) +def _triton_allocator(size: int, alignment: int, stream: Optional[int]) -> torch.Tensor: + return torch.empty(size, device="cuda", dtype=torch.int8) + + @ModelRegistry("qwen3_next") class Qwen3NextTpPartModel(Qwen3MOEModel): # weight class @@ -27,6 +33,12 @@ class Qwen3NextTpPartModel(Qwen3MOEModel): def __init__(self, kvargs) -> None: self.mem_manager: Qwen3NextMemoryManager = None + + # Set Triton allocator for TMA descriptors + # This is required for kernels in qwen3next/triton_kernel/fla/ops/solve_tril.py + triton.set_allocator(_triton_allocator) + logger.info("Triton allocator set for Qwen3Next model") + super().__init__(kvargs) @override diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py index 944c71e2b..db4969cb0 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk.py @@ -35,7 +35,7 @@ def chunk_gated_delta_rule_fwd( ): g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) + A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g=g, cu_seqlens=cu_seqlens, chunk_size=64, output_dtype=torch.float32) A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) w, u = recompute_w_u_fwd( k=k, @@ -53,6 +53,7 @@ def chunk_gated_delta_rule_fwd( initial_state=initial_state, output_final_state=output_final_state, cu_seqlens=cu_seqlens, + chunk_size=64, ) o = chunk_fwd_o( q=q, @@ -62,6 +63,7 @@ def chunk_gated_delta_rule_fwd( g=g, scale=scale, cu_seqlens=cu_seqlens, + chunk_size=64, ) if SUPPRESS_LEVEL < 3: return g, o, A, final_state, None, None, None diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py index 2ea88d41f..f34029927 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -237,14 +237,14 @@ def _get_chunk_delta_h_configs(): ] -def _get_chunk_delta_h_static_key(k, u, chunk_size, **kwargs): +def _get_chunk_delta_h_static_key(k, u, chunk_size): B, T, Hg, K = k.shape V = u.shape[-1] H = u.shape[-2] return {"H": H, "K": K, "V": V, "BT": chunk_size} -def _get_chunk_delta_h_run_key(k, u, **kwargs): +def _get_chunk_delta_h_run_key(k, u): # Return batch * heads as run key return k.shape[0] * k.shape[2] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py index 9006fb0ac..12ee5a37e 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -130,7 +130,7 @@ def _get_chunk_o_configs(): ] -def _get_chunk_o_static_key(q, v, chunk_size, **kwargs): +def _get_chunk_o_static_key(q, v, chunk_size): B, T, Hg, K = q.shape V = v.shape[-1] H = v.shape[-2] @@ -138,7 +138,7 @@ def _get_chunk_o_static_key(q, v, chunk_size, **kwargs): return {"H": H, "K": K, "V": V, "BT": BT} -def _get_chunk_o_run_key(q, v, **kwargs): +def _get_chunk_o_run_key(q, v): # Return batch * heads as run key return q.shape[0] * q.shape[2] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py index c9fcb7d03..32a87f7a0 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -17,6 +17,8 @@ from .op import exp from lightllm.common.triton_utils.autotuner import autotune +triton.set_allocator + @triton.heuristics( { @@ -96,14 +98,14 @@ def _get_chunk_scaled_dot_kkt_configs(): ] -def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size, cu_seqlens, **kwargs): +def _get_chunk_scaled_dot_kkt_static_key(k, beta, chunk_size=64, cu_seqlens=None): B, T, Hg, K = k.shape H = beta.shape[-1] IS_VARLEN = cu_seqlens is not None return {"H": H, "K": K, "BT": chunk_size, "IS_VARLEN": IS_VARLEN} -def _get_chunk_scaled_dot_kkt_run_key(k, beta, **kwargs): +def _get_chunk_scaled_dot_kkt_run_key(k, beta): # Return batch * heads as run key return k.shape[0] * k.shape[2] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py index c548812ca..64ec2d6cd 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/cumsum.py @@ -148,7 +148,7 @@ def _get_cumsum_scalar_configs(): return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8]] -def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first, **kwargs): +def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first): if head_first: B, H, T = g.shape else: @@ -157,7 +157,7 @@ def _get_cumsum_scalar_static_key(g, chunk_size, reverse, cu_seqlens, head_first return {"B": B, "H": H, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} -def _get_cumsum_scalar_run_key(g, **kwargs): +def _get_cumsum_scalar_run_key(g): # Return total number of elements as run key return g.shape[0] * g.shape[1] @@ -214,7 +214,7 @@ def _get_cumsum_vector_configs(): return [{"BS": BS, "num_warps": num_warps} for BS in BS_LIST for num_warps in [2, 4, 8]] -def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first, **kwargs): +def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first): if head_first: B, H, T, S = g.shape else: @@ -223,7 +223,7 @@ def _get_cumsum_vector_static_key(g, chunk_size, reverse, cu_seqlens, head_first return {"B": B, "H": H, "S": S, "BT": chunk_size, "IS_VARLEN": IS_VARLEN, "REVERSE": reverse} -def _get_cumsum_vector_run_key(g, **kwargs): +def _get_cumsum_vector_run_key(g): # Return batch * heads as run key return g.shape[0] * g.shape[2] if len(g.shape) == 4 else g.shape[0] diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py index d05a955b0..29f892ef2 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/l2norm.py @@ -80,12 +80,12 @@ def _get_l2norm_kernel1_configs(): return [{"num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16, 32]] -def _get_l2norm_kernel1_static_key(x, **kwargs): +def _get_l2norm_kernel1_static_key(x): D = x.shape[-1] return {"D": D} -def _get_l2norm_kernel1_run_key(x, **kwargs): +def _get_l2norm_kernel1_run_key(x): return x.shape[0] # T @@ -109,12 +109,12 @@ def _get_l2norm_kernel_configs(): return [{"BT": BT, "num_warps": num_warps} for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST] -def _get_l2norm_kernel_static_key(x, **kwargs): +def _get_l2norm_kernel_static_key(x): D = x.shape[-1] return {"D": D} -def _get_l2norm_kernel_run_key(x, **kwargs): +def _get_l2norm_kernel_run_key(x): return x.shape[0] # T diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py index 99a5e2f70..ffe68a637 100644 --- a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -1,14 +1,21 @@ +# Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py +from typing import Tuple + import torch import triton import triton.language as tl + from lightllm.common.triton_utils.autotuner import autotune # g = -self.A_log.float().exp() * F.softplus(a.float() + self.dt_bias) +# beta_output = b.sigmoid() @triton.jit def fused_gdn_gating_kernel( g, + beta_output, A_log, a, + b, dt_bias, seq_len, NUM_HEADS: tl.constexpr, @@ -22,16 +29,18 @@ def fused_gdn_gating_kernel( mask = head_off < NUM_HEADS blk_A_log = tl.load(A_log + head_off, mask=mask) blk_a = tl.load(a + off, mask=mask) + blk_b = tl.load(b + off, mask=mask) blk_bias = tl.load(dt_bias + head_off, mask=mask) - # If the model is loaded in fp16, without the .float() here, A might be -inf x = blk_a.to(tl.float32) + blk_bias.to(tl.float32) softplus_x = tl.where(beta * x <= threshold, (1 / beta) * tl.log(1 + tl.exp(beta * x)), x) blk_g = -tl.exp(blk_A_log.to(tl.float32)) * softplus_x tl.store(g + off, blk_g.to(g.dtype.element_ty), mask=mask) + blk_beta_output = tl.sigmoid(blk_b.to(tl.float32)) + tl.store(beta_output + off, blk_beta_output.to(b.dtype.element_ty), mask=mask) def _get_fused_gdn_gating_configs(): - return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [8, 16, 32, 64] for nw in [1, 2, 4]] + return [{"BLK_HEADS": bh, "num_warps": nw} for bh in [4, 8, 16, 32, 64] for nw in [1, 2, 4]] def _get_fused_gdn_gating_static_key(a: torch.Tensor): @@ -48,36 +57,33 @@ def _get_fused_gdn_gating_static_key(a: torch.Tensor): def fused_gdn_gating( A_log: torch.Tensor, a: torch.Tensor, + b: torch.Tensor, dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, - run_config: dict | None = None, -) -> torch.Tensor: - batch, num_heads = a.shape - seq_len = 1 + run_config: dict = None, +) -> Tuple[torch.Tensor, torch.Tensor]: - # default heuristic when autotune is disabled if not run_config: - # choose the largest block size that does not exceed num_heads - candidate_blk = [8, 16, 32, 64] - blk_heads = max([c for c in candidate_blk if c <= max(8, num_heads)] or [8]) - run_config = {"BLK_HEADS": blk_heads, "num_warps": 1} + run_config = {"BLK_HEADS": 8, "num_warps": 1} - BLK_HEADS = run_config["BLK_HEADS"] - num_warps = run_config.get("num_warps", 1) - - grid = (batch, seq_len, triton.cdiv(num_heads, BLK_HEADS)) - g = torch.empty_like(a, dtype=torch.float32) + batch, num_heads = a.shape + seq_len = 1 + grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) + beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) fused_gdn_gating_kernel[grid]( g, + beta_output, A_log, a, + b, dt_bias, seq_len, num_heads, beta, threshold, - BLK_HEADS, - num_warps=num_warps, + run_config["BLK_HEADS"], + num_warps=run_config["num_warps"], ) - return g + return g, beta_output diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index e1cb32b88..434848f09 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -599,7 +599,7 @@ async def _wait_to_token_package( (out_token_counter - metadata["mtp_accepted_token_num"]), 1 ) format_start_time = datetime.datetime.fromtimestamp(start_time).strftime("%Y-%m-%d %H:%M:%S") - logger.info( + logger.debug( f"X-Request-Id:{x_request_id} " f"X-Session-Id:{x_session_id} start_time:{format_start_time} " f"lightllm_req_id:{group_request_id} first_token_cost:{first_token_cost_ms}ms " diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index 6bbe87373..03409b4df 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -10,7 +10,7 @@ _FORMAT = "%(levelname)s %(asctime)s [%(filename)s:%(lineno)d] %(message)s" _DATE_FORMAT = "%m-%d %H:%M:%S" -_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "debug") +_LOG_LEVEL = os.environ.get("LIGHTLLM_LOG_LEVEL", "info") _LOG_LEVEL = getattr(logging, _LOG_LEVEL.upper(), 0) _LOG_DIR = os.environ.get("LIGHTLLM_LOG_DIR", None) From 117a23c81d48371801009baa71544b7a4517c74b Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 12 Dec 2025 06:46:55 +0000 Subject: [PATCH 14/19] update kernel --- ...8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 49 ++- .../layer_infer/transformer_layer_infer.py | 1 + lightllm/models/qwen3next/model.py | 16 +- .../qwen3next/triton_kernel/fla/__init__.py | 3 + .../triton_kernel/fla_bak/__init__.py | 15 - .../qwen3next/triton_kernel/fla_bak/chunk.py | 225 ----------- .../triton_kernel/fla_bak/chunk_delta_h.py | 257 ------------ .../triton_kernel/fla_bak/chunk_o.py | 167 -------- .../fla_bak/chunk_scaled_dot_kkt.py | 136 ------- .../qwen3next/triton_kernel/fla_bak/cumsum.py | 200 ---------- .../triton_kernel/fla_bak/fused_recurrent.py | 367 ------------------ .../qwen3next/triton_kernel/fla_bak/index.py | 30 -- .../qwen3next/triton_kernel/fla_bak/l2norm.py | 137 ------- .../qwen3next/triton_kernel/fla_bak/op.py | 36 -- .../triton_kernel/fla_bak/solve_tril.py | 271 ------------- .../qwen3next/triton_kernel/fla_bak/utils.py | 173 --------- .../triton_kernel/fla_bak/wy_fast.py | 122 ------ .../triton_kernel/fused_gdn_gating.py | 8 +- lightllm/server/api_cli.py | 9 +- lightllm/server/core/objs/start_args_type.py | 5 +- .../dynamic_prompt/hybrid_radix_cache.py | 4 +- .../server/router/model_infer/infer_batch.py | 7 +- test/test_api/test_gsmk.py | 230 +++++++++++ 23 files changed, 310 insertions(+), 2158 deletions(-) delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/index.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/op.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py delete mode 100644 lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py create mode 100644 test/test_api/test_gsmk.py diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json index 9a339538c..e6e34fdbe 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -1,3 +1,50 @@ { - "8448": null + "1": { + "BLK_HEADS": 64, + "num_warps": 1 + }, + "100": { + "BLK_HEADS": 4, + "num_warps": 4 + }, + "1024": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "128": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "16": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "2048": { + "BLK_HEADS": 16, + "num_warps": 1 + }, + "256": { + "BLK_HEADS": 32, + "num_warps": 2 + }, + "32": { + "BLK_HEADS": 8, + "num_warps": 1 + }, + "4096": { + "BLK_HEADS": 16, + "num_warps": 4 + }, + "64": { + "BLK_HEADS": 64, + "num_warps": 2 + }, + "8": { + "BLK_HEADS": 8, + "num_warps": 2 + }, + "8448": { + "BLK_HEADS": 32, + "num_warps": 4 + } } \ No newline at end of file diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 17ecb13c2..53835bfb1 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -270,6 +270,7 @@ def _linear_attn( bias=layer_weight.linear_conv1d.mm_param.bias, query_start_loc=infer_state.b1_cu_q_seq_len, cache_indices=buffer_idx, + has_initial_state=infer_state.b_ready_cache_len > 0, conv_states=conv_states.transpose(1, 2), activation=self.activation, ) diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index 7fa6f6d88..a8558d37a 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -1,3 +1,4 @@ +import os import torch from typing import Optional from typing_extensions import override @@ -62,11 +63,11 @@ def _init_mem_manager(self): start_args: StartArgs = get_env_start_args() mtp_step = start_args.mtp_step - linear_attn_cache_size = start_args.linear_attn_cache_size - if linear_attn_cache_size is not None: + mamba_cache_size = start_args.mamba_cache_size + if mamba_cache_size is not None: assert ( - linear_attn_cache_size >= start_args.running_max_req_size - ), "linear_attn_cache_size must be greater than running_max_req_size" + mamba_cache_size >= start_args.running_max_req_size + ), "mamba_cache_size must be greater than running_max_req_size" self.num_linear_k_heads = self.config["linear_num_key_heads"] self.num_linear_v_heads = self.config["linear_num_value_heads"] @@ -78,9 +79,12 @@ def _init_mem_manager(self): self.head_linear_k_dim * self.num_linear_k_heads * 2 + self.head_linear_v_dim * self.num_linear_v_heads ) + ssm_dtype_dict = {"bfloat16": torch.bfloat16, "float32": torch.float32} + assert start_args.mamba_ssm_data_type in ssm_dtype_dict + self.mem_manager = Qwen3NextMemoryManager( full_attn_cache_size=self.max_total_token_num, - linear_attn_cache_size=linear_attn_cache_size, + linear_attn_cache_size=mamba_cache_size, dtype=self.data_type, num_kv_heads=self.num_kv_heads, head_dim=self.config["head_dim"], @@ -89,7 +93,7 @@ def _init_mem_manager(self): full_attention_interval=self.config["full_attention_interval"], conv_state_dtype=self.data_type, conv_state_shape=(conv_kernel_size - 1 + mtp_step, conv_dim // self.tp_world_size_), - ssm_state_dtype=self.data_type, + ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], ssm_state_shape=( # mtp_step + 1, self.num_linear_v_heads // self.tp_world_size_, diff --git a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py index 0e89cf9f7..2bde70bb9 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/__init__.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/__init__.py @@ -6,3 +6,6 @@ # The original source code was licensed under the MIT license and included # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang + +# Adapted from +# https://github.com/vllm-project/vllm diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py deleted file mode 100644 index cd3b0962a..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -from .chunk import chunk_gated_delta_rule -from .fused_recurrent import fused_recurrent_gated_delta_rule - -__all__ = [ - "chunk_gated_delta_rule", - "fused_recurrent_gated_delta_rule", -] diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py deleted file mode 100644 index 22c81ae63..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk.py +++ /dev/null @@ -1,225 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -import warnings -from typing import Optional - -import torch -from einops import rearrange - -from .chunk_delta_h import chunk_gated_delta_rule_fwd_h -from .chunk_o import chunk_fwd_o -from .chunk_scaled_dot_kkt import chunk_scaled_dot_kkt_fwd -from .cumsum import chunk_local_cumsum -from .l2norm import l2norm_fwd -from .solve_tril import solve_tril -from .utils import SUPPRESS_LEVEL, input_guard -from .wy_fast import recompute_w_u_fwd - - -def chunk_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, -): - g = chunk_local_cumsum(g, chunk_size=64, cu_seqlens=cu_seqlens) - # obtain WY representation. u is actually the new v. - A = chunk_scaled_dot_kkt_fwd(k=k, beta=beta, g_cumsum=g, cu_seqlens=cu_seqlens, output_dtype=torch.float32) - A = solve_tril(A=A, cu_seqlens=cu_seqlens, output_dtype=k.dtype) - w, u = recompute_w_u_fwd( - k=k, - v=v, - beta=beta, - A=A, - g_cumsum=g, - cu_seqlens=cu_seqlens, - ) - h, v_new, final_state = chunk_gated_delta_rule_fwd_h( - k=k, - w=w, - u=u, - g=g, - initial_state=initial_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - ) - o = chunk_fwd_o( - q=q, - k=k, - v=v_new, - h=h, - g=g, - scale=scale, - cu_seqlens=cu_seqlens, - ) - if SUPPRESS_LEVEL < 3: - return g, o, A, final_state, None, None, None - elif SUPPRESS_LEVEL >= 3: - return g, o, A, final_state, w, h, v_new - - -class ChunkGatedDeltaRuleFunction(torch.autograd.Function): - @staticmethod - @input_guard - @torch.amp.custom_fwd(device_type="cuda") - def forward( - ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - output_final_state: bool, - cu_seqlens: Optional[torch.LongTensor] = None, - use_qk_l2norm_in_kernel: bool = False, - ): - if use_qk_l2norm_in_kernel: - q = l2norm_fwd(q) - k = l2norm_fwd(k) - - g, o, A, final_state, w, h, v_new = chunk_gated_delta_rule_fwd( - q=q, - k=k, - v=v, - g=g, - beta=beta, - scale=scale, - initial_state=initial_state, - output_final_state=output_final_state, - cu_seqlens=cu_seqlens, - ) - ctx.scale = scale - ctx.use_qk_l2norm_in_kernel = use_qk_l2norm_in_kernel - return o.to(q.dtype), final_state - - -@torch.compiler.disable -def chunk_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float = None, - initial_state: torch.Tensor = None, - output_final_state: bool = False, - cu_seqlens: Optional[torch.LongTensor] = None, - head_first: bool = False, - use_qk_l2norm_in_kernel: bool = False, -): - r""" - Args: - q (torch.Tensor): - queries of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. - k (torch.Tensor): - keys of shape `[B, T, H, K]` if `head_first=False` else `[B, H, T, K]`. - v (torch.Tensor): - values of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. - g (torch.Tensor): - (forget) gating tensor (in log space!) of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. - beta (torch.Tensor): - betas of shape `[B, T, H]` if `head_first=False` else `[B, H, T]`. - scale (Optional[int]): - Scale factor for the RetNet attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, H, K, V]` for `N` input sequences. - For equal-length input sequences, `N` equals the batch size `B`. - Default: `None`. - output_final_state (Optional[bool]): - Whether to output the final state of shape `[N, H, K, V]`. Default: `False`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - head_first (Optional[bool]): - Whether the inputs are in the head-first format, which is not supported for variable-length inputs. - Default: `False`. - - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, H, V]` if `head_first=False` else `[B, H, T, V]`. - final_state (torch.Tensor): - Final state of shape `[N, H, K, V]` if `output_final_state=True` else `None`. - - Examples:: - >>> import torch - >>> import torch.nn.functional as F - >>> from einops import rearrange - >>> from fla.ops.gated_delta_rule import chunk_gated_delta_rule - # inputs with equal lengths - >>> B, T, H, K, V = 4, 2048, 4, 512, 512 - >>> q = torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda') - >>> k = F.normalize(torch.randn(B, T, H, K, dtype=torch.bfloat16, device='cuda'), p=2, dim=-1) - >>> v = torch.randn(B, T, H, V, dtype=torch.bfloat16, device='cuda') - >>> beta = torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda').sigmoid() - >>> g = F.logsigmoid(torch.rand(B, T, H, dtype=torch.bfloat16, device='cuda')) - >>> h0 = torch.randn(B, H, K, V, dtype=torch.bfloat16, device='cuda') - >>> o, ht = chunk_gated_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True - ) - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required - >>> q, k, v, beta, g = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, beta, g)) - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = chunk_gated_delta_rule( - q, k, v, g, beta, - initial_state=h0, - output_final_state=True, - cu_seqlens=cu_seqlens - ) - """ - assert q.dtype == k.dtype == v.dtype - assert q.dtype != torch.float32, "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." - assert len(beta.shape) == 3, "beta must be of shape [B, T, H] if head_first=False, or [B, H, T] otherwise." - - if head_first: - raise DeprecationWarning( - "head_first is deprecated and will be removed in a future version. " - "Please use head_first=False for now instead.", - stacklevel=2, - ) - q, k, v, beta, g = map(lambda x: rearrange(x, "b h t ... -> b t h ..."), (q, k, v, beta, g)) - if not head_first and q.shape[1] < q.shape[2]: - warnings.warn( - f"Input tensor shape suggests potential format mismatch" - f" seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2, - ) - if cu_seqlens is not None: - if q.shape[0] != 1: - raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing." - ) - if initial_state is not None and initial_state.shape[0] != len(cu_seqlens) - 1: - raise ValueError( - f"The number of initial states is expected to be equal to the number of input sequences, " - f"i.e., {len(cu_seqlens) - 1} rather than {initial_state.shape[0]}." - ) - if scale is None: - scale = k.shape[-1] ** -0.5 - o, final_state = ChunkGatedDeltaRuleFunction.apply( - q, k, v, g, beta, scale, initial_state, output_final_state, cu_seqlens, use_qk_l2norm_in_kernel - ) - if head_first: - o = rearrange(o, "b t h ... -> b h t ...") - return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py deleted file mode 100644 index f20c95d90..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_delta_h.py +++ /dev/null @@ -1,257 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .index import prepare_chunk_indices, prepare_chunk_offsets -from .op import exp -from .utils import is_nvidia_hopper, use_cuda_graph - -NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8, 16] - - -@triton.heuristics( - { - "USE_G": lambda args: args["g"] is not None, - "USE_INITIAL_STATE": lambda args: args["h0"] is not None, - "STORE_FINAL_STATE": lambda args: args["ht"] is not None, - "SAVE_NEW_VALUE": lambda args: args["v_new"] is not None, - "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, - } -) -@triton.autotune( - configs=[ - triton.Config({"BV": BV}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4] - for num_stages in [2, 3, 4] - for BV in [32, 64] - ], - key=["H", "K", "V", "BT", "USE_G"], - use_cuda_graph=use_cuda_graph, -) -@triton.jit(do_not_specialize=["T"]) -def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( - k, - v, - w, - v_new, - g, - h, - h0, - ht, - cu_seqlens, - chunk_offsets, - T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BT: tl.constexpr, - BV: tl.constexpr, - USE_G: tl.constexpr, - USE_INITIAL_STATE: tl.constexpr, - STORE_FINAL_STATE: tl.constexpr, - SAVE_NEW_VALUE: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - i_v, i_nh = tl.program_id(0), tl.program_id(1) - i_n, i_h = i_nh // H, i_nh % H - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - NT = tl.cdiv(T, BT) - boh = tl.load(chunk_offsets + i_n).to(tl.int32) - else: - bos, eos = i_n * T, i_n * T + T - NT = tl.cdiv(T, BT) - boh = i_n * NT - - # [BK, BV] - b_h1 = tl.zeros([64, BV], dtype=tl.float32) - if K > 64: - b_h2 = tl.zeros([64, BV], dtype=tl.float32) - if K > 128: - b_h3 = tl.zeros([64, BV], dtype=tl.float32) - if K > 192: - b_h4 = tl.zeros([64, BV], dtype=tl.float32) - - # calculate offset - h += (boh * H + i_h) * K * V - v += (bos * H + i_h) * V - k += (bos * Hg + i_h // (H // Hg)) * K - w += (bos * H + i_h) * K - if SAVE_NEW_VALUE: - v_new += (bos * H + i_h) * V - stride_v = H * V - stride_h = H * K * V - stride_k = Hg * K - stride_w = H * K - if USE_INITIAL_STATE: - h0 = h0 + i_nh * K * V - if STORE_FINAL_STATE: - ht = ht + i_nh * K * V - - # load initial state - if USE_INITIAL_STATE: - p_h0_1 = tl.make_block_ptr(h0, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) - b_h1 += tl.load(p_h0_1, boundary_check=(0, 1)).to(tl.float32) - if K > 64: - p_h0_2 = tl.make_block_ptr(h0, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) - b_h2 += tl.load(p_h0_2, boundary_check=(0, 1)).to(tl.float32) - if K > 128: - p_h0_3 = tl.make_block_ptr(h0, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) - b_h3 += tl.load(p_h0_3, boundary_check=(0, 1)).to(tl.float32) - if K > 192: - p_h0_4 = tl.make_block_ptr(h0, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) - b_h4 += tl.load(p_h0_4, boundary_check=(0, 1)).to(tl.float32) - - # main recurrence - for i_t in range(NT): - p_h1 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h1, b_h1.to(p_h1.dtype.element_ty), boundary_check=(0, 1)) - if K > 64: - p_h2 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h2, b_h2.to(p_h2.dtype.element_ty), boundary_check=(0, 1)) - if K > 128: - p_h3 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h3, b_h3.to(p_h3.dtype.element_ty), boundary_check=(0, 1)) - if K > 192: - p_h4 = tl.make_block_ptr(h + i_t * stride_h, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) - tl.store(p_h4, b_h4.to(p_h4.dtype.element_ty), boundary_check=(0, 1)) - - p_v = tl.make_block_ptr(v, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_v_new = ( - tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - if SAVE_NEW_VALUE - else None - ) - b_v_new = tl.zeros([BT, BV], dtype=tl.float32) - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 0), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h1.to(b_w.dtype)) - if K > 64: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 64), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h2.to(b_w.dtype)) - if K > 128: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 128), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h3.to(b_w.dtype)) - if K > 192: - p_w = tl.make_block_ptr(w, (T, K), (stride_w, 1), (i_t * BT, 192), (BT, 64), (1, 0)) - b_w = tl.load(p_w, boundary_check=(0, 1)) - b_v_new += tl.dot(b_w, b_h4.to(b_w.dtype)) - b_v_new = -b_v_new + tl.load(p_v, boundary_check=(0, 1)) - - if SAVE_NEW_VALUE: - p_v_new = tl.make_block_ptr(v_new, (T, V), (stride_v, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - tl.store(p_v_new, b_v_new.to(p_v_new.dtype.element_ty), boundary_check=(0, 1)) - - if USE_G: - m_t = (i_t * BT + tl.arange(0, BT)) < T - last_idx = min((i_t + 1) * BT, T) - 1 - b_g_last = tl.load(g + bos * H + last_idx * H + i_h) - p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_g = tl.load(p_g, boundary_check=(0,)) - b_v_new = b_v_new * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] - b_g_last = exp(b_g_last) - b_h1 = b_h1 * b_g_last - if K > 64: - b_h2 = b_h2 * b_g_last - if K > 128: - b_h3 = b_h3 * b_g_last - if K > 192: - b_h4 = b_h4 * b_g_last - b_v_new = b_v_new.to(k.dtype.element_ty) - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (0, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h1 += tl.dot(b_k, b_v_new) - if K > 64: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (64, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h2 += tl.dot(b_k, b_v_new) - if K > 128: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (128, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h3 += tl.dot(b_k, b_v_new) - if K > 192: - p_k = tl.make_block_ptr(k, (K, T), (1, stride_k), (192, i_t * BT), (64, BT), (0, 1)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_h4 += tl.dot(b_k, b_v_new) - - # epilogue - if STORE_FINAL_STATE: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (0, i_v * BV), (64, BV), (1, 0)) - tl.store(p_ht, b_h1.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - if K > 64: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (64, i_v * BV), (64, BV), (1, 0)) - tl.store(p_ht, b_h2.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - if K > 128: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (128, i_v * BV), (64, BV), (1, 0)) - tl.store(p_ht, b_h3.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - if K > 192: - p_ht = tl.make_block_ptr(ht, (K, V), (V, 1), (192, i_v * BV), (64, BV), (1, 0)) - tl.store(p_ht, b_h4.to(p_ht.dtype.element_ty), boundary_check=(0, 1)) - - -def chunk_gated_delta_rule_fwd_h( - k: torch.Tensor, - w: torch.Tensor, - u: torch.Tensor, - g: Optional[torch.Tensor] = None, - initial_state: Optional[torch.Tensor] = None, - output_final_state: bool = False, - chunk_size: int = 64, # SY: remove this argument and force chunk size 64? - save_new_value: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, -) -> tuple[torch.Tensor, torch.Tensor]: - B, T, Hg, K, V = *k.shape, u.shape[-1] - H = u.shape[-2] - BT = chunk_size - - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None - # N: the actual number of sequences in the batch with either equal or variable lengths - if cu_seqlens is None: - N, NT, chunk_offsets = B, triton.cdiv(T, BT), None - else: - N, NT, chunk_offsets = len(cu_seqlens) - 1, len(chunk_indices), prepare_chunk_offsets(cu_seqlens, BT) - assert K <= 256, "current kernel does not support head dimension larger than 256." - - h = k.new_empty(B, NT, H, K, V) - final_state = k.new_empty(N, H, K, V, dtype=torch.float32) if output_final_state else None - - v_new = torch.empty_like(u) if save_new_value else None - - def grid(meta): - return (triton.cdiv(V, meta["BV"]), N * H) - - chunk_gated_delta_rule_fwd_kernel_h_blockdim64[grid]( - k=k, - v=u, - w=w, - v_new=v_new, - g=g, - h=h, - h0=initial_state, - ht=final_state, - cu_seqlens=cu_seqlens, - chunk_offsets=chunk_offsets, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - ) - return h, v_new, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py deleted file mode 100644 index 73c2e1f19..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_o.py +++ /dev/null @@ -1,167 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - -# ruff: noqa: E501 - -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .index import prepare_chunk_indices -from .op import exp -from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper - -BKV_LIST = [64, 128] if check_shared_mem() else [32, 64] -NUM_WARPS = [2, 4] if is_nvidia_hopper else [2, 4, 8] - - -@triton.heuristics( - {"USE_G": lambda args: args["g"] is not None, "IS_VARLEN": lambda args: args["cu_seqlens"] is not None} -) -@triton.autotune( - configs=[ - triton.Config({"BK": BK, "BV": BV}, num_warps=num_warps, num_stages=num_stages) - for BK in BKV_LIST - for BV in BKV_LIST - for num_warps in NUM_WARPS - for num_stages in [2, 3, 4] - ], - key=["H", "K", "V", "BT"], -) -@triton.jit(do_not_specialize=["T"]) -def chunk_fwd_kernel_o( - q, - k, - v, - h, - g, - o, - cu_seqlens, - chunk_indices, - scale, - T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BT: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - USE_G: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - i_v, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H - - if IS_VARLEN: - i_tg = i_t - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - NT = tl.cdiv(T, BT) - else: - NT = tl.cdiv(T, BT) - i_tg = i_b * NT + i_t - bos, eos = i_b * T, i_b * T + T - - # offset calculation - q += (bos * Hg + i_h // (H // Hg)) * K - k += (bos * Hg + i_h // (H // Hg)) * K - v += (bos * H + i_h) * V - o += (bos * H + i_h) * V - h += (i_tg * H + i_h).to(tl.int64) * K * V - - b_o = tl.zeros([BT, BV], dtype=tl.float32) - b_A = tl.zeros([BT, BT], dtype=tl.float32) - - for i_k in range(tl.cdiv(K, BK)): - p_q = tl.make_block_ptr(q, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - p_k = tl.make_block_ptr(k, (K, T), (1, Hg * K), (i_k * BK, i_t * BT), (BK, BT), (0, 1)) - p_h = tl.make_block_ptr(h, (K, V), (V, 1), (i_k * BK, i_v * BV), (BK, BV), (1, 0)) - # [BT, BK] - b_q = tl.load(p_q, boundary_check=(0, 1)) - # [BK, BT] - b_k = tl.load(p_k, boundary_check=(0, 1)) - # [BK, BV] - b_h = tl.load(p_h, boundary_check=(0, 1)) - - # [BT, BK] @ [BK, BV] -> [BT, BV] - b_o += tl.dot(b_q, b_h) - # [BT, BK] @ [BK, BT] -> [BT, BT] - b_A += tl.dot(b_q, b_k) - - if USE_G: - g += bos * H + i_h - p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_g = tl.load(p_g, boundary_check=(0,)) - b_o = b_o * exp(b_g)[:, None] - b_A = b_A * exp(b_g[:, None] - b_g[None, :]) - - o_t = i_t * BT + tl.arange(0, BT) - m_t = o_t < T - m_A = (o_t[:, None] >= o_t[None, :]) & (m_t[:, None] & m_t) - b_A = tl.where(m_A, b_A, 0) - - p_v = tl.make_block_ptr(v, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_o = tl.make_block_ptr(o, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_v = tl.load(p_v, boundary_check=(0, 1)) - - # to fix mma -> mma layout conversion - # already solved by triton v3.2 or higher - b_o = b_o * scale + tl.dot(b_A.to(b_v.dtype), b_v) * scale - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - -def chunk_fwd_o( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - h: torch.Tensor, - g: Optional[torch.Tensor] = None, # cumsum of log decay - scale: Optional[float] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, -) -> torch.Tensor: - B, T, Hg, K, V = *q.shape, v.shape[-1] - H = v.shape[-2] - if FLA_GDN_FIX_BT: - BT = 64 - else: - BT = min(chunk_size, max(16, triton.next_power_of_2(T))) - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - if scale is None: - scale = k.shape[-1] ** -0.5 - - o = torch.empty_like(v) - - def grid(meta): - return (triton.cdiv(V, meta["BV"]), NT, B * H) - - chunk_fwd_kernel_o[grid]( - q, - k, - v, - h, - g, - o, - cu_seqlens, - chunk_indices, - scale, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - ) - return o diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py deleted file mode 100644 index aa545e8ec..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/chunk_scaled_dot_kkt.py +++ /dev/null @@ -1,136 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .index import prepare_chunk_indices -from .op import exp - - -@triton.heuristics( - {"IS_VARLEN": lambda args: args["cu_seqlens"] is not None, "USE_G": lambda args: args["g_cumsum"] is not None} -) -@triton.autotune( - configs=[ - triton.Config({"BK": BK}, num_warps=num_warps, num_stages=num_stages) - for BK in [32, 64, 128] - for num_warps in [2, 4, 8] - for num_stages in [2, 3, 4] - ], - key=["H", "K", "BT", "IS_VARLEN"], -) -@triton.jit(do_not_specialize=["T"]) -def chunk_scaled_dot_kkt_fwd_kernel( - k, - beta, - g_cumsum, - A, - cu_seqlens, - chunk_indices, - T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - BT: tl.constexpr, - BK: tl.constexpr, - IS_VARLEN: tl.constexpr, - USE_G: tl.constexpr, -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - o_t = i_t * BT + tl.arange(0, BT) - m_t = o_t < T - - p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_beta = tl.load(p_beta, boundary_check=(0,)) - - b_A = tl.zeros([BT, BT], dtype=tl.float32) - for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr( - k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) - ) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_kb = b_k * b_beta[:, None] - b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) - - if USE_G: - p_g = tl.make_block_ptr(g_cumsum + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - b_g = tl.load(p_g, boundary_check=(0,)) - b_g_diff = b_g[:, None] - b_g[None, :] - b_A = b_A * exp(b_g_diff) - - m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) - b_A = tl.where(m_A, b_A, 0) - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - tl.store(p_A, b_A.to(p_A.dtype.element_ty), boundary_check=(0, 1)) - - -def chunk_scaled_dot_kkt_fwd( - k: torch.Tensor, - beta: torch.Tensor, - g_cumsum: Optional[torch.Tensor] = None, - cu_seqlens: Optional[torch.LongTensor] = None, - chunk_size: int = 64, - output_dtype: torch.dtype = torch.float32, -) -> torch.Tensor: - r""" - Compute beta * K * K^T. - - Args: - k (torch.Tensor): - The key tensor of shape `[B, T, H, K]`. - beta (torch.Tensor): - The beta tensor of shape `[B, T, H]`. - g_cumsum (torch.Tensor): - The cumulative sum of the gate tensor of shape `[B, T, H]`. - Default: None - cu_seqlens (torch.LongTensor): - The cumulative sequence lengths of the input tensor. - Default: None - chunk_size (int): - The chunk size. Default: 64. - output_dtype (torch.dtype): - The dtype of the output tensor. Default: `torch.float32` - - Returns: - beta * K * K^T of shape `[B, T, H, BT]` where `BT` is the chunk size. - """ - - B, T, Hg, K = k.shape - - H = beta.shape[-1] - BT = chunk_size - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - A = torch.empty(B, T, H, BT, device=k.device, dtype=output_dtype) - chunk_scaled_dot_kkt_fwd_kernel[(NT, B * H)]( - k=k, - beta=beta, - g_cumsum=g_cumsum, - A=A, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - Hg=Hg, - K=K, - BT=BT, - ) - return A diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py deleted file mode 100644 index 9cd6a6545..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/cumsum.py +++ /dev/null @@ -1,200 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -import warnings -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .index import prepare_chunk_indices -from .utils import check_shared_mem, input_guard - -BS_LIST = [32, 64] if check_shared_mem() else [16, 32] - - -@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8]], - key=["B", "H", "BT", "IS_VARLEN", "REVERSE"], -) -@triton.jit(do_not_specialize=["T"]) -def chunk_local_cumsum_scalar_kernel( - s, - o, - cu_seqlens, - chunk_indices, - T, - B: tl.constexpr, - H: tl.constexpr, - BT: tl.constexpr, - REVERSE: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr, -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - - if HEAD_FIRST: - p_s = tl.make_block_ptr(s + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) - p_o = tl.make_block_ptr(o + bos * H + i_h * T, (T,), (1,), (i_t * BT,), (BT,), (0,)) - else: - p_s = tl.make_block_ptr(s + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - p_o = tl.make_block_ptr(o + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - # [BT] - b_s = tl.load(p_s, boundary_check=(0,)).to(tl.float32) - b_o = tl.cumsum(b_s, axis=0) - if REVERSE: - b_z = tl.sum(b_s, axis=0) - b_o = -b_o + b_z[None] + b_s - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0,)) - - -@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[triton.Config({"BS": BS}, num_warps=num_warps) for BS in BS_LIST for num_warps in [2, 4, 8]], - key=["B", "H", "S", "BT", "IS_VARLEN", "REVERSE"], -) -@triton.jit(do_not_specialize=["T"]) -def chunk_local_cumsum_vector_kernel( - s, - o, - cu_seqlens, - chunk_indices, - T, - B: tl.constexpr, - H: tl.constexpr, - S: tl.constexpr, - BT: tl.constexpr, - BS: tl.constexpr, - REVERSE: tl.constexpr, - IS_VARLEN: tl.constexpr, - HEAD_FIRST: tl.constexpr, -): - i_s, i_t, i_bh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - - o_i = tl.arange(0, BT) - if REVERSE: - m_s = tl.where(o_i[:, None] <= o_i[None, :], 1.0, 0.0) - else: - m_s = tl.where(o_i[:, None] >= o_i[None, :], 1.0, 0.0) - - if HEAD_FIRST: - p_s = tl.make_block_ptr(s + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h * T) * S, (T, S), (S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - else: - p_s = tl.make_block_ptr(s + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - p_o = tl.make_block_ptr(o + (bos * H + i_h) * S, (T, S), (H * S, 1), (i_t * BT, i_s * BS), (BT, BS), (1, 0)) - # [BT, BS] - b_s = tl.load(p_s, boundary_check=(0, 1)).to(tl.float32) - b_o = tl.dot(m_s, b_s, allow_tf32=False) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), boundary_check=(0, 1)) - - -def chunk_local_cumsum_scalar( - g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, -) -> torch.Tensor: - if head_first: - B, H, T = g.shape - else: - B, T, H = g.shape - assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" - BT = chunk_size - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) - grid = (NT, B * H) - chunk_local_cumsum_scalar_kernel[grid]( - g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse - ) - return g - - -def chunk_local_cumsum_vector( - g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, -) -> torch.Tensor: - if head_first: - B, H, T, S = g.shape - else: - B, T, H, S = g.shape - BT = chunk_size - chunk_indices = prepare_chunk_indices(cu_seqlens, chunk_size) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - assert chunk_size == 2 ** (chunk_size.bit_length() - 1), "chunk_size must be a power of 2" - - g_org, g = g, torch.empty_like(g, dtype=output_dtype or g.dtype) - - def grid(meta): - return (triton.cdiv(meta["S"], meta["BS"]), NT, B * H) - - # keep cumulative normalizer in fp32 - # this kernel is equivalent to - # g = g.view(B, H, NT, BT, -1).cumsum(-2).view(B, H, T, -1) - chunk_local_cumsum_vector_kernel[grid]( - g_org, g, cu_seqlens, chunk_indices, T=T, B=B, H=H, S=S, BT=BT, HEAD_FIRST=head_first, REVERSE=reverse - ) - return g - - -@input_guard -def chunk_local_cumsum( - g: torch.Tensor, - chunk_size: int, - reverse: bool = False, - cu_seqlens: Optional[torch.Tensor] = None, - head_first: bool = False, - output_dtype: Optional[torch.dtype] = torch.float, - **kwargs, -) -> torch.Tensor: - if not head_first and g.shape[1] < g.shape[2]: - warnings.warn( - f"Input tensor shape suggests potential format mismatch" - f" seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2, - ) - if cu_seqlens is not None: - assert g.shape[0] == 1, "Only batch size 1 is supported when cu_seqlens are provided" - if len(g.shape) == 3: - return chunk_local_cumsum_scalar(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) - elif len(g.shape) == 4: - return chunk_local_cumsum_vector(g, chunk_size, reverse, cu_seqlens, head_first, output_dtype) - else: - raise ValueError( - f"Unsupported input shape {g.shape}. " - f"which should be (B, T, H, D) if `head_first=False` " - f"or (B, H, T, D) otherwise" - ) diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py deleted file mode 100644 index 4ff18d4f6..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/fused_recurrent.py +++ /dev/null @@ -1,367 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .op import exp - - -@triton.heuristics( - { - "USE_INITIAL_STATE": lambda args: args["h0"] is not None, - "IS_VARLEN": lambda args: args["cu_seqlens"] is not None, - "IS_CONTINUOUS_BATCHING": lambda args: args["ssm_state_indices"] is not None, - "IS_SPEC_DECODING": lambda args: args["num_accepted_tokens"] is not None, - } -) -@triton.jit(do_not_specialize=["N", "T"]) -def fused_recurrent_gated_delta_rule_fwd_kernel( - q, - k, - v, - g, - beta, - o, - h0, - ht, - cu_seqlens, - ssm_state_indices, - num_accepted_tokens, - scale, - N: tl.int64, # num of sequences - T: tl.int64, # num of tokens - B: tl.constexpr, - H: tl.constexpr, - HV: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - stride_init_state_token: tl.constexpr, - stride_final_state_token: tl.constexpr, - stride_indices_seq: tl.constexpr, - stride_indices_tok: tl.constexpr, - USE_INITIAL_STATE: tl.constexpr, # whether to use initial state - INPLACE_FINAL_STATE: tl.constexpr, # whether to store final state inplace - IS_BETA_HEADWISE: tl.constexpr, # whether beta is headwise vector or scalar, - USE_QK_L2NORM_IN_KERNEL: tl.constexpr, - IS_VARLEN: tl.constexpr, - IS_CONTINUOUS_BATCHING: tl.constexpr, - IS_SPEC_DECODING: tl.constexpr, -): - i_k, i_v, i_nh = tl.program_id(0), tl.program_id(1), tl.program_id(2) - i_n, i_hv = i_nh // HV, i_nh % HV - i_h = i_hv // (HV // H) - if IS_VARLEN: - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int64), tl.load(cu_seqlens + i_n + 1).to(tl.int64) - all = T - T = eos - bos - else: - bos, eos = i_n * T, i_n * T + T - all = B * T - - if T == 0: - # no tokens to process for this sequence - return - - o_k = i_k * BK + tl.arange(0, BK) - o_v = i_v * BV + tl.arange(0, BV) - - p_q = q + (bos * H + i_h) * K + o_k - p_k = k + (bos * H + i_h) * K + o_k - p_v = v + (bos * HV + i_hv) * V + o_v - if IS_BETA_HEADWISE: - p_beta = beta + (bos * HV + i_hv) * V + o_v - else: - p_beta = beta + bos * HV + i_hv - p_g = g + bos * HV + i_hv - p_o = o + ((i_k * all + bos) * HV + i_hv) * V + o_v - - mask_k = o_k < K - mask_v = o_v < V - mask_h = mask_k[:, None] & mask_v[None, :] - - b_h = tl.zeros([BK, BV], dtype=tl.float32) - if USE_INITIAL_STATE: - if IS_CONTINUOUS_BATCHING: - if IS_SPEC_DECODING: - i_t = tl.load(num_accepted_tokens + i_n).to(tl.int64) - 1 - else: - i_t = 0 - p_h0 = ( - h0 + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_init_state_token - ) - else: - p_h0 = h0 + bos * HV * K * V - p_h0 = p_h0 + i_hv * K * V + o_k[:, None] * V + o_v[None, :] - b_h += tl.load(p_h0, mask=mask_h, other=0).to(tl.float32) - - for i_t in range(0, T): - b_q = tl.load(p_q, mask=mask_k, other=0).to(tl.float32) - b_k = tl.load(p_k, mask=mask_k, other=0).to(tl.float32) - b_v = tl.load(p_v, mask=mask_v, other=0).to(tl.float32) - b_g = tl.load(p_g).to(tl.float32) - - if USE_QK_L2NORM_IN_KERNEL: - b_q = b_q / tl.sqrt(tl.sum(b_q * b_q) + 1e-6) - b_k = b_k / tl.sqrt(tl.sum(b_k * b_k) + 1e-6) - b_q = b_q * scale - # [BK, BV] - b_h *= exp(b_g) - # [BV] - b_v -= tl.sum(b_h * b_k[:, None], 0) - if IS_BETA_HEADWISE: - b_beta = tl.load(p_beta, mask=mask_v, other=0).to(tl.float32) - else: - b_beta = tl.load(p_beta).to(tl.float32) - b_v *= b_beta - # [BK, BV] - b_h += b_k[:, None] * b_v[None, :] - # [BV] - b_o = tl.sum(b_h * b_q[:, None], 0) - tl.store(p_o, b_o.to(p_o.dtype.element_ty), mask=mask_v) - - # keep the states for multi-query tokens - if INPLACE_FINAL_STATE: - p_ht = ( - ht + tl.load(ssm_state_indices + i_n * stride_indices_seq + i_t).to(tl.int64) * stride_final_state_token - ) - else: - p_ht = ht + (bos + i_t) * stride_final_state_token - p_ht = p_ht + i_hv * K * V + o_k[:, None] * V + o_v[None, :] - tl.store(p_ht, b_h.to(p_ht.dtype.element_ty), mask=mask_h) - - p_q += H * K - p_k += H * K - p_o += HV * V - p_v += HV * V - p_g += HV - p_beta += HV * (V if IS_BETA_HEADWISE else 1) - - -def fused_recurrent_gated_delta_rule_fwd( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - B, T, H, K, V = *k.shape, v.shape[-1] - HV = v.shape[2] - N = B if cu_seqlens is None else len(cu_seqlens) - 1 - BK, BV = triton.next_power_of_2(K), min(triton.next_power_of_2(V), 8) - NK, NV = triton.cdiv(K, BK), triton.cdiv(V, BV) - assert NK == 1, "NK > 1 is not supported yet" - num_stages = 3 - num_warps = 1 - - o = q.new_empty(NK, *v.shape) - if inplace_final_state: - final_state = initial_state - else: - final_state = q.new_empty(T, HV, K, V, dtype=initial_state.dtype) - - stride_init_state_token = initial_state.stride(0) - stride_final_state_token = final_state.stride(0) - - if ssm_state_indices is None: - stride_indices_seq, stride_indices_tok = 1, 1 - elif ssm_state_indices.ndim == 1: - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride(0), 1 - else: - stride_indices_seq, stride_indices_tok = ssm_state_indices.stride() - - grid = (NK, NV, N * HV) - fused_recurrent_gated_delta_rule_fwd_kernel[grid]( - q=q, - k=k, - v=v, - g=g, - beta=beta, - o=o, - h0=initial_state, - ht=final_state, - cu_seqlens=cu_seqlens, - ssm_state_indices=ssm_state_indices, - num_accepted_tokens=num_accepted_tokens, - scale=scale, - N=N, - T=T, - B=B, - H=H, - HV=HV, - K=K, - V=V, - BK=BK, - BV=BV, - stride_init_state_token=stride_init_state_token, - stride_final_state_token=stride_final_state_token, - stride_indices_seq=stride_indices_seq, - stride_indices_tok=stride_indices_tok, - IS_BETA_HEADWISE=beta.ndim == v.ndim, - USE_QK_L2NORM_IN_KERNEL=use_qk_l2norm_in_kernel, - INPLACE_FINAL_STATE=inplace_final_state, - num_warps=num_warps, - num_stages=num_stages, - ) - o = o.squeeze(0) - return o, final_state - - -class FusedRecurrentFunction(torch.autograd.Function): - @staticmethod - def forward( - ctx, - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor, - scale: float, - initial_state: torch.Tensor, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False, - ): - o, final_state = fused_recurrent_gated_delta_rule_fwd( - q=q.contiguous(), - k=k.contiguous(), - v=v.contiguous(), - g=g.contiguous(), - beta=beta.contiguous(), - scale=scale, - initial_state=initial_state, - inplace_final_state=inplace_final_state, - cu_seqlens=cu_seqlens, - ssm_state_indices=ssm_state_indices, - num_accepted_tokens=num_accepted_tokens, - use_qk_l2norm_in_kernel=use_qk_l2norm_in_kernel, - ) - - return o, final_state - - -def fused_recurrent_gated_delta_rule( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - g: torch.Tensor, - beta: torch.Tensor = None, - scale: float = None, - initial_state: torch.Tensor = None, - inplace_final_state: bool = True, - cu_seqlens: Optional[torch.LongTensor] = None, - ssm_state_indices: Optional[torch.Tensor] = None, - num_accepted_tokens: Optional[torch.Tensor] = None, - use_qk_l2norm_in_kernel: bool = False, -) -> tuple[torch.Tensor, torch.Tensor]: - r""" - Args: - q (torch.Tensor): - queries of shape `[B, T, H, K]`. - k (torch.Tensor): - keys of shape `[B, T, H, K]`. - v (torch.Tensor): - values of shape `[B, T, HV, V]`. - GVA is applied if `HV > H`. - g (torch.Tensor): - g (decays) of shape `[B, T, HV]`. - beta (torch.Tensor): - betas of shape `[B, T, HV]`. - scale (Optional[int]): - Scale factor for the RetNet attention scores. - If not provided, it will default to `1 / sqrt(K)`. Default: `None`. - initial_state (Optional[torch.Tensor]): - Initial state of shape `[N, HV, K, V]` for `N` input sequences. - For equal-length input sequences, `N` equals the batch size `B`. - Default: `None`. - inplace_final_state: bool: - Whether to store the final state in-place to save memory. - Default: `True`. - cu_seqlens (torch.LongTensor): - Cumulative sequence lengths of shape `[N+1]` used for variable-length training, - consistent with the FlashAttention API. - ssm_state_indices (Optional[torch.Tensor]): - Indices to map the input sequences to the initial/final states. - num_accepted_tokens (Optional[torch.Tensor]): - Number of accepted tokens for each sequence during decoding. - - Returns: - o (torch.Tensor): - Outputs of shape `[B, T, HV, V]`. - final_state (torch.Tensor): - Final state of shape `[N, HV, K, V]`. - - Examples:: - >>> import torch - >>> import torch.nn.functional as F - >>> from einops import rearrange - >>> from fla.ops.gated_delta_rule import fused_recurrent_gated_delta_rule - # inputs with equal lengths - >>> B, T, H, HV, K, V = 4, 2048, 4, 8, 512, 512 - >>> q = torch.randn(B, T, H, K, device='cuda') - >>> k = F.normalize(torch.randn(B, T, H, K, device='cuda'), p=2, dim=-1) - >>> v = torch.randn(B, T, HV, V, device='cuda') - >>> g = F.logsigmoid(torch.rand(B, T, HV, device='cuda')) - >>> beta = torch.rand(B, T, HV, device='cuda').sigmoid() - >>> h0 = torch.randn(B, HV, K, V, device='cuda') - >>> o, ht = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - ) - # for variable-length inputs, the batch size `B` is expected to be 1 and `cu_seqlens` is required - >>> q, k, v, g, beta = map(lambda x: rearrange(x, 'b t ... -> 1 (b t) ...'), (q, k, v, g, beta)) - # for a batch with 4 sequences, `cu_seqlens` with 5 start/end positions are expected - >>> cu_seqlens = q.new_tensor([0, 2048, 4096, 6144, 8192], dtype=torch.long) - >>> o_var, ht_var = fused_gated_recurrent_delta_rule( - q, k, v, g, beta, - initial_state=h0, - cu_seqlens=cu_seqlens - ) - """ - if cu_seqlens is not None and q.shape[0] != 1: - raise ValueError( - f"The batch size is expected to be 1 rather than {q.shape[0]} when using `cu_seqlens`." - f"Please flatten variable-length inputs before processing." - ) - if scale is None: - scale = k.shape[-1] ** -0.5 - else: - assert scale > 0, "scale must be positive" - if beta is None: - beta = torch.ones_like(q[..., 0]) - o, final_state = FusedRecurrentFunction.apply( - q, - k, - v, - g, - beta, - scale, - initial_state, - inplace_final_state, - cu_seqlens, - ssm_state_indices, - num_accepted_tokens, - use_qk_l2norm_in_kernel, - ) - return o, final_state diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/index.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/index.py deleted file mode 100644 index 8b1d59fc6..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/index.py +++ /dev/null @@ -1,30 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -import torch - -import triton - -from .utils import tensor_cache - - -@tensor_cache -def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor: - return cu_seqlens[1:] - cu_seqlens[:-1] - - -@tensor_cache -def prepare_chunk_indices(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: - indices = torch.cat([torch.arange(n) for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()]) - return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens) - - -@tensor_cache -def prepare_chunk_offsets(cu_seqlens: torch.LongTensor, chunk_size: int) -> torch.LongTensor: - return torch.cat([cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]).cumsum(-1) diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py deleted file mode 100644 index 7225cd4ae..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/l2norm.py +++ /dev/null @@ -1,137 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - -import os -from typing import Optional - -import torch - -import triton -import triton.language as tl - -BT_LIST = [8, 16, 32, 64, 128] - -USE_DEFAULT_FLA_NORM = int(os.getenv("USE_DEFAULT_FLA_NORM", "0")) - - -@triton.autotune(configs=[triton.Config({}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16, 32]], key=["D"]) -@triton.jit -def l2norm_fwd_kernel1( - x, - y, - D, - BD: tl.constexpr, - eps, -): - i_t = tl.program_id(0) - x += i_t * D - y += i_t * D - # Compute mean and variance - cols = tl.arange(0, BD) - mask = cols < D - b_x = tl.load(x + cols, mask=mask, other=0.0).to(tl.float32) - b_var = tl.sum(b_x * b_x, axis=0) - b_rstd = 1 / tl.sqrt(b_var + eps) - # tl.store(Rstd + i_t, rstd) - # Normalize and apply linear transformation - b_y = b_x * b_rstd - tl.store(y + cols, b_y, mask=mask) - - -@triton.autotune( - configs=[triton.Config({"BT": BT}, num_warps=num_warps) for num_warps in [1, 2, 4, 8, 16] for BT in BT_LIST], - key=["D"], -) -@triton.jit(do_not_specialize=["NB"]) -def l2norm_fwd_kernel( - x, - y, - eps, - NB, - T, - D: tl.constexpr, - BT: tl.constexpr, - BD: tl.constexpr, -): - i_t = tl.program_id(0) - p_x = tl.make_block_ptr(x, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) - b_x = tl.load(p_x, boundary_check=(0, 1)).to(tl.float32) - b_var = tl.sum(b_x * b_x, axis=1) - b_y = b_x / tl.sqrt(b_var + eps)[:, None] - p_y = tl.make_block_ptr(y, (T, D), (D, 1), (i_t * BT, 0), (BT, BD), (1, 0)) - tl.store(p_y, b_y.to(p_y.dtype.element_ty), boundary_check=(0, 1)) - - -@triton.jit -def l2norm_fwd_kernel2(X, Y, eps, M, N: tl.constexpr, MBLOCK: tl.constexpr): - xoffset = tl.program_id(0) * MBLOCK - row_idx = xoffset + tl.arange(0, MBLOCK)[:, None] - xmask = row_idx < M - rindex = tl.arange(0, N)[None, :] - xs = tl.load(X + (rindex + N * row_idx), xmask).to(tl.float32) - square = tl.broadcast_to(xs * xs, [MBLOCK, N]) - square_sum = tl.sum(tl.where(xmask, square, 0), 1)[:, None] - rsqrt = tl.rsqrt(square_sum + eps) - tl.store(Y + (rindex + N * row_idx), xs * rsqrt, xmask) - - -def l2norm_fwd(x: torch.Tensor, eps: float = 1e-6, output_dtype: Optional[torch.dtype] = None): - x_shape_og = x.shape - x = x.view(-1, x.shape[-1]) - # allocate output - if output_dtype is None: - y = torch.empty_like(x) - else: - y = torch.empty_like(x, dtype=output_dtype) - assert y.stride(-1) == 1 - T, D = x.shape[0], x.shape[-1] - # rstd = torch.empty((T,), dtype=torch.float32, device=x.device) - # Less than 64KB per feature: enqueue fused kernel - MAX_FUSED_SIZE = 65536 // x.element_size() - BD = min(MAX_FUSED_SIZE, triton.next_power_of_2(D)) - if D > BD: - raise RuntimeError("This layer doesn't support feature dim >= 64KB.") - - if not USE_DEFAULT_FLA_NORM: - MBLOCK = 32 - # M, N = x.shape - l2norm_fwd_kernel2[(triton.cdiv(T, MBLOCK),)]( - x, - y, - eps, - T, - D, - MBLOCK, - ) - else: - if D <= 512: - NB = triton.cdiv(T, 2048) - - def grid(meta): - return (triton.cdiv(T, meta["BT"]),) - - l2norm_fwd_kernel[grid]( - x, - y, - eps, - NB=NB, - T=T, - D=D, - BD=BD, - ) - else: - l2norm_fwd_kernel1[(T,)]( - x, - y, - eps=eps, - D=D, - BD=BD, - ) - - return y.view(x_shape_og) diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/op.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/op.py deleted file mode 100644 index ec0999455..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/op.py +++ /dev/null @@ -1,36 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - -import os - -import triton -import triton.language as tl - - -@triton.jit -def div_normal(x, y): - return x / y - - -div = div_normal -exp = tl.exp -log = tl.log -log2 = tl.log2 - - -if not hasattr(tl, "gather"): - - @triton.jit - def gather(src, index, axis, _builder=None): - # This is a fallback implementation when tl.gather is not supported - # In order to pass triton compiler, there is no actual gather operation - return src - -else: - gather = tl.gather diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py deleted file mode 100644 index 46e4d5082..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/solve_tril.py +++ /dev/null @@ -1,271 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .index import prepare_chunk_indices -from .utils import input_guard - - -@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4, 8] - for num_stages in [2, 3, 4, 5] - ], - key=["BT"], -) -@triton.jit(do_not_specialize=["T"]) -def solve_tril_16x16_kernel( - A, - Ad, - cu_seqlens, - chunk_indices, - T, - H: tl.constexpr, - BT: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - - A = A + (bos * H + i_h) * BT - Ad = Ad + (bos * H + i_h) * 16 - - offset = (i_t * 16) % BT - p_A = tl.make_block_ptr(A, (T, BT), (H * BT, 1), (i_t * 16, offset), (16, 16), (1, 0)) - p_Ai = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 16, 0), (16, 16), (1, 0)) - b_A = tl.load(p_A, boundary_check=(0, 1)).to(tl.float32) - b_A = -tl.where(tl.arange(0, 16)[:, None] > tl.arange(0, 16)[None, :], b_A, 0) - - o_i = tl.arange(0, 16) - for i in range(1, min(16, T - i_t * 16)): - b_a = -tl.load(A + (i_t * 16 + i) * H * BT + o_i + offset) - b_a = b_a + tl.sum(b_a[:, None] * b_A, 0) - mask = o_i == i - b_A = tl.where(mask[:, None], b_a, b_A) - b_A += o_i[:, None] == o_i[None, :] - tl.store(p_Ai, b_A.to(p_Ai.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - - -@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [1, 2, 4, 8] - for num_stages in [2, 3, 4, 5] - ], - key=["H", "BT", "IS_VARLEN"], -) -@triton.jit(do_not_specialize=["T"]) -def merge_16x16_to_32x32_inverse_kernel( - A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - - A += (bos * H + i_h) * 32 - Ad += (bos * H + i_h) * 16 - Ai += (bos * H + i_h) * 32 - - p_A_21 = tl.make_block_ptr(A, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32, 0), (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) - p_Ai_11 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32, 0), (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 16), (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 32), (H * 32, 1), (i_t * 32 + 16, 0), (16, 16), (1, 0)) - - A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) - Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) - Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") - tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - - -@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] - for num_stages in [2, 3, 4, 5] - ], - key=["H", "BT", "IS_VARLEN"], -) -@triton.jit(do_not_specialize=["T"]) -def merge_16x16_to_64x64_inverse_kernel( - A, Ad, Ai, cu_seqlens, chunk_indices, T, H: tl.constexpr, BT: tl.constexpr, IS_VARLEN: tl.constexpr -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - - A += (bos * H + i_h) * 64 - Ad += (bos * H + i_h) * 16 - Ai += (bos * H + i_h) * 64 - - p_A_21 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) - p_A_32 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) - p_A_31 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) - p_A_43 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) - p_A_42 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) - p_A_41 = tl.make_block_ptr(A, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) - p_Ad_11 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64, 0), (16, 16), (1, 0)) - p_Ad_22 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) - p_Ad_33 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) - p_Ad_44 = tl.make_block_ptr(Ad, (T, 16), (H * 16, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) - - A_21 = tl.load(p_A_21, boundary_check=(0, 1)).to(tl.float32) - A_32 = tl.load(p_A_32, boundary_check=(0, 1)).to(tl.float32) - A_31 = tl.load(p_A_31, boundary_check=(0, 1)).to(tl.float32) - A_43 = tl.load(p_A_43, boundary_check=(0, 1)).to(tl.float32) - A_42 = tl.load(p_A_42, boundary_check=(0, 1)).to(tl.float32) - A_41 = tl.load(p_A_41, boundary_check=(0, 1)).to(tl.float32) - - Ai_11 = tl.load(p_Ad_11, boundary_check=(0, 1)).to(tl.float32) - Ai_22 = tl.load(p_Ad_22, boundary_check=(0, 1)).to(tl.float32) - Ai_33 = tl.load(p_Ad_33, boundary_check=(0, 1)).to(tl.float32) - Ai_44 = tl.load(p_Ad_44, boundary_check=(0, 1)).to(tl.float32) - - Ai_21 = -tl.dot(tl.dot(Ai_22, A_21, input_precision="ieee"), Ai_11, input_precision="ieee") - Ai_32 = -tl.dot(tl.dot(Ai_33, A_32, input_precision="ieee"), Ai_22, input_precision="ieee") - Ai_43 = -tl.dot(tl.dot(Ai_44, A_43, input_precision="ieee"), Ai_33, input_precision="ieee") - - Ai_31 = -tl.dot( - Ai_33, - tl.dot(A_31, Ai_11, input_precision="ieee") + tl.dot(A_32, Ai_21, input_precision="ieee"), - input_precision="ieee", - ) - Ai_42 = -tl.dot( - Ai_44, - tl.dot(A_42, Ai_22, input_precision="ieee") + tl.dot(A_43, Ai_32, input_precision="ieee"), - input_precision="ieee", - ) - Ai_41 = -tl.dot( - Ai_44, - tl.dot(A_41, Ai_11, input_precision="ieee") - + tl.dot(A_42, Ai_21, input_precision="ieee") - + tl.dot(A_43, Ai_31, input_precision="ieee"), - input_precision="ieee", - ) - - p_Ai_11 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 0), (16, 16), (1, 0)) - p_Ai_22 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 16), (16, 16), (1, 0)) - p_Ai_33 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 32), (16, 16), (1, 0)) - p_Ai_44 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 48), (16, 16), (1, 0)) - p_Ai_21 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 0), (16, 16), (1, 0)) - p_Ai_31 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 0), (16, 16), (1, 0)) - p_Ai_32 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 16), (16, 16), (1, 0)) - p_Ai_41 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 0), (16, 16), (1, 0)) - p_Ai_42 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 16), (16, 16), (1, 0)) - p_Ai_43 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 48, 32), (16, 16), (1, 0)) - tl.store(p_Ai_11, Ai_11.to(p_Ai_11.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_22, Ai_22.to(p_Ai_22.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_33, Ai_33.to(p_Ai_33.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_44, Ai_44.to(p_Ai_44.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_21, Ai_21.to(p_Ai_21.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_31, Ai_31.to(p_Ai_31.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_32, Ai_32.to(p_Ai_32.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_41, Ai_41.to(p_Ai_41.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_42, Ai_42.to(p_Ai_42.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_43, Ai_43.to(p_Ai_43.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - - fill_zeros = tl.zeros((16, 16), dtype=tl.float32) - p_Ai_12 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 16), (16, 16), (1, 0)) - p_Ai_13 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 32), (16, 16), (1, 0)) - p_Ai_14 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64, 48), (16, 16), (1, 0)) - p_Ai_23 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 32), (16, 16), (1, 0)) - p_Ai_24 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 16, 48), (16, 16), (1, 0)) - p_Ai_34 = tl.make_block_ptr(Ai, (T, 64), (H * 64, 1), (i_t * 64 + 32, 48), (16, 16), (1, 0)) - tl.store(p_Ai_12, fill_zeros.to(p_Ai_12.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_13, fill_zeros.to(p_Ai_13.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_14, fill_zeros.to(p_Ai_14.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_23, fill_zeros.to(p_Ai_23.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_24, fill_zeros.to(p_Ai_24.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - tl.store(p_Ai_34, fill_zeros.to(p_Ai_34.dtype.element_ty, fp_downcast_rounding="rtne"), boundary_check=(0, 1)) - - -@input_guard -def solve_tril( - A: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, output_dtype: torch.dtype = torch.float -) -> torch.Tensor: - """ - Compute the inverse of the lower triangular matrix - A should be strictly lower triangular, i.e., A.triu() == 0. - - Args: - A (torch.Tensor): - [B, T, H, K] - cu_seqlens (torch.Tensor): - The cumulative sequence lengths of the input tensor. - Default: None. - output_dtype (torch.dtype): - The dtype of the output tensor. Default: `torch.float` - - Returns: - (I + A)^-1 with the same shape as A - """ - assert A.shape[-1] in [16, 32, 64] - - B, T, H, BT = A.shape - Ad = torch.empty(B, T, H, 16, device=A.device, dtype=torch.float if BT != 16 else output_dtype) - - chunk_indices = prepare_chunk_indices(cu_seqlens, 16) if cu_seqlens is not None else None - NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, 16) - solve_tril_16x16_kernel[NT, B * H]( - A=A, - Ad=Ad, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - BT=BT, - ) - if BT == 16: - return Ad - - Ai = torch.empty(B, T, H, BT, device=A.device, dtype=output_dtype) - merge_fn = merge_16x16_to_32x32_inverse_kernel if BT == 32 else merge_16x16_to_64x64_inverse_kernel - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - NT = len(chunk_indices) if cu_seqlens is not None else triton.cdiv(T, BT) - merge_fn[NT, B * H]( - A=A, - Ad=Ad, - Ai=Ai, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - BT=BT, - ) - return Ai diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py deleted file mode 100644 index d8f29f287..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/utils.py +++ /dev/null @@ -1,173 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang -# ruff: noqa: E501 -import contextlib -import functools -import logging -import os -from enum import Enum -from typing import Any, Callable, Literal, Optional - -import torch - -import triton - -logger = logging.getLogger(__name__) - -COMPILER_MODE = os.getenv("FLA_COMPILER_MODE") == "1" -FLA_CI_ENV = os.getenv("FLA_CI_ENV") == "1" -FLA_GDN_FIX_BT = os.getenv("FLA_GDN_FIX_BT", "0") == "1" - -SUPPRESS_LEVEL = int(os.getenv("GDN_RECOMPUTE_SUPPRESS_LEVEL", "0")) - - -def tensor_cache(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: - """ - A decorator that caches the most recent results of a function with tensor inputs. - - This decorator will store the output of the decorated function for the most recent set of input tensors. - The cache is limited to a fixed size (default is 4). When the cache is full, the oldest entry will be removed. - - Args: - fn (Callable[..., torch.Tensor]): - The function to be decorated. It should take tensor inputs and return tensor outputs. - - Returns: - Callable[..., torch.Tensor]: - A wrapped version of the input function with single-entry caching. - """ - - cache_entries: tuple[Optional[tuple], Optional[dict], Any] = [] - cache_size = 4 - - @functools.wraps(fn) - def wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal cache_entries, cache_size - for i, entry in enumerate(cache_entries): - last_args, last_kwargs, last_result = entry - if ( - len(args) == len(last_args) - and len(kwargs) == len(last_kwargs) - and all(a is b for a, b in zip(args, last_args)) - and all(k in last_kwargs and v is last_kwargs[k] for k, v in kwargs.items()) - ): - cache_entries = cache_entries[:i] + cache_entries[i + 1 :] + [(args, kwargs, last_result)] - return last_result - - result = fn(*args, **kwargs) - - if len(cache_entries) >= cache_size: - cache_entries = cache_entries[1:] - cache_entries.append((args, kwargs, result)) - return result - - return wrapper - - -def input_guard(fn: Callable[..., torch.Tensor]) -> Callable[..., torch.Tensor]: - """ - A decorator to make sure all input tensors are contiguous and set the device based on input tensors. - """ - - @functools.wraps(fn) - def wrapper(*args, **kwargs): - contiguous_args = (i if not isinstance(i, torch.Tensor) else i.contiguous() for i in args) - contiguous_kwargs = {k: (v if not isinstance(v, torch.Tensor) else v.contiguous()) for k, v in kwargs.items()} - - tensor = None - for arg in args: - if isinstance(arg, torch.Tensor): - tensor = arg - break - if tensor is None: - for value in kwargs.values(): - if isinstance(value, torch.Tensor): - tensor = value - break - - if tensor is not None: - ctx = torch.cuda.device(tensor.device.index) - else: - ctx = contextlib.nullcontext() - - with ctx: - return fn(*contiguous_args, **contiguous_kwargs) - - return wrapper - - -@functools.cache -def get_available_device() -> str: - try: - return triton.runtime.driver.active.get_current_target().backend - except BaseException: - return "cpu" - - -@functools.cache -def _check_platform() -> Literal["nvidia", "amd", "intel", "musa"]: - device = get_available_device() - mapping = { - "cuda": "nvidia", - "hip": "amd", - "xpu": "intel", - } - # return the mapped value, or the original if not found - return mapping.get(device, device) - - -# For AMD GPUs, the triton backend is 'hip', while for Nvidia GPUs, the triton backend is 'cuda'. -# However, the torch backend is 'cuda' for both Nvidia and AMD GPUs. -# Therefore, we need to check the triton backend to determine the actual GPU vendor. -device = get_available_device() if get_available_device() != "hip" else "cuda" -device_torch_lib = getattr(torch, device) -device_platform = _check_platform() - -is_amd = device_platform == "amd" -is_intel = device_platform == "intel" -is_nvidia = device_platform == "nvidia" -is_intel_alchemist = is_intel and "Intel(R) Arc(TM) A" in torch.xpu.get_device_name(0) -is_nvidia_hopper = is_nvidia and ( - "NVIDIA H" in torch.cuda.get_device_name(0) or torch.cuda.get_device_capability()[0] >= 9 -) -use_cuda_graph = is_nvidia and os.environ.get("FLA_USE_CUDA_GRAPH", "0") == "1" - - -def get_all_max_shared_mem(): - try: - return [ - triton.runtime.driver.active.utils.get_device_properties(i)["max_shared_mem"] - for i in range(device_torch_lib.device_count()) - ] - except BaseException: - return [-1] - - -class Backend(Enum): - ADA = 101376 # RTX 4090 - AMPERE = 166912 # A100 - HOPPER = 232448 # H100 - DEFAULT = 102400 # Default - - @classmethod - def get_shared_memory(cls, arch: str) -> int: - try: - return cls[arch.upper()].value - except KeyError: - return cls.DEFAULT.value - - -@functools.cache -def check_shared_mem(arch: str = "none", tensor_idx: int = 0) -> bool: - try: - device_shared_mem_list = get_all_max_shared_mem() - max_shared_memory = device_shared_mem_list[tensor_idx] - return max_shared_memory >= Backend.get_shared_memory(arch) - except Exception: - return False diff --git a/lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py b/lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py deleted file mode 100644 index dec8d2ffc..000000000 --- a/lightllm/models/qwen3next/triton_kernel/fla_bak/wy_fast.py +++ /dev/null @@ -1,122 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project -# SPDX-FileCopyrightText: Songlin Yang, Yu Zhang -# -# This file contains code copied from the flash-linear-attention project. -# The original source code was licensed under the MIT license and included -# the following copyright notice: -# Copyright (c) 2023-2025, Songlin Yang, Yu Zhang - -# ruff: noqa: E501 -from typing import Optional - -import torch - -import triton -import triton.language as tl - -from .index import prepare_chunk_indices - - -@triton.heuristics({"IS_VARLEN": lambda args: args["cu_seqlens"] is not None}) -@triton.autotune( - configs=[ - triton.Config({}, num_warps=num_warps, num_stages=num_stages) - for num_warps in [2, 4, 8] - for num_stages in [2, 3, 4] - ], - key=["H", "K", "V", "BT", "BK", "BV", "IS_VARLEN"], -) -@triton.jit(do_not_specialize=["T"]) -def recompute_w_u_fwd_kernel( - k, - v, - beta, - w, - u, - A, - g, - cu_seqlens, - chunk_indices, - T, - H: tl.constexpr, - Hg: tl.constexpr, - K: tl.constexpr, - V: tl.constexpr, - BT: tl.constexpr, - BK: tl.constexpr, - BV: tl.constexpr, - IS_VARLEN: tl.constexpr, -): - i_t, i_bh = tl.program_id(0), tl.program_id(1) - i_b, i_h = i_bh // H, i_bh % H - if IS_VARLEN: - i_n, i_t = tl.load(chunk_indices + i_t * 2).to(tl.int32), tl.load(chunk_indices + i_t * 2 + 1).to(tl.int32) - bos, eos = tl.load(cu_seqlens + i_n).to(tl.int32), tl.load(cu_seqlens + i_n + 1).to(tl.int32) - T = eos - bos - else: - bos, eos = i_b * T, i_b * T + T - p_beta = tl.make_block_ptr(beta + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) - p_g = tl.make_block_ptr(g + (bos * H + i_h), (T,), (H,), (i_t * BT,), (BT,), (0,)) - p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (H * BT, 1), (i_t * BT, 0), (BT, BT), (1, 0)) - b_beta = tl.load(p_beta, boundary_check=(0,)) - b_A = tl.load(p_A, boundary_check=(0, 1)) - b_g = tl.exp(tl.load(p_g, boundary_check=(0,))) - - for i_v in range(tl.cdiv(V, BV)): - p_v = tl.make_block_ptr(v + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - p_u = tl.make_block_ptr(u + (bos * H + i_h) * V, (T, V), (H * V, 1), (i_t * BT, i_v * BV), (BT, BV), (1, 0)) - b_v = tl.load(p_v, boundary_check=(0, 1)) - b_vb = (b_v * b_beta[:, None]).to(b_v.dtype) - b_u = tl.dot(b_A, b_vb, allow_tf32=False) - tl.store(p_u, b_u.to(p_u.dtype.element_ty), boundary_check=(0, 1)) - - for i_k in range(tl.cdiv(K, BK)): - p_k = tl.make_block_ptr( - k + (bos * Hg + i_h // (H // Hg)) * K, (T, K), (Hg * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0) - ) - p_w = tl.make_block_ptr(w + (bos * H + i_h) * K, (T, K), (H * K, 1), (i_t * BT, i_k * BK), (BT, BK), (1, 0)) - b_k = tl.load(p_k, boundary_check=(0, 1)) - b_kb = (b_k * b_beta[:, None] * b_g[:, None]).to(b_k.dtype) - b_w = tl.dot(b_A, b_kb) - tl.store(p_w, b_w.to(p_w.dtype.element_ty), boundary_check=(0, 1)) - - -def recompute_w_u_fwd( - k: torch.Tensor, - v: torch.Tensor, - beta: torch.Tensor, - g_cumsum: torch.Tensor, - A: torch.Tensor, - cu_seqlens: Optional[torch.LongTensor], -) -> tuple[torch.Tensor, torch.Tensor]: - B, T, Hg, K, V = *k.shape, v.shape[-1] - H = v.shape[-2] - BT = A.shape[-1] - - chunk_indices = prepare_chunk_indices(cu_seqlens, BT) if cu_seqlens is not None else None - NT = triton.cdiv(T, BT) if cu_seqlens is None else len(chunk_indices) - BK = 64 - BV = 64 - u = torch.empty_like(v) - w = k.new_empty(B, T, H, K) - recompute_w_u_fwd_kernel[(NT, B * H)]( - k=k, - v=v, - beta=beta, - w=w, - u=u, - A=A, - g=g_cumsum, - cu_seqlens=cu_seqlens, - chunk_indices=chunk_indices, - T=T, - H=H, - Hg=Hg, - K=K, - V=V, - BT=BT, - BK=BK, - BV=BV, - ) - return w, u diff --git a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py index ffe68a637..e1a112c5a 100644 --- a/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py +++ b/lightllm/models/qwen3next/triton_kernel/fused_gdn_gating.py @@ -1,5 +1,5 @@ # Adapted from https://github.com/sgl-project/sglang/python/sglang/srt/layers/attention/fla/fused_gdn_gating.py -from typing import Tuple +from typing import Optional, Tuple import torch import triton @@ -61,15 +61,15 @@ def fused_gdn_gating( dt_bias: torch.Tensor, beta: float = 1.0, threshold: float = 20.0, - run_config: dict = None, + run_config: Optional[dict] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: - if not run_config: + if run_config is None: run_config = {"BLK_HEADS": 8, "num_warps": 1} batch, num_heads = a.shape seq_len = 1 - grid = (batch, seq_len, triton.cdiv(num_heads, 8)) + grid = (batch, seq_len, triton.cdiv(num_heads, run_config["BLK_HEADS"])) g = torch.empty(1, batch, num_heads, dtype=torch.float32, device=a.device) beta_output = torch.empty(1, batch, num_heads, dtype=torch.float32, device=b.device) fused_gdn_gating_kernel[grid]( diff --git a/lightllm/server/api_cli.py b/lightllm/server/api_cli.py index 49496fbd4..4f0c3cace 100644 --- a/lightllm/server/api_cli.py +++ b/lightllm/server/api_cli.py @@ -572,5 +572,12 @@ def make_argument_parser() -> argparse.ArgumentParser: default=False, help="""Enable prefix prompt cache fetch for data parallel inference, disabled by default.""", ) - parser.add_argument("--linear_attn_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument("--mamba_cache_size", type=int, default=3000, help="""The size of linear attn cache. """) + parser.add_argument( + "--mamba_ssm_data_type", + type=str, + choices=["bfloat16", "float32"], + default="float32", + help="the data type of the model weight", + ) return parser diff --git a/lightllm/server/core/objs/start_args_type.py b/lightllm/server/core/objs/start_args_type.py index 79fc14dd7..6a4ca401d 100644 --- a/lightllm/server/core/objs/start_args_type.py +++ b/lightllm/server/core/objs/start_args_type.py @@ -149,6 +149,5 @@ class StartArgs: disable_custom_allreduce: bool = field(default=False) # hybrid attention model - linear_attn_cache_size: int = field(default=2000) - - weight_version: str = "default" + mamba_cache_size: int = field(default=2000) + mamba_ssm_data_type: Optional[str] = field(default="float32", metadata={"choices": ["bfloat16", "float32"]}) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 225821139..95bdcc5e1 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -82,9 +82,9 @@ def insert_for_hybrid_radix_cache(self, reqs): input_token_ids = req.get_input_token_ids() key = torch.tensor(input_token_ids[0 : req.cur_kv_len], dtype=torch.int64, device="cpu") value = g_infer_context.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].cpu() - + cur_buffer_idx = g_infer_context.req_manager.req_to_buffer_indexes[req.req_idx] # 分配新的 buffer 并复制当前 buffer 的内容 - self.mem_manager.copy_buffer(req.buffer_idx, new_buffer_indexes[i]) + self.mem_manager.copy_buffer(cur_buffer_idx, new_buffer_indexes[i]) _, new_shared_kv_node = self.insert(key, value) new_shared_kv_node.buffer_idx = new_buffer_indexes[i] diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index f05000742..d4e4ea38b 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -114,7 +114,8 @@ def free_a_req_mem(self, free_token_index: List, req: "InferReq"): value = self.req_manager.req_to_token_indexs[req.req_idx][: req.cur_kv_len].detach().cpu() prefix_len, node = self.radix_cache.insert(key, value) - node.buffer_idx = req.buffer_idx + if hasattr(self.req_manager, "req_to_buffer_indexes"): + node.buffer_idx = self.req_manager.req_to_buffer_indexes[req.req_idx] old_prefix_len = 0 if req.shared_kv_node is None else req.shared_kv_node.node_prefix_total_len free_token_index.append(self.req_manager.req_to_token_indexs[req.req_idx][old_prefix_len:prefix_len]) if req.shared_kv_node is not None: @@ -348,10 +349,6 @@ def __init__( self.nixl_pd_task_failed_num: int = 0 self.nixl_trans_device_id: int = -1 - # 可以用于请求在整个生命周期维护单一大小的buffer的场景 - # 例如混合注意力模型 Qwen3Next - self.buffer_idx = -1 - # 在开启 enable_cpu_cache 的情况下,当请求结束后,会将请求的 kv cache # 卸载到 cpu cache 中,该标志变量用于标记请求的卸载任务的状态 self.cpu_cache_task_status: "InferReq._CpuCacheTaskStatus" = InferReq._CpuCacheTaskStatus.NOT_STARTED diff --git a/test/test_api/test_gsmk.py b/test/test_api/test_gsmk.py new file mode 100644 index 000000000..866dd0f01 --- /dev/null +++ b/test/test_api/test_gsmk.py @@ -0,0 +1,230 @@ +# Adapted from https://github.com/sgl-project/sglang/blob/main/benchmark/gsm8k/bench_other.py +import argparse +import ast +import json +import os +import re +import time +from concurrent.futures import ThreadPoolExecutor +from typing import Optional + +import numpy as np +import requests +from tqdm import tqdm + +INVALID = -9999999 + + +def read_jsonl(filename: str): + """Read a JSONL file.""" + with open(filename) as fin: + for line in fin: + if line.startswith("#"): + continue + yield json.loads(line) + + +def dump_state_text(filename: str, states: list, mode: str = "w"): + """Dump program state in a text file.""" + with open(filename, mode) as fout: + for i, s in enumerate(states): + if isinstance(s, str): + fout.write(f"==== {i} ====\n{s}\n") + else: + fout.write(f"==== {i} ====\n{str(s)}\n") + + +def download_and_cache_file(url: str, filename: Optional[str] = None): + """Read and cache a file from a url.""" + if filename is None: + filename = os.path.join("/tmp", url.split("/")[-1]) + + # Check if the cache file already exists + if os.path.exists(filename): + return filename + + print(f"Downloading from {url} to {filename}") + + # Stream the response to show the progress bar + response = requests.get(url, stream=True) + response.raise_for_status() # Check for request errors + + # Total size of the file in bytes + total_size = int(response.headers.get("content-length", 0)) + chunk_size = 1024 # Download in chunks of 1KB + + # Use tqdm to display the progress bar + with open(filename, "wb") as file, tqdm( + desc="Downloading", + total=total_size, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for chunk in response.iter_content(chunk_size=chunk_size): + size = file.write(chunk) + bar.update(size) + + return filename + + +def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): + """Call LightLLM API for text generation.""" + assert url is not None + + data = { + "inputs": prompt, + "parameters": { + "temperature": temperature, + "max_new_tokens": max_tokens, + "stop_sequences": stop, + }, + } + res = requests.post(url, json=data) + assert res.status_code == 200, f"API request failed with status code {res.status_code}: {res.text}" + + response_json = res.json() + if "generated_text" not in response_json: + raise ValueError(f"Invalid API response format. Expected 'generated_text' key, got: {response_json.keys()}") + if not isinstance(response_json["generated_text"], list) or len(response_json["generated_text"]) == 0: + raise ValueError( + "Invalid API response format. 'generated_text' should be a non-empty list, " + f"got: {response_json['generated_text']}" + ) + + pred = response_json["generated_text"][0] + return pred + + +def get_one_example(lines, i, include_answer): + ret = "Question: " + lines[i]["question"] + "\nAnswer:" + if include_answer: + ret += " " + lines[i]["answer"] + return ret + + +def get_few_shot_examples(lines, k): + ret = "" + for i in range(k): + ret += get_one_example(lines, i, True) + "\n\n" + return ret + + +def get_answer_value(answer_str): + answer_str = answer_str.replace(",", "") + numbers = re.findall(r"\d+", answer_str) + if len(numbers) < 1: + return INVALID + try: + return ast.literal_eval(numbers[-1]) + except SyntaxError: + return INVALID + + +def parse_args(): + """Parse command line arguments.""" + parser = argparse.ArgumentParser() + parser.add_argument("--parallel", type=int, default=256) + parser.add_argument("--host", type=str, default="http://127.0.0.1") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--num-shots", type=int, default=5) + parser.add_argument("--num-questions", type=int, default=200) + parser.add_argument("--result-file", type=str, default="result.jsonl") + parser.add_argument("--data-path", type=str, default="test.jsonl") + return parser.parse_args() + + +def main(args): + # LightLLM API URL + url = f"{args.host}:{args.port}/generate" + + # Read data + url_data = "https://raw.githubusercontent.com/openai/grade-school-math/master/grade_school_math/data/test.jsonl" + filename = download_and_cache_file(url_data) + lines = list(read_jsonl(filename)) + + # Construct prompts + num_questions = args.num_questions + num_shots = args.num_shots + few_shot_examples = get_few_shot_examples(lines, num_shots) + + # Ensure we have enough samples and avoid data leakage + # Test questions should start after few-shot examples + max_available = len(lines) - num_shots + if num_questions > max_available: + print( + "Warning: Requested {} questions, but only {} available after reserving {} for few-shot. " + "Using {} questions.".format(num_questions, max_available, num_shots, max_available) + ) + num_questions = max_available + + questions = [] + labels = [] + for i in range(num_shots, num_shots + num_questions): + questions.append(get_one_example(lines, i, False)) + labels.append(get_answer_value(lines[i]["answer"])) + assert all(label != INVALID for label in labels) + + states = [None] * len(labels) + + # Run requests using thread pool + def get_one_answer(i): + answer = call_generate_lightllm( + prompt=few_shot_examples + questions[i], + temperature=0, + max_tokens=1024, + stop=["Question", "Assistant:", "<|separator|>"], + url=url, + ) + states[i] = answer + + tic = time.perf_counter() + if args.parallel == 1: + for i in tqdm(range(len(questions))): + get_one_answer(i) + else: + with ThreadPoolExecutor(args.parallel) as executor: + list( + tqdm( + executor.map(get_one_answer, list(range(len(questions)))), + total=len(questions), + ) + ) + + latency = time.perf_counter() - tic + + preds = [] + for i in range(len(states)): + preds.append(get_answer_value(states[i])) + + # Compute accuracy + acc = np.mean(np.array(preds) == np.array(labels)) + invalid = np.mean(np.array(preds) == INVALID) + + # Print results + print(f"Accuracy: {acc:.3f}") + print(f"Invalid: {invalid:.3f}") + print(f"Latency: {latency:.3f} s") + + # Dump results + dump_state_text("tmp_output_lightllm.txt", states) + + with open(args.result_file, "a") as fout: + value = { + "task": "gsm8k", + "backend": "lightllm", + "num_gpus": 1, + "latency": round(latency, 3), + "accuracy": round(acc, 3), + "num_requests": args.num_questions, + "other": { + "num_questions": args.num_questions, + "parallel": args.parallel, + }, + } + fout.write(json.dumps(value) + "\n") + + +if __name__ == "__main__": + args = parse_args() + main(args) From ea8f30bdf9bf05b4e84f5c9ba57340f1f0b1dd3f Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 12 Dec 2025 15:58:03 +0800 Subject: [PATCH 15/19] try fix --- .../qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py index 32a87f7a0..8f561cfa4 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -74,8 +74,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( (1, 0), ) b_k = tl.load(p_k, boundary_check=(0, 1)) - b_kb = b_k * b_beta[:, None] - b_A += tl.dot(b_kb.to(b_k.dtype), tl.trans(b_k)) + b_A += tl.dot(b_k, tl.trans(b_k)) if USE_G: p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) @@ -83,6 +82,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( b_g_diff = b_g[:, None] - b_g[None, :] b_A = b_A * exp(b_g_diff) + b_A *= b_beta[:, None] m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) b_A = tl.where(m_A, b_A, 0) p_A = tl.make_block_ptr(A + (bos * H + i_h) * BT, (T, BT), (BT * H, 1), (i_t * BT, 0), (BT, BT), (1, 0)) From 0f433c0291e6ee65fed5817ca3ed36fe32fd05cc Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 12 Dec 2025 08:27:39 +0000 Subject: [PATCH 16/19] clean code --- .../basemodel/layer_weights/hf_load_utils.py | 2 +- lightllm/common/req_manager.py | 1 - ...VARLEN=true,REVERSE=false}_NVIDIA_H200.json | 3 +++ ...=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json | 4 ++++ ...16,x_dtype=torch.bfloat16}_NVIDIA_H200.json | 4 ++++ .../{topk_num=10}_NVIDIA_H200.json | 4 ++++ ...,out_dtype=torch.bfloat16}_NVIDIA_H200.json | 6 ++++++ lightllm/common/triton_utils/autotuner.py | 2 +- .../layer_infer/transformer_layer_infer.py | 4 ++-- lightllm/models/qwen3next/model.py | 2 +- lightllm/server/api_http.py | 18 +----------------- lightllm/server/router/manager.py | 3 +-- .../server/router/model_infer/infer_batch.py | 3 --- .../model_infer/mode_backend/base_backend.py | 1 - .../mode_backend/chunked_prefill/impl.py | 1 - lightllm/utils/device_utils.py | 2 +- lightllm/utils/log_utils.py | 13 ------------- 17 files changed, 29 insertions(+), 44 deletions(-) diff --git a/lightllm/common/basemodel/layer_weights/hf_load_utils.py b/lightllm/common/basemodel/layer_weights/hf_load_utils.py index bb2d9aec4..8cf66a5ad 100755 --- a/lightllm/common/basemodel/layer_weights/hf_load_utils.py +++ b/lightllm/common/basemodel/layer_weights/hf_load_utils.py @@ -60,7 +60,7 @@ def load_hf_weights(data_type, weight_dir, pre_post_layer=None, transformer_laye transformer_layer_list=transformer_layer_list, weight_dir=weight_dir, ) # noqa - worker = int(os.environ.get("LOADWORKER", 16)) + worker = int(os.environ.get("LOADWORKER", 1)) with Pool(worker) as p: iterator = p.imap_unordered(partial_func, candidate_files, chunksize=1) desc_str = f"pid {os.getpid()} Loading model weights with {worker} workers" diff --git a/lightllm/common/req_manager.py b/lightllm/common/req_manager.py index 572191089..40c8aa993 100644 --- a/lightllm/common/req_manager.py +++ b/lightllm/common/req_manager.py @@ -7,7 +7,6 @@ from lightllm.common.basemodel.triton_kernel.gen_sampling_params import update_req_to_token_id_counter from lightllm.utils.envs_utils import enable_env_vars, get_env_start_args from lightllm.utils.config_utils import get_vocab_size -from lightllm.server.router.dynamic_prompt.hybrid_radix_cache import HybridMemManager logger = init_logger(__name__) diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json index 412046d09..354a6f93a 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/chunk_local_cumsum_scalar/{B=1,BT=64,H=8,IS_VARLEN=true,REVERSE=false}_NVIDIA_H200.json @@ -14,6 +14,9 @@ "16": { "num_warps": 4 }, + "164096": { + "num_warps": 1 + }, "2048": { "num_warps": 2 }, diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json index e6e34fdbe..d00af04ca 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/fused_gdn_gating:v1/{NUM_HEADS=8,a_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -19,6 +19,10 @@ "BLK_HEADS": 8, "num_warps": 2 }, + "164096": { + "BLK_HEADS": 8, + "num_warps": 1 + }, "2048": { "BLK_HEADS": 16, "num_warps": 1 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json index 316fb7678..84c47d348 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/gated_rmsnorm_forward:v1/{N=128,has_bias=false,weight_dtype=torch.bfloat16,x_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -7,6 +7,10 @@ "BLOCK_N": 256, "num_warps": 1 }, + "1312768": { + "BLOCK_N": 64, + "num_warps": 2 + }, "16384": { "BLOCK_N": 128, "num_warps": 1 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json index 65c618475..5923f3164 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/moe_align_fused:v1/{topk_num=10}_NVIDIA_H200.json @@ -31,6 +31,10 @@ "BLOCK_SIZE": 128, "num_warps": 4 }, + "32768": { + "BLOCK_SIZE": 256, + "num_warps": 8 + }, "4096": { "BLOCK_SIZE": 128, "num_warps": 8 diff --git a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json index 4c0fdb9d2..0b3aa1e36 100644 --- a/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json +++ b/lightllm/common/triton_utils/autotune_kernel_configs/triton_3.4.0/NVIDIA_H200/silu_and_mul_fwd:v1/{N=128,out_dtype=torch.bfloat16}_NVIDIA_H200.json @@ -59,6 +59,12 @@ "NUM_STAGES": 2, "num_warps": 4 }, + "164096": { + "BLOCK_M": 64, + "BLOCK_N": 128, + "NUM_STAGES": 4, + "num_warps": 1 + }, "2048": { "BLOCK_M": 1, "BLOCK_N": 256, diff --git a/lightllm/common/triton_utils/autotuner.py b/lightllm/common/triton_utils/autotuner.py index ec95c4b27..c69147087 100644 --- a/lightllm/common/triton_utils/autotuner.py +++ b/lightllm/common/triton_utils/autotuner.py @@ -62,7 +62,7 @@ def autotune( as needed before invocation. """ - def decorator(fn): + def decorator(fn: Callable) -> Callable: return Autotuner( fn=fn, kernel_name=kernel_name, diff --git a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py index 53835bfb1..6b62128cf 100644 --- a/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen3next/layer_infer/transformer_layer_infer.py @@ -271,14 +271,14 @@ def _linear_attn( query_start_loc=infer_state.b1_cu_q_seq_len, cache_indices=buffer_idx, has_initial_state=infer_state.b_ready_cache_len > 0, - conv_states=conv_states.transpose(1, 2), + conv_states=conv_states, activation=self.activation, ) mixed_qkv = out_tensor.transpose(0, 1) else: mixed_qkv = causal_conv1d_update( mixed_qkv, - conv_states.transpose(1, 2), + conv_states, layer_weight.linear_conv1d.mm_param.weight.transpose(0, 1), bias=layer_weight.linear_conv1d.mm_param.bias, activation=self.activation, diff --git a/lightllm/models/qwen3next/model.py b/lightllm/models/qwen3next/model.py index a8558d37a..b83222027 100644 --- a/lightllm/models/qwen3next/model.py +++ b/lightllm/models/qwen3next/model.py @@ -92,7 +92,7 @@ def _init_mem_manager(self): mtp_layer_num=start_args.mtp_step, full_attention_interval=self.config["full_attention_interval"], conv_state_dtype=self.data_type, - conv_state_shape=(conv_kernel_size - 1 + mtp_step, conv_dim // self.tp_world_size_), + conv_state_shape=(conv_dim // self.tp_world_size_, conv_kernel_size - 1 + mtp_step), ssm_state_dtype=ssm_dtype_dict[start_args.mamba_ssm_data_type], ssm_state_shape=( # mtp_step + 1, diff --git a/lightllm/server/api_http.py b/lightllm/server/api_http.py index 32db64174..2dab18dda 100755 --- a/lightllm/server/api_http.py +++ b/lightllm/server/api_http.py @@ -130,22 +130,6 @@ def get_model_name(): return {"model_name": g_objs.args.model_name} -@app.get("/get_server_info") -@app.post("/get_server_info") -def get_server_info(): - # 将 StartArgs 转换为字典格式 - from dataclasses import asdict - - server_info: dict[str, Any] = asdict(g_objs.args) - return {**server_info} - - -@app.get("/get_weight_version") -@app.post("/get_weight_version") -def get_weight_version(): - return {"weight_version": g_objs.args.weight_version} - - @app.get("/healthz", summary="Check server health") @app.get("/health", summary="Check server health") @app.head("/health", summary="Check server health") @@ -267,7 +251,7 @@ async def completions(request: CompletionRequest, raw_request: Request) -> Respo return create_error_response( HTTPStatus.EXPECTATION_FAILED, "service in pd mode dont recv reqs from http interface" ) - logger.info(f"completions request: {request}") + resp = await completions_impl(request, raw_request) return resp diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index e2a64f6e0..89c46d9ed 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -250,14 +250,13 @@ async def loop_for_fwd( frozen_token_num = self.shared_token_load.get_frozened_token_count(d_i) estimated_peak_token_count = self.shared_token_load.get_estimated_peak_token_count(d_i) paused_req_num = self._get_paused_req_num_in_dp_index(dp_index=d_i) - logger.debug( f"dp_i {d_i} current batch size: {len(self.running_batch.reqs)} \n" f"dp_i {d_i} paused req num: {paused_req_num} \n" f"dp_i {d_i} frozen token num: {frozen_token_num} \n" f"dp_i {d_i} estimated_peak_token_count: {estimated_peak_token_count} \n" f"dp_i {d_i} token used ratio: {token_ratio1} not contain prompt cache tree unrefed token\n" - f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token\n" + f"dp_i {d_i} token used ratio: {token_ratio2} contain prompt cache tree unrefed token" ) self.metric_client.gauge_set("lightllm_batch_pause_size", paused_req_num) # pd decode mode need to update token_load more frequently diff --git a/lightllm/server/router/model_infer/infer_batch.py b/lightllm/server/router/model_infer/infer_batch.py index d4e4ea38b..c2df005fe 100644 --- a/lightllm/server/router/model_infer/infer_batch.py +++ b/lightllm/server/router/model_infer/infer_batch.py @@ -34,8 +34,6 @@ class InferenceContext: overlap_stream: torch.cuda.Stream = None # 一些情况下推理进程进行异步折叠操作的异步流对象。 cpu_kv_cache_stream: torch.cuda.Stream = None # 用 cpu kv cache 操作的 stream - use_hybrid_radix_cache: bool = False - def register( self, backend, req_manager: ReqManager, radix_cache: RadixCache, shm_req_manager: ShmReqManager, vocab_size: int ): @@ -408,7 +406,6 @@ def _match_radix_cache(self): g_infer_context.req_manager.req_to_token_indexs[self.req_idx, 0:ready_cache_len] = value_tensor self.cur_kv_len = int(ready_cache_len) # 序列化问题, 该对象可能为numpy.int64,用 int(*)转换 self.shm_req.prompt_cache_len = self.cur_kv_len # 记录 prompt cache 的命中长度 - self.buffer_idx = share_node.buffer_idx self.shm_req.shm_cur_kv_len = self.cur_kv_len return diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index eb9f98cc7..caba3e36b 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -100,7 +100,6 @@ def init_model(self, kvargs): self.is_multinode_tp = self.args.nnodes > 1 and self.args.dp == 1 self.is_nixl_pd_mode = self.run_mode in ["nixl_prefill", "nixl_decode"] self.is_nixl_decode_mode = self.run_mode == "nixl_decode" - self.is_hybrid_model = kvargs.get("is_hybrid_model", False) self.logger = init_logger(__name__) diff --git a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py index a95d84116..23586d92f 100644 --- a/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py +++ b/lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl.py @@ -110,7 +110,6 @@ def prefill_normal( model_input, run_reqs = prepare_prefill_inputs( prefill_reqs, is_chuncked_mode=not self.disable_chunked_prefill, is_multimodal=self.is_multimodal ) - with torch.cuda.stream(g_infer_context.get_overlap_stream()): model_output = self.model.forward(model_input) _, next_token_ids_cpu, next_token_logprobs_cpu = self._sample_and_scatter_token( diff --git a/lightllm/utils/device_utils.py b/lightllm/utils/device_utils.py index e2a210621..022c5ab40 100644 --- a/lightllm/utils/device_utils.py +++ b/lightllm/utils/device_utils.py @@ -90,7 +90,7 @@ def get_current_device_name(): gpu_name = gpu_name.replace(" ", "_") return gpu_name else: - raise RuntimeError("No GPU available") + return None @lru_cache(maxsize=None) diff --git a/lightllm/utils/log_utils.py b/lightllm/utils/log_utils.py index 03409b4df..799786fba 100644 --- a/lightllm/utils/log_utils.py +++ b/lightllm/utils/log_utils.py @@ -29,17 +29,6 @@ def format(self, record): return msg -class RankFilter(logging.Filter): - def filter(self, record): - from lightllm.utils.dist_utils import get_current_rank_in_dp - - try: - rank = get_current_rank_in_dp() - return rank == 0 - except: - return False - - _root_logger = logging.getLogger("lightllm") _default_handler = None _default_file_handler = None @@ -56,7 +45,6 @@ def _setup_logger(): _default_handler = logging.StreamHandler(sys.stdout) _default_handler.flush = sys.stdout.flush # type: ignore _default_handler.setLevel(_LOG_LEVEL) - # _default_handler.addFilter(RankFilter()) _root_logger.addHandler(_default_handler) if _default_file_handler is None and _LOG_DIR is not None: @@ -68,7 +56,6 @@ def _setup_logger(): _default_file_handler = logging.FileHandler(_LOG_DIR + "/default.log") _default_file_handler.setLevel(_LOG_LEVEL) _default_file_handler.setFormatter(fmt) - # _default_file_handler.addFilter(RankFilter()) _root_logger.addHandler(_default_file_handler) _default_handler.setFormatter(fmt) From 7216a2acc844df5d33e2d18d022017ec57218498 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 12 Dec 2025 16:33:28 +0800 Subject: [PATCH 17/19] try fix --- .../triton_kernel/fla/ops/chunk_delta_h.py | 13 ++++++------- .../qwen3next/triton_kernel/fla/ops/chunk_o.py | 4 ++-- .../triton_kernel/fla/ops/chunk_scaled_dot_kkt.py | 4 ++-- .../models/qwen3next/triton_kernel/fla/ops/op.py | 10 ++++++++++ 4 files changed, 20 insertions(+), 11 deletions(-) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py index f34029927..b27fe7ada 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_delta_h.py @@ -14,7 +14,7 @@ import triton.language as tl from .index import prepare_chunk_indices, prepare_chunk_offsets -from .op import exp +from .op import exp, safe_exp from .utils import use_cuda_graph from lightllm.common.triton_utils.autotuner import autotune @@ -150,19 +150,18 @@ def chunk_gated_delta_rule_fwd_kernel_h_blockdim64( last_idx = min((i_t + 1) * BT, T) - 1 if USE_G: - m_t = (i_t * BT + tl.arange(0, BT)) < T b_g_last = tl.load(g + bos * H + last_idx * H + i_h) p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) - b_v = b_v * tl.where(m_t, exp(b_g_last - b_g), 0)[:, None] + b_v = b_v * safe_exp(b_g_last - b_g)[:, None] b_g_last = exp(b_g_last) - b_h1 *= b_g_last + b_h1 = b_h1 * b_g_last if K > 64: - b_h2 *= b_g_last + b_h2 = b_h2 * b_g_last if K > 128: - b_h3 *= b_g_last + b_h3 = b_h3 * b_g_last if K > 192: - b_h4 *= b_g_last + b_h4 = b_h4 * b_g_last if USE_GK: o_k1 = tl.arange(0, 64) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py index 12ee5a37e..fc49763ec 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_o.py @@ -16,7 +16,7 @@ import triton.language as tl from .index import prepare_chunk_indices -from .op import exp +from .op import exp, safe_exp from .utils import FLA_GDN_FIX_BT, check_shared_mem, is_nvidia_hopper from lightllm.common.triton_utils.autotuner import autotune @@ -103,7 +103,7 @@ def chunk_fwd_kernel_o( p_g = tl.make_block_ptr(g, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) b_o = b_o * exp(b_g)[:, None] - b_A = b_A * exp(b_g[:, None] - b_g[None, :]) + b_A = b_A * safe_exp(b_g[:, None] - b_g[None, :]) o_t = i_t * BT + tl.arange(0, BT) m_t = o_t < T diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py index 8f561cfa4..715d52dfa 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/chunk_scaled_dot_kkt.py @@ -14,7 +14,7 @@ import triton.language as tl from .index import prepare_chunk_indices -from .op import exp +from .op import exp, safe_exp from lightllm.common.triton_utils.autotuner import autotune triton.set_allocator @@ -80,7 +80,7 @@ def chunk_scaled_dot_kkt_fwd_kernel( p_g = tl.make_block_ptr(g + bos * H + i_h, (T,), (H,), (i_t * BT,), (BT,), (0,)) b_g = tl.load(p_g, boundary_check=(0,)) b_g_diff = b_g[:, None] - b_g[None, :] - b_A = b_A * exp(b_g_diff) + b_A = b_A * safe_exp(b_g_diff) b_A *= b_beta[:, None] m_A = (o_t[:, None] > o_t[None, :]) & (m_t[:, None] & m_t) diff --git a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py index d35c71ab9..f288b1f71 100644 --- a/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py +++ b/lightllm/models/qwen3next/triton_kernel/fla/ops/op.py @@ -19,6 +19,16 @@ log2 = tl.log2 +@triton.jit +def safe_exp(x): + """ + Numerically stable exponential function. + Only applies exp to non-positive values, returns 0 for positive values. + This prevents numerical overflow and improves stability. + """ + return exp(tl.where(x <= 0, x, float("-inf"))) + + if not is_gather_supported: @triton.jit From 3162252d21ede07e1e5e5af89c102c23b4fad9a9 Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 12 Dec 2025 17:33:41 +0800 Subject: [PATCH 18/19] fix prefix cache --- .../dynamic_prompt/hybrid_radix_cache.py | 28 +++++++++++-------- 1 file changed, 16 insertions(+), 12 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 95bdcc5e1..49e100386 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -1,4 +1,4 @@ -from typing import Set, Protocol, List +from typing import Set, Protocol, List, Optional, Tuple import torch from sortedcontainers import SortedSet @@ -29,7 +29,6 @@ def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None) self.mem_manager: HybridMemManager = mem_manager super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: x.time_id) - self.evict_buffer_set.add(self.root_node) def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): if need_buffer_num > self.mem_manager.get_buffer_can_use_size(): @@ -56,16 +55,12 @@ def release_buffer(buffer_idx): def evict_buffer(self, need_evict_buffer_num, evict_buffer_callback, evict_token_callback): while need_evict_buffer_num > 0: - node = self.evict_buffer_set.pop() - if node.buffer_idx is not None: - evict_buffer_callback(node.buffer_idx) - need_evict_buffer_num -= 1 - else: - # 在混合注意力模型的情景里,只能匹配 buffer_idx 不为 None的节点 - # 假如 buffer_idx 为 None,则当做匹配失败。 - # 所以可以直接把这个节点给释放掉 - if node.is_leaf() and node.ref_counter == 0: - self._remove_leaf_node(node) + node = self.evict_buffer_set.pop(0) + assert node.buffer_idx is not None + evict_buffer_callback(node.buffer_idx) + evict_token_callback(node.token_mem_index_value) + need_evict_buffer_num -= 1 + self._remove_leaf_node(node) return def insert_for_hybrid_radix_cache(self, reqs): @@ -90,6 +85,7 @@ def insert_for_hybrid_radix_cache(self, reqs): new_shared_kv_node.buffer_idx = new_buffer_indexes[i] self.dec_node_ref_counter(req.shared_kv_node) self.add_node_ref_counter(new_shared_kv_node) + self.evict_buffer_set.add(req.shared_kv_node) req.shared_kv_node = new_shared_kv_node def match_prefix(self, key, update_refs=False): @@ -113,9 +109,17 @@ def match_prefix(self, key, update_refs=False): def _remove_leaf_node(self, node: TreeNode): self.evict_tree_set.discard(node) + self.evict_buffer_set.discard(node) self.tree_total_tokens_num.arr[0] -= len(node.token_mem_index_value) parent_node: TreeNode = node.parent parent_node.remove_child(node) if parent_node.is_leaf(): self.evict_tree_set.add(parent_node) + if parent_node.buffer_idx is not None: + self.evict_buffer_set.add(parent_node) return parent_node + + def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: + prefix_len, node = super().insert(key, value) + self.evict_buffer_set.add(node) + return prefix_len, node From 9d613bcd43f2e8604f3678d277c139f4ebc7ff9e Mon Sep 17 00:00:00 2001 From: sufubao Date: Fri, 12 Dec 2025 10:17:52 +0000 Subject: [PATCH 19/19] fix --- .../dynamic_prompt/hybrid_radix_cache.py | 24 ++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py index 49e100386..e77363c76 100644 --- a/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py +++ b/lightllm/server/router/dynamic_prompt/hybrid_radix_cache.py @@ -28,7 +28,7 @@ class HybridRadixCache(RadixCache): def __init__(self, unique_name, total_token_num, rank_in_node, mem_manager=None): self.mem_manager: HybridMemManager = mem_manager super().__init__(unique_name, total_token_num, rank_in_node, mem_manager) - self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: x.time_id) + self.evict_buffer_set: Set[TreeNode] = SortedSet(key=lambda x: (x.time_id,)) def free_radix_cache_to_get_enough_buffer(self, need_buffer_num): if need_buffer_num > self.mem_manager.get_buffer_can_use_size(): @@ -85,7 +85,8 @@ def insert_for_hybrid_radix_cache(self, reqs): new_shared_kv_node.buffer_idx = new_buffer_indexes[i] self.dec_node_ref_counter(req.shared_kv_node) self.add_node_ref_counter(new_shared_kv_node) - self.evict_buffer_set.add(req.shared_kv_node) + if req.shared_kv_node is not None and req.shared_kv_node.buffer_idx is not None: + self.update_buffer_evict_set(req.shared_kv_node) req.shared_kv_node = new_shared_kv_node def match_prefix(self, key, update_refs=False): @@ -105,6 +106,7 @@ def match_prefix(self, key, update_refs=False): return None, 0, None value = torch.concat(ans_value_list) + self.update_buffer_evict_set(tree_node) return tree_node, len(value), value def _remove_leaf_node(self, node: TreeNode): @@ -116,10 +118,26 @@ def _remove_leaf_node(self, node: TreeNode): if parent_node.is_leaf(): self.evict_tree_set.add(parent_node) if parent_node.buffer_idx is not None: - self.evict_buffer_set.add(parent_node) + self.update_buffer_evict_set(parent_node) return parent_node def insert(self, key, value=None) -> Tuple[int, Optional[TreeNode]]: prefix_len, node = super().insert(key, value) + if node is not None: + node.update_buffer_time() self.evict_buffer_set.add(node) return prefix_len, node + + def update_buffer_evict_set(self, node: TreeNode): + if node is None or node.buffer_idx is None: + return + + if node not in self.evict_buffer_set: + self.evict_buffer_set.add(node) + return + + self.evict_buffer_set.discard(node) + node.update_buffer_time() + self.evict_buffer_set.add(node) + + self.update_buffer_evict_set(node.parent)