Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions csrc/moe/moe_lora_align_sum_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,16 @@ __global__ void moe_lora_align_sum_kernel(
int64_t block_size, int num_experts, int max_loras, size_t numel,
int max_num_tokens_padded, int max_num_m_blocks,
int32_t* __restrict__ sorted_token_ids, int32_t* __restrict__ expert_ids,
int topk_num, int32_t* total_tokens_post_pad) {
int topk_num, int32_t* total_tokens_post_pad, int32_t* adapter_enabled,
int32_t* lora_ids) {
const size_t tokens_per_thread = div_ceil(numel, blockDim.x);
const size_t start_idx = threadIdx.x * tokens_per_thread;

int lora_id = blockIdx.x;
int lora_idx = blockIdx.x;
int lora_id = lora_ids[lora_idx];
if (lora_id == -1 || adapter_enabled[lora_id] == 0) {
return;
}
extern __shared__ int32_t shared_mem[];
int32_t* cumsum = shared_mem;
token_cnts_t* tokens_cnts = (token_cnts_t*)(shared_mem + num_experts + 1);
Expand Down Expand Up @@ -121,14 +126,13 @@ __global__ void moe_lora_align_sum_kernel(
}
}

void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad) {
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids) {
const int topk_num = topk_ids.size(1);

TORCH_CHECK(block_size > 0, "block_size should be greater than 0. ");
Expand Down Expand Up @@ -164,6 +168,7 @@ void moe_lora_align_block_size(torch::Tensor topk_ids,
max_loras, topk_ids.numel(), max_num_tokens_padded,
max_num_m_blocks, sorted_token_ids.data_ptr<int32_t>(),
expert_ids.data_ptr<int32_t>(), topk_num,
num_tokens_post_pad.data_ptr<int32_t>());
num_tokens_post_pad.data_ptr<int32_t>(),
adapter_enabled.data_ptr<int32_t>(), lora_ids.data_ptr<int32_t>());
});
}
15 changes: 7 additions & 8 deletions csrc/moe/moe_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,13 @@ void batched_moe_align_block_size(int64_t max_tokens_per_batch,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);

void moe_lora_align_block_size(torch::Tensor topk_ids,
torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size,
int64_t max_loras, int64_t max_num_tokens_padded,
int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids,
torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad);
void moe_lora_align_block_size(
torch::Tensor topk_ids, torch::Tensor token_lora_mapping,
int64_t num_experts, int64_t block_size, int64_t max_loras,
int64_t max_num_tokens_padded, int64_t max_num_m_blocks,
torch::Tensor sorted_token_ids, torch::Tensor expert_ids,
torch::Tensor num_tokens_post_pad, torch::Tensor adapter_enabled,
torch::Tensor lora_ids);
#ifndef USE_ROCM
torch::Tensor moe_wna16_gemm(torch::Tensor input, torch::Tensor output,
torch::Tensor b_qweight, torch::Tensor b_scales,
Expand Down
4 changes: 3 additions & 1 deletion csrc/moe/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) {
" int max_num_m_blocks, "
" Tensor !sorted_token_ids,"
" Tensor !experts_ids,"
" Tensor !num_tokens_post_pad) -> () ");
" Tensor !num_tokens_post_pad,"
" Tensor !adapter_enabled,"
" Tensor !lora_ids) -> () ");
m.impl("moe_lora_align_block_size", torch::kCUDA, &moe_lora_align_block_size);

#ifndef USE_ROCM
Expand Down
6 changes: 6 additions & 0 deletions tests/lora/test_fused_moe_lora_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,8 @@ def use_fused_moe_lora_kernel(
)
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
adapter_enabled = torch.ones(max_loras + 1, dtype=torch.int32)
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32)

# call kernel
ops.moe_lora_align_block_size(
Expand All @@ -147,6 +149,8 @@ def use_fused_moe_lora_kernel(
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
adapter_enabled,
lora_ids,
)

config = {
Expand All @@ -172,6 +176,8 @@ def use_fused_moe_lora_kernel(
num_tokens_post_padded,
max_lora_rank,
top_k_num,
lora_ids,
adapter_enabled,
config["BLOCK_SIZE_M"],
config["BLOCK_SIZE_N"],
config["BLOCK_SIZE_K"],
Expand Down
4 changes: 4 additions & 0 deletions tests/lora/test_moe_lora_align_sum.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def test_moe_lora_align_block_size(
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
)
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
adapter_enabled = torch.ones((max_loras + 1,), dtype=torch.int32, device="cuda")
lora_ids = torch.arange(max_loras + 2, dtype=torch.int32, device="cuda")

# call kernel
ops.moe_lora_align_block_size(
Expand All @@ -73,6 +75,8 @@ def test_moe_lora_align_block_size(
sorted_token_ids,
expert_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)

# verify values
Expand Down
64 changes: 57 additions & 7 deletions tests/lora/test_olmoe_tp.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project


import vllm
from vllm.lora.request import LoRARequest

Expand Down Expand Up @@ -28,8 +29,17 @@
"SELECT poll_source FROM candidate GROUP BY poll_source ORDER BY count(*) DESC LIMIT 1", # noqa: E501
]

EXPECTED_BASE_MODEL_OUTPUT = [
"SELECT COUNT(Candidate_ID) FROM candidate",
"SELECT COUNT(Candidate_ID) FROM candidate",
"SELECT Candidate_ID, COUNT(*) as Total_Candidates\nFROM candidate\nINNER JOIN people ON candidate.People_ID = people.People_ID", # noqa: E501
"SELECT Candidate_ID, Poll_Source FROM candidate WHERE People_ID IN (SELECT People_ID FROM people) ORDER BY COUNT(*) DESC LIMIT 1", # noqa: E501
]


def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
def generate_and_test(
llm: vllm.LLM, lora_path: str, lora_id: list[int | None] | int | None
) -> None:
prompts = [
PROMPT_TEMPLATE.format(context="How many candidates are there?"),
PROMPT_TEMPLATE.format(context="Count the number of candidates."),
Expand All @@ -40,12 +50,18 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
context="Return the poll resource associated with the most candidates."
),
]

lora_request = None
if isinstance(lora_id, int):
lora_request = LoRARequest(str(lora_id), lora_id, lora_path)
elif isinstance(lora_id, list):
lora_request = [
LoRARequest(str(i), i, lora_path) if i is not None else None
for i in lora_id
]

sampling_params = vllm.SamplingParams(temperature=0, max_tokens=64)
outputs = llm.generate(
prompts,
sampling_params,
lora_request=LoRARequest(str(lora_id), lora_id, lora_path) if lora_id else None,
)
outputs = llm.generate(prompts, sampling_params, lora_request=lora_request)
# Print the outputs.
generated_texts: list[str] = []
for output in outputs:
Expand All @@ -55,7 +71,13 @@ def generate_and_test(llm: vllm.LLM, lora_path: str, lora_id: int) -> None:
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")

for i in range(len(EXPECTED_LORA_OUTPUT)):
assert generated_texts[i].startswith(EXPECTED_LORA_OUTPUT[i])
req_lora_id = lora_id[i] if isinstance(lora_id, list) else lora_id
expected_output = (
EXPECTED_LORA_OUTPUT[i]
if req_lora_id is not None
else EXPECTED_BASE_MODEL_OUTPUT[i]
)
assert generated_texts[i].startswith(expected_output)


def test_olmoe_lora(olmoe_lora_files):
Expand All @@ -75,6 +97,34 @@ def test_olmoe_lora(olmoe_lora_files):
generate_and_test(llm, olmoe_lora_files, lora_id=2)


def test_olmoe_lora_base_model(olmoe_lora_files):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There's no need to add this test.

llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)

generate_and_test(llm, olmoe_lora_files, lora_id=None)


def test_olmoe_lora_mixed(olmoe_lora_files):
llm = vllm.LLM(
MODEL_PATH,
max_model_len=1024,
enable_lora=True,
max_loras=4,
enforce_eager=True,
trust_remote_code=True,
enable_chunked_prefill=True,
)

generate_and_test(llm, olmoe_lora_files, lora_id=[1, None, 3, None])


@multi_gpu_test(num_gpus=2)
def test_olmoe_lora_tp2(olmoe_lora_files):
llm = vllm.LLM(
Expand Down
4 changes: 4 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -1815,6 +1815,8 @@ def moe_lora_align_block_size(
sorted_token_ids: torch.Tensor,
experts_ids: torch.Tensor,
num_tokens_post_pad: torch.Tensor,
adapter_enabled: torch.Tensor,
lora_ids: torch.Tensor,
) -> None:
torch.ops._moe_C.moe_lora_align_block_size(
topk_ids,
Expand All @@ -1827,6 +1829,8 @@ def moe_lora_align_block_size(
sorted_token_ids,
experts_ids,
num_tokens_post_pad,
adapter_enabled,
lora_ids,
)


Expand Down
10 changes: 9 additions & 1 deletion vllm/lora/layers/fused_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@
config["BLOCK_SIZE_M"],
global_num_experts,
max_loras,
self.adapter_enabled,
expert_map,
)

Expand Down Expand Up @@ -147,6 +148,7 @@
max_lora_rank,
top_k,
config,
self.adapter_enabled,
)

result = func(*args, **kwargs)
Expand Down Expand Up @@ -193,7 +195,7 @@
intermediate_cache2 = moe_state_dict["intermediate_cache2"]
intermediate_cache3 = args[0]
max_lora_rank = self.w1_lora_a_stacked.shape[-2]
self.punica_wrapper.add_lora_fused_moe(

Check failure on line 198 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Too many arguments for "add_lora_fused_moe" of "PunicaWrapperBase" [call-arg]

Check failure on line 198 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Too many arguments for "add_lora_fused_moe" of "PunicaWrapperBase" [call-arg]

Check failure on line 198 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Too many arguments for "add_lora_fused_moe" of "PunicaWrapperBase" [call-arg]

Check failure on line 198 in vllm/lora/layers/fused_moe.py

View workflow job for this annotation

GitHub Actions / pre-commit

Too many arguments for "add_lora_fused_moe" of "PunicaWrapperBase" [call-arg]
intermediate_cache3,
intermediate_cache2,
[self.w2_lora_a_stacked],
Expand All @@ -205,6 +207,7 @@
max_lora_rank,
top_k,
config,
self.adapter_enabled,
True,
)

Expand Down Expand Up @@ -239,6 +242,9 @@
assert not self.base_layer.use_ep, (
"EP support for Fused MoE LoRA is not implemented yet."
)
self.adapter_enabled = torch.tensor(
[0] * (max_loras + 1), dtype=torch.int, device=self.device
)

self.w1_lora_a_stacked = torch.zeros(
(
Expand Down Expand Up @@ -326,6 +332,7 @@
self.w3_lora_b_stacked[index] = 0
self.w2_lora_a_stacked[index] = 0
self.w2_lora_b_stacked[index] = 0
self.adapter_enabled[index] = 0

def set_lora(
self,
Expand All @@ -335,8 +342,9 @@
embeddings_tensor: torch.Tensor | None,
bias: torch.Tensor | None = None,
):
self.reset_lora(index)
"""Overwrites lora tensors at index."""
self.reset_lora(index)
self.adapter_enabled[index] = 1
for eid in range(len(lora_a) // 3):
w1_lora_a = lora_a[eid * 3]
w2_lora_a = lora_a[eid * 3 + 1]
Expand Down
23 changes: 19 additions & 4 deletions vllm/lora/ops/triton_ops/fused_moe_lora_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def _fused_moe_lora_kernel(
EM,
num_valid_tokens,
num_experts,
lora_ids,
adapter_enabled,
# The stride variables represent how much to increase the ptr by when
# moving by 1 element in a particular dimension. E.g. `stride_am` is
# how much to increase `a_ptr` by to get the element one row down
Expand Down Expand Up @@ -84,6 +86,11 @@ def _fused_moe_lora_kernel(
pid = tl.program_id(axis=0)
slice_id = tl.program_id(axis=1)
lora_idx = tl.program_id(axis=2)
lora_id = tl.load(lora_ids + lora_idx)
moe_enabled = tl.load(adapter_enabled + lora_id)
if lora_id == -1 or moe_enabled == 0:
# Early exit for the no-lora case.
return
max_loras = tl.num_programs(axis=2)
grid_k = tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)

Expand All @@ -97,12 +104,12 @@ def _fused_moe_lora_kernel(
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m

num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_idx)
num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr + lora_id)
if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
return

# get the expert_id to process curr shard
ind = lora_idx * stride_el + pid_m
ind = lora_id * stride_el + pid_m
expert_id = tl.load(expert_ids_ptr + ind, ind < max_loras * stride_el, -1)
if expert_id == -1:
return
Expand All @@ -116,7 +123,7 @@ def _fused_moe_lora_kernel(
offs_k = tl.arange(0, BLOCK_SIZE_K)

offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
token_ind = stride_tl * lora_idx + offs_token_id
token_ind = stride_tl * lora_id + offs_token_id
offs_token = tl.load(
sorted_token_ids_ptr + token_ind, token_ind < max_loras * stride_tl, 0
)
Expand All @@ -129,7 +136,7 @@ def _fused_moe_lora_kernel(

b_ptrs = (
cur_b_ptr
+ lora_idx * stride_bl
+ lora_id * stride_bl
+ expert_id * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
Expand Down Expand Up @@ -180,6 +187,8 @@ def _fused_moe_lora(
num_tokens_post_padded: torch.Tensor, # (max_loras, )
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
block_size_m: int,
block_size_n: int,
block_size_k: int,
Expand Down Expand Up @@ -268,6 +277,8 @@ def _fused_moe_lora(
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
qcurr_hidden_states.stride(0),
qcurr_hidden_states.stride(1),
w1_lora_a_stacked.stride(0),
Expand Down Expand Up @@ -315,6 +326,8 @@ def _fused_moe_lora(
EM,
num_tokens,
num_experts,
lora_ids,
adapter_enabled,
a_intermediate_cache1.stride(0),
a_intermediate_cache1.stride(1),
w1_lora_b_stacked.stride(0),
Expand Down Expand Up @@ -348,6 +361,8 @@ def _fused_moe_lora_fake(
num_tokens_post_padded: torch.Tensor,
max_lora_rank: int,
top_k_num: int,
lora_ids: torch.Tensor,
adapter_enabled: torch.Tensor,
block_size_m: int,
block_size_n: int,
block_size_k: int,
Expand Down
Loading