Skip to content

Commit 5f6cbf6

Browse files
wcwuwcwuchenvangheembanjuedeDanielle Robinson
authored
[Feature][Kernel]FusedMoE LoRA (#21229)
Signed-off-by: wuchen <[email protected]> Signed-off-by: banjuede <[email protected]> Signed-off-by: Chen Wu <[email protected]> Signed-off-by: Danielle Robinson <[email protected]> Signed-off-by: Jee Jee Li <[email protected]> Signed-off-by: bk-201 <[email protected]> Co-authored-by: wuchen <[email protected]> Co-authored-by: Nathan Van Gheem <[email protected]> Co-authored-by: banjuede <[email protected]> Co-authored-by: Danielle Robinson <[email protected]> Co-authored-by: Jee Jee Li <[email protected]> Co-authored-by: bk-201 <[email protected]>
1 parent 3ada34f commit 5f6cbf6

28 files changed

+2084
-55
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,12 @@ steps:
384384
--num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT \
385385
--ignore=lora/test_chatglm3_tp.py \
386386
--ignore=lora/test_llama_tp.py \
387-
--ignore=lora/test_llm_with_multi_loras.py
387+
--ignore=lora/test_llm_with_multi_loras.py \
388+
--ignore=lora/test_olmoe_tp.py \
389+
--ignore=lora/test_deepseekv2_tp.py \
390+
--ignore=lora/test_gptoss.py \
391+
--ignore=lora/test_qwen3moe_tp.py
392+
388393
parallelism: 4
389394

390395
- label: PyTorch Compilation Unit Tests # 15min
@@ -1065,6 +1070,7 @@ steps:
10651070
- pytest -v -s -x lora/test_chatglm3_tp.py
10661071
- pytest -v -s -x lora/test_llama_tp.py
10671072
- pytest -v -s -x lora/test_llm_with_multi_loras.py
1073+
- pytest -v -s -x lora/test_olmoe_tp.py
10681074

10691075

10701076
- label: Weight Loading Multiple GPU Test # 33min

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,7 @@ target_compile_definitions(_C PRIVATE CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1)
883883
set(VLLM_MOE_EXT_SRC
884884
"csrc/moe/torch_bindings.cpp"
885885
"csrc/moe/moe_align_sum_kernels.cu"
886+
"csrc/moe/moe_lora_align_sum_kernels.cu"
886887
"csrc/moe/topk_softmax_kernels.cu")
887888

888889
if(VLLM_GPU_LANG STREQUAL "CUDA")
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
#include <stdio.h>
2+
#include <stdlib.h>
3+
#include <time.h>
4+
#include <torch/all.h>
5+
#include <ATen/cuda/CUDAContext.h>
6+
#include <c10/cuda/CUDAGuard.h>
7+
8+
#include <ATen/ATen.h>
9+
#include <ATen/cuda/Atomic.cuh>
10+
11+
#include "../cuda_compat.h"
12+
#include "../dispatch_utils.h"
13+
#include "core/math.hpp"
14+
15+
namespace {
16+
17+
__device__ __forceinline__ int32_t index(int32_t total_col, int32_t row,
18+
int32_t col) {
19+
return row * total_col + col;
20+
}
21+
22+
} // namespace
23+
24+
// TODO: Refactor common parts with moe_align_sum_kernels
25+
template <typename scalar_t, typename token_cnts_t>
26+
__global__ void moe_lora_align_sum_kernel(
27+
scalar_t* __restrict__ topk_ids, int32_t* token_lora_mapping,
28+
int64_t block_size, int num_experts, int max_loras, size_t numel,
29+
int max_num_tokens_padded, int max_num_m_blocks,
30+
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
31+
int topk_num, int32_t* total_tokens_post_pad) {
32+
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
33+
const size_t start_idx = threadIdx.x * tokens_per_thread;
34+
35+
int lora_id = blockIdx.x;
36+
extern __shared__ int32_t shared_mem[];
37+
int32_t* cumsum = shared_mem;
38+
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
39+
40+
// Initialize sorted_token_ids with numel
41+
for (size_t it = threadIdx.x; it < max_num_tokens_padded; it += blockDim.x) {
42+
sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
43+
}
44+
45+
// Initialize expert_ids with -1
46+
for (size_t it = threadIdx.x; it < max_num_m_blocks; it += blockDim.x) {
47+
expert_ids[lora_id * max_num_m_blocks + it] = -1;
48+
}
49+
50+
// Initialize total_tokens_post_pad with 0
51+
if (threadIdx.x == 0) {
52+
total_tokens_post_pad[lora_id] = 0;
53+
}
54+
55+
for (int i = 0; i < num_experts; ++i) {
56+
tokens_cnts[index(num_experts, threadIdx.x + 1, i)] = 0;
57+
}
58+
59+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
60+
int mask = token_lora_mapping[i / topk_num] == lora_id;
61+
int idx = index(num_experts, threadIdx.x + 1, topk_ids[i]);
62+
tokens_cnts[idx] += mask;
63+
}
64+
65+
__syncthreads();
66+
67+
// For each expert we accumulate the token counts from the different threads.
68+
if (threadIdx.x < num_experts) {
69+
tokens_cnts[index(num_experts, 0, threadIdx.x)] = 0;
70+
for (int i = 1; i <= blockDim.x; ++i) {
71+
tokens_cnts[index(num_experts, i, threadIdx.x)] +=
72+
tokens_cnts[index(num_experts, i - 1, threadIdx.x)];
73+
}
74+
}
75+
76+
__syncthreads();
77+
78+
// We accumulate the token counts of all experts in thread 0.
79+
if (threadIdx.x == 0) {
80+
cumsum[0] = 0;
81+
for (int i = 1; i <= num_experts; ++i) {
82+
cumsum[i] = cumsum[i - 1] +
83+
div_ceil(tokens_cnts[index(num_experts, blockDim.x, i - 1)],
84+
block_size) *
85+
block_size;
86+
}
87+
total_tokens_post_pad[lora_id] = static_cast<int32_t>(cumsum[num_experts]);
88+
}
89+
90+
__syncthreads();
91+
92+
/**
93+
* For each expert, each thread processes the tokens of the corresponding
94+
* blocks and stores the corresponding expert_id for each block.
95+
*/
96+
if (threadIdx.x < num_experts) {
97+
for (int i = cumsum[threadIdx.x]; i < cumsum[threadIdx.x + 1];
98+
i += block_size) {
99+
expert_ids[index(max_num_m_blocks, lora_id, i / block_size)] =
100+
threadIdx.x;
101+
}
102+
}
103+
104+
for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
105+
int32_t expert_id = topk_ids[i];
106+
/** The cumsum[expert_id] stores the starting index of the tokens that the
107+
* expert with expert_id needs to process, and
108+
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
109+
* processed by the expert with expert_id within the current thread's token
110+
* shard.
111+
*/
112+
int32_t rank_post_pad =
113+
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] +
114+
cumsum[expert_id];
115+
116+
int mask = (int)token_lora_mapping[i / topk_num] == lora_id;
117+
atomicAdd(
118+
&sorted_token_ids[index(max_num_tokens_padded, lora_id, rank_post_pad)],
119+
(i - numel) * mask);
120+
tokens_cnts[index(num_experts, threadIdx.x, expert_id)] += mask;
121+
}
122+
}
123+
124+
void moe_lora_align_block_size(torch::Tensor topk_ids,
125+
torch::Tensor token_lora_mapping,
126+
int64_t num_experts, int64_t block_size,
127+
int64_t max_loras,
128+
torch::Tensor sorted_token_ids,
129+
torch::Tensor expert_ids,
130+
torch::Tensor num_tokens_post_pad) {
131+
const int topk_num = topk_ids.size(1);
132+
133+
int max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1);
134+
135+
TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
136+
max_num_tokens_padded = round_to_next_multiple_of(
137+
max_num_tokens_padded, static_cast<int>(block_size));
138+
int max_num_m_blocks = div_ceil(max_num_tokens_padded, block_size);
139+
140+
int device_max_shared_mem;
141+
auto dev = topk_ids.get_device();
142+
cudaDeviceGetAttribute(&device_max_shared_mem,
143+
cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
144+
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
145+
146+
const int32_t num_thread = max((int32_t)num_experts, 128); // WARP_SIZE,
147+
TORCH_CHECK(num_thread <= 1024,
148+
"num_thread must be less than 1024, "
149+
"and fallback is not implemented yet.");
150+
const int32_t shared_mem = (num_thread + 1) * num_experts * sizeof(int32_t) +
151+
(num_experts + 1) * sizeof(int32_t);
152+
153+
if (shared_mem > device_max_shared_mem) {
154+
TORCH_CHECK(false,
155+
"Shared memory usage exceeds device limit, and global memory "
156+
"fallback is not implemented yet.");
157+
}
158+
159+
VLLM_DISPATCH_INTEGRAL_TYPES(
160+
topk_ids.scalar_type(), "moe_lora_align_sum_kernel", [&] {
161+
dim3 blockDim(num_thread);
162+
auto kernel = moe_lora_align_sum_kernel<scalar_t, int32_t>;
163+
AT_CUDA_CHECK(VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(
164+
(void*)kernel, shared_mem));
165+
kernel<<<max_loras, blockDim, shared_mem, stream>>>(
166+
topk_ids.data_ptr<scalar_t>(),
167+
token_lora_mapping.data_ptr<int32_t>(), block_size, num_experts,
168+
max_loras, topk_ids.numel(), max_num_tokens_padded,
169+
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
170+
expert_ids.data_ptr<int32_t>(), topk_num,
171+
num_tokens_post_pad.data_ptr<int32_t>());
172+
});
173+
}

csrc/moe/moe_ops.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
2020
torch::Tensor expert_ids,
2121
torch::Tensor num_tokens_post_pad);
2222

23+
void moe_lora_align_block_size(torch::Tensor topk_ids,
24+
torch::Tensor token_lora_mapping,
25+
int64_t num_experts, int64_t block_size,
26+
int64_t max_loras,
27+
torch::Tensor sorted_token_ids,
28+
torch::Tensor expert_ids,
29+
torch::Tensor num_tokens_post_pad);
2330
#ifndef USE_ROCM
2431
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
2532
torch::Tensor b_qweight, torch::Tensor b_scales,

csrc/moe/torch_bindings.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,18 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
3333
m.impl("batched_moe_align_block_size", torch::kCUDA,
3434
&batched_moe_align_block_size);
3535

36+
// Aligning the number of tokens to be processed by each expert such
37+
// that it is divisible by the block size.
38+
m.def(
39+
"moe_lora_align_block_size(Tensor topk_ids,"
40+
" Tensor token_lora_mapping,"
41+
" int num_experts,"
42+
" int block_size, int max_loras, "
43+
" Tensor !sorted_token_ids,"
44+
" Tensor !experts_ids,"
45+
" Tensor !num_tokens_post_pad) -> () ");
46+
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);
47+
3648
#ifndef USE_ROCM
3749
m.def(
3850
"moe_wna16_gemm(Tensor input, Tensor! output, Tensor b_qweight, "

tests/lora/conftest.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -230,6 +230,26 @@ def tinyllama_lora_files():
230230
return snapshot_download(repo_id="jashing/tinyllama-colorist-lora")
231231

232232

233+
@pytest.fixture(scope="session")
234+
def deepseekv2_lora_files():
235+
return snapshot_download(repo_id="wuchen01/DeepSeek-V2-Lite-Chat-All-LoRA")
236+
237+
238+
@pytest.fixture(scope="session")
239+
def gptoss20b_lora_files():
240+
return snapshot_download(repo_id="LevinZheng/gpt-oss-20b-lora-adapter")
241+
242+
243+
@pytest.fixture(scope="session")
244+
def qwen3moe_lora_files():
245+
return snapshot_download(repo_id="jeeejeee/qwen3-moe-text2sql-spider")
246+
247+
248+
@pytest.fixture(scope="session")
249+
def olmoe_lora_files():
250+
return snapshot_download(repo_id="jeeejeee/olmoe-instruct-text2sql-spider")
251+
252+
233253
@pytest.fixture
234254
def reset_default_device():
235255
"""

tests/lora/test_deepseekv2_tp.py

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3+
4+
import vllm
5+
from vllm.lora.request import LoRARequest
6+
7+
from ..utils import multi_gpu_test
8+
9+
MODEL_PATH = "deepseek-ai/DeepSeek-V2-Lite-Chat"
10+
11+
PROMPT_TEMPLATE = "<|begin▁of▁sentence|>You are a helpful assistant.\n\nUser: {context}\n\nAssistant:" # noqa: E501
12+
13+
14+
def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int):
15+
prompts = [
16+
PROMPT_TEMPLATE.format(context="Who are you?"),
17+
]
18+
sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
19+
outputs = llm.generate(
20+
prompts,
21+
sampling_params,
22+
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
23+
)
24+
# Print the outputs.
25+
generated_texts: list[str] = []
26+
for output in outputs:
27+
prompt = output.prompt
28+
generated_text = output.outputs[0].text.strip()
29+
generated_texts.append(generated_text)
30+
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
31+
# return generated_texts
32+
expected_lora_output = [
33+
"I am \u5f20\u5b50\u8c6a, an AI assistant developed by \u9648\u58eb\u680b.", # noqa: E501
34+
]
35+
for i in range(len(expected_lora_output)):
36+
assert generated_texts[i].startswith(expected_lora_output[i])
37+
38+
39+
def test_deepseekv2_lora(deepseekv2_lora_files):
40+
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
41+
# Otherwise, the lora-test will fail due to CUDA OOM.
42+
llm = vllm.LLM(
43+
MODEL_PATH,
44+
max_model_len=1024,
45+
enable_lora=True,
46+
max_loras=4,
47+
enforce_eager=True,
48+
trust_remote_code=True,
49+
enable_chunked_prefill=True,
50+
)
51+
generate_and_test(llm, deepseekv2_lora_files, 1)
52+
53+
54+
def test_deepseekv2(deepseekv2_lora_files):
55+
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
56+
# Otherwise, the lora-test will fail due to CUDA OOM.
57+
llm = vllm.LLM(
58+
MODEL_PATH,
59+
max_model_len=1024,
60+
enable_lora=True,
61+
max_loras=4,
62+
enforce_eager=True,
63+
trust_remote_code=True,
64+
)
65+
generate_and_test(llm, deepseekv2_lora_files, 1)
66+
67+
68+
@multi_gpu_test(num_gpus=2)
69+
def test_deepseekv2_tp2(deepseekv2_lora_files):
70+
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
71+
# Otherwise, the lora-test will fail due to CUDA OOM.
72+
llm = vllm.LLM(
73+
MODEL_PATH,
74+
max_model_len=1024,
75+
enable_lora=True,
76+
max_loras=4,
77+
enforce_eager=True,
78+
trust_remote_code=True,
79+
tensor_parallel_size=2,
80+
)
81+
generate_and_test(llm, deepseekv2_lora_files, 2)
82+
83+
84+
@multi_gpu_test(num_gpus=4)
85+
def test_deepseekv2_tp4(deepseekv2_lora_files):
86+
# We enable enforce_eager=True here to reduce VRAM usage for lora-test CI,
87+
# Otherwise, the lora-test will fail due to CUDA OOM.
88+
llm = vllm.LLM(
89+
MODEL_PATH,
90+
max_model_len=1024,
91+
enable_lora=True,
92+
max_loras=4,
93+
enforce_eager=True,
94+
trust_remote_code=True,
95+
tensor_parallel_size=4,
96+
)
97+
generate_and_test(llm, deepseekv2_lora_files, 2)

0 commit comments

Comments
 (0)