diff --git a/csrc/moe/moe_lora_align_sum_kernels.cu b/csrc/moe/moe_lora_align_sum_kernels.cu index e76d1c366785..360f1312cf57 100644 --- a/csrc/moe/moe_lora_align_sum_kernels.cu +++ b/csrc/moe/moe_lora_align_sum_kernels.cu @@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel( int64_t block_size, int num_experts, int max_loras, size_t numel, int max_num_tokens_padded, int max_num_m_blocks, int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids, - int topk_num, int32_t* total_tokens_post_pad) { + int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled, + int32_t* lora_ids) { const size_t tokens_per_thread = div_ceil(numel, blockDim.x); const size_t start_idx = threadIdx.x * tokens_per_thread; - int lora_id = blockIdx.x; + int lora_idx = blockIdx.x; + int lora_id = lora_ids[lora_idx]; + if (lora_id == -1 || adapter_enabled[lora_id] == 0) { + return; + } extern __shared__ int32_t shared_mem[]; int32_t* cumsum = shared_mem; token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1); @@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel( } } -void moe_lora_align_block_size(torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, - int64_t max_loras, int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad) { +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids) { const int topk_num = topk_ids.size(1); TORCH_CHECK(block_size > 0, "block_size should be greater than 0. "); @@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids, max_loras, topk_ids.numel(), max_num_tokens_padded, max_num_m_blocks, sorted_token_ids.data_ptr(), expert_ids.data_ptr(), topk_num, - num_tokens_post_pad.data_ptr()); + num_tokens_post_pad.data_ptr(), + adapter_enabled.data_ptr(), lora_ids.data_ptr()); }); } \ No newline at end of file diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index e4bf0aa99421..0adf745689b2 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch, torch::Tensor expert_ids, torch::Tensor num_tokens_post_pad); -void moe_lora_align_block_size(torch::Tensor topk_ids, - torch::Tensor token_lora_mapping, - int64_t num_experts, int64_t block_size, - int64_t max_loras, int64_t max_num_tokens_padded, - int64_t max_num_m_blocks, - torch::Tensor sorted_token_ids, - torch::Tensor expert_ids, - torch::Tensor num_tokens_post_pad); +void moe_lora_align_block_size( + torch::Tensor topk_ids, torch::Tensor token_lora_mapping, + int64_t num_experts, int64_t block_size, int64_t max_loras, + int64_t max_num_tokens_padded, int64_t max_num_m_blocks, + torch::Tensor sorted_token_ids, torch::Tensor expert_ids, + torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled, + torch::Tensor lora_ids); #ifndef USE_ROCM torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output, torch::Tensor b_qweight, torch::Tensor b_scales, diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index c08a543908ef..ace72fad71e8 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { " int max_num_m_blocks, " " Tensor !sorted_token_ids," " Tensor !experts_ids," - " Tensor !num_tokens_post_pad) -> () "); + " Tensor !num_tokens_post_pad," + " Tensor !adapter_enabled," + " Tensor !lora_ids) -> () "); m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size); #ifndef USE_ROCM diff --git a/tests/lora/test_fused_moe_lora_kernel.py b/tests/lora/test_fused_moe_lora_kernel.py index b724e112b9dd..318a0e58805d 100644 --- a/tests/lora/test_fused_moe_lora_kernel.py +++ b/tests/lora/test_fused_moe_lora_kernel.py @@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel( ) expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32) num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32) + adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32) + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32) # call kernel ops.moe_lora_align_block_size( @@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel( sorted_token_ids, expert_ids, num_tokens_post_padded, + adapter_enabled, + lora_ids, ) config = { @@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel( num_tokens_post_padded, max_lora_rank, top_k_num, + lora_ids, + adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"], diff --git a/tests/lora/test_moe_lora_align_sum.py b/tests/lora/test_moe_lora_align_sum.py index 6cd1281c3632..72f1d759f1e7 100644 --- a/tests/lora/test_moe_lora_align_sum.py +++ b/tests/lora/test_moe_lora_align_sum.py @@ -60,6 +60,8 @@ def test_moe_lora_align_block_size( (max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda" ) num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda") + adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda") + lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda") # call kernel ops.moe_lora_align_block_size( @@ -73,6 +75,8 @@ def test_moe_lora_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_pad, + adapter_enabled, + lora_ids, ) # verify values diff --git a/tests/lora/test_olmoe_tp.py b/tests/lora/test_olmoe_tp.py index b954e0776ca4..e659c1e1a9a0 100644 --- a/tests/lora/test_olmoe_tp.py +++ b/tests/lora/test_olmoe_tp.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project + import vllm from vllm.lora.request import LoRARequest @@ -28,8 +29,17 @@ "SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501 ] +EXPECTED_BASE_MODEL_OUTPUT = [ + "SELECT COUNT(Candidate_ID) FROM candidate", + "SELECT COUNT(Candidate_ID) FROM candidate", + "SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501 + "SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501 +] + -def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: +def generate_and_test( + llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None +) -> None: prompts = [ PROMPT_TEMPLATE.format(context="How many candidates are there?"), PROMPT_TEMPLATE.format(context="Count the number of candidates."), @@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: context="Return the poll resource associated with the most candidates." ), ] + + lora_request = None + if isinstance(lora_id, int): + lora_request = LoRARequest(str(lora_id), lora_id, lora_path) + elif isinstance(lora_id, list): + lora_request = [ + LoRARequest(str(i), i, lora_path) if i is not None else None + for i in lora_id + ] + sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None, - ) + outputs = llm.generate(prompts, sampling_params, lora_request=lora_request) # Print the outputs. generated_texts: list[str] = [] for output in outputs: @@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None: print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") for i in range(len(EXPECTED_LORA_OUTPUT)): - assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i]) + req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id + expected_output = ( + EXPECTED_LORA_OUTPUT[i] + if req_lora_id is not None + else EXPECTED_BASE_MODEL_OUTPUT[i] + ) + assert generated_texts[i].startswith(expected_output) def test_olmoe_lora(olmoe_lora_files): @@ -75,6 +97,20 @@ def test_olmoe_lora(olmoe_lora_files): generate_and_test(llm, olmoe_lora_files, lora_id=2) +def test_olmoe_lora_mixed(olmoe_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + enforce_eager=True, + trust_remote_code=True, + enable_chunked_prefill=True, + ) + + generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None]) + + @multi_gpu_test(num_gpus=2) def test_olmoe_lora_tp2(olmoe_lora_files): llm = vllm.LLM( diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 61cf54fcfa39..657b11046809 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -1823,6 +1823,8 @@ def moe_lora_align_block_size( sorted_token_ids: torch.Tensor, experts_ids: torch.Tensor, num_tokens_post_pad: torch.Tensor, + adapter_enabled: torch.Tensor, + lora_ids: torch.Tensor, ) -> None: torch.ops._moe_C.moe_lora_align_block_size( topk_ids, @@ -1835,6 +1837,8 @@ def moe_lora_align_block_size( sorted_token_ids, experts_ids, num_tokens_post_pad, + adapter_enabled, + lora_ids, ) diff --git a/vllm/lora/layers/fused_moe.py b/vllm/lora/layers/fused_moe.py index 275a2ed0c681..7711f5c3208b 100644 --- a/vllm/lora/layers/fused_moe.py +++ b/vllm/lora/layers/fused_moe.py @@ -111,6 +111,7 @@ def wrapper(*args, **kwargs): config["BLOCK_SIZE_M"], self.base_layer.local_num_experts, max_loras, + self.adapter_enabled, expert_map, ) @@ -138,6 +139,7 @@ def wrapper(*args, **kwargs): max_lora_rank, top_k, config, + self.adapter_enabled, ) result = func(*args, **kwargs) @@ -196,6 +198,7 @@ def wrapper(*args, **kwargs): max_lora_rank, top_k, config, + self.adapter_enabled, True, ) @@ -227,6 +230,10 @@ def create_lora_weights( ) -> None: """Initializes lora matrices.""" + self.adapter_enabled = torch.tensor( + [0] * (max_loras + 1), dtype=torch.int, device=self.device + ) + self.w1_lora_a_stacked = torch.zeros( ( max_loras, @@ -313,6 +320,7 @@ def reset_lora(self, index: int): self.w3_lora_b_stacked[index] = 0 self.w2_lora_a_stacked[index] = 0 self.w2_lora_b_stacked[index] = 0 + self.adapter_enabled[index] = 0 def set_lora( self, @@ -322,8 +330,9 @@ def set_lora( embeddings_tensor: torch.Tensor | None, bias: torch.Tensor | None = None, ): - self.reset_lora(index) """Overwrites lora tensors at index.""" + self.reset_lora(index) + self.adapter_enabled[index] = 1 for eid in range(len(lora_a) // 3): w1_lora_a = lora_a[eid * 3] w2_lora_a = lora_a[eid * 3 + 1] diff --git a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py index 15031f5e2f9e..539605c7c534 100644 --- a/vllm/lora/ops/triton_ops/fused_moe_lora_op.py +++ b/vllm/lora/ops/triton_ops/fused_moe_lora_op.py @@ -54,6 +54,8 @@ def _fused_moe_lora_kernel( EM, num_valid_tokens, num_experts, + lora_ids, + adapter_enabled, # The stride variables represent how much to increase the ptr by when # moving by 1 element in a particular dimension. E.g. `stride_am` is # how much to increase `a_ptr` by to get the element one row down @@ -84,6 +86,11 @@ def _fused_moe_lora_kernel( pid = tl.program_id(axis=0) slice_id = tl.program_id(axis=1) lora_idx = tl.program_id(axis=2) + lora_id = tl.load(lora_ids + lora_idx) + moe_enabled = tl.load(adapter_enabled + lora_id) + if lora_id == -1 or moe_enabled == 0: + # Early exit for the no-lora case. + return max_loras = tl.num_programs(axis=2) grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K) @@ -100,12 +107,12 @@ def _fused_moe_lora_kernel( pid_m = first_pid_m + ((pid_m_n % num_pid_in_group) % group_size_m) pid_n = (pid_m_n % num_pid_in_group) // group_size_m - num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx) + num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id) if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded: return # get the expert_id to process curr shard - ind = lora_idx * stride_el + pid_m + ind = lora_id * stride_el + pid_m expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1) if expert_id == -1: return @@ -119,7 +126,7 @@ def _fused_moe_lora_kernel( offs_k = pid_sk * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K) offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64) - token_ind = stride_tl * lora_idx + offs_token_id + token_ind = stride_tl * lora_id + offs_token_id offs_token = tl.load( sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0 ) @@ -132,7 +139,7 @@ def _fused_moe_lora_kernel( b_ptrs = ( cur_b_ptr - + lora_idx * stride_bl + + lora_id * stride_bl + expert_id * stride_be + offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn @@ -184,6 +191,8 @@ def _fused_moe_lora( num_tokens_post_padded: torch.Tensor, # (max_loras, ) max_lora_rank: int, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, block_size_m: int, block_size_n: int, block_size_k: int, @@ -234,7 +243,7 @@ def _fused_moe_lora( num_tokens = M * top_k_num w1_output_dim_size = w1_lora_b_stacked.shape[2] - lora_intermediate_cache1 = torch.empty( + lora_intermediate_cache1 = torch.zeros( (num_slices * M * top_k_num * (max_lora_rank + w1_output_dim_size)), dtype=output.dtype, device=device, @@ -272,6 +281,8 @@ def _fused_moe_lora( EM, num_tokens, num_experts, + lora_ids, + adapter_enabled, qcurr_hidden_states.stride(0), qcurr_hidden_states.stride(1), w1_lora_a_stacked.stride(0), @@ -319,6 +330,8 @@ def _fused_moe_lora( EM, num_tokens, num_experts, + lora_ids, + adapter_enabled, a_intermediate_cache1.stride(0), a_intermediate_cache1.stride(1), w1_lora_b_stacked.stride(0), @@ -352,6 +365,8 @@ def _fused_moe_lora_fake( num_tokens_post_padded: torch.Tensor, max_lora_rank: int, top_k_num: int, + lora_ids: torch.Tensor, + adapter_enabled: torch.Tensor, block_size_m: int, block_size_n: int, block_size_k: int, diff --git a/vllm/lora/punica_wrapper/punica_base.py b/vllm/lora/punica_wrapper/punica_base.py index 5b4a18cf4789..c552412cfd62 100644 --- a/vllm/lora/punica_wrapper/punica_base.py +++ b/vllm/lora/punica_wrapper/punica_base.py @@ -456,6 +456,7 @@ def moe_lora_align_block_size( block_size: int, num_experts: int, max_loras: int, + adapter_enabled: torch.Tensor, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -479,6 +480,7 @@ def add_lora_fused_moe( max_lora_rank: int, top_k_num: int, config, + adapter_enabled: torch.Tensor, mul_routed_weight=False, ): """ diff --git a/vllm/lora/punica_wrapper/punica_gpu.py b/vllm/lora/punica_wrapper/punica_gpu.py index d9590769778e..30def90380db 100644 --- a/vllm/lora/punica_wrapper/punica_gpu.py +++ b/vllm/lora/punica_wrapper/punica_gpu.py @@ -305,6 +305,7 @@ def moe_lora_align_block_size( block_size: int, num_experts: int, max_loras: int, + adapter_enabled: torch.Tensor, expert_map: torch.Tensor | None = None, pad_sorted_ids: bool = False, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: @@ -331,7 +332,7 @@ def moe_lora_align_block_size( (max_loras), dtype=torch.int32, device=topk_ids.device ) - (token_lora_mapping, _, _, _, _, _) = self.token_mapping_meta.meta_args( + (token_lora_mapping, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args( num_tokens ) @@ -346,6 +347,8 @@ def moe_lora_align_block_size( sorted_ids, expert_ids, num_tokens_post_pad, + adapter_enabled, + lora_ids, ) if expert_map is not None: expert_ids = expert_map[expert_ids] @@ -365,11 +368,13 @@ def add_lora_fused_moe( max_lora_rank: int, top_k_num: int, config, + adapter_enabled: torch.Tensor, mul_routed_weight=False, ): """ Performs a fused forward computation for LoRA of Mixture-of-Experts (MoE) layer. """ + (_, _, _, _, lora_ids, _) = self.token_mapping_meta.meta_args(x.size(0)) fused_moe_lora( y, x, @@ -381,6 +386,8 @@ def add_lora_fused_moe( num_tokens_post_padded, max_lora_rank, top_k_num, + lora_ids, + adapter_enabled, config["BLOCK_SIZE_M"], config["BLOCK_SIZE_N"], config["BLOCK_SIZE_K"],