Skip to content

Commit 48fd6ad

Browse files
committed
fix lora tests
1 parent 305a472 commit 48fd6ad

File tree

3 files changed

+15
-76
lines changed

3 files changed

+15
-76
lines changed

tests/lora/test_fused_moe_lora_kernel.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,11 @@ def use_fused_moe_lora_kernel(
135135
expert_ids = torch.empty((max_loras * max_num_m_blocks,), dtype=torch.int32)
136136
num_tokens_post_padded = torch.empty((max_loras,), dtype=torch.int32)
137137

138+
num_tokens_per_lora = torch.ones(max_loras+1, dtype=torch.int32)
139+
adapter_enabled = torch.ones(max_loras+1, dtype=torch.int32)
140+
lora_ids = torch.arange(1,max_loras+1, dtype=torch.int32)
141+
142+
138143
# call kernel
139144
ops.moe_lora_align_block_size(
140145
topk_ids,
@@ -147,6 +152,8 @@ def use_fused_moe_lora_kernel(
147152
sorted_token_ids,
148153
expert_ids,
149154
num_tokens_post_padded,
155+
num_tokens_per_lora,
156+
adapter_enabled,
150157
)
151158

152159
config = {
@@ -171,6 +178,8 @@ def use_fused_moe_lora_kernel(
171178
num_tokens_post_padded,
172179
max_lora_rank,
173180
top_k_num,
181+
lora_ids,
182+
adapter_enabled,
174183
config["BLOCK_SIZE_M"],
175184
config["BLOCK_SIZE_N"],
176185
config["BLOCK_SIZE_K"],

tests/lora/test_moe_lora_align_sum.py

Lines changed: 4 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -60,15 +60,13 @@ def test_moe_lora_align_block_size(
6060
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
6161
)
6262
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
63-
num_tokens_per_lora = torch.ones((max_loras,), dtype=torch.int32, device="cuda")
64-
adapter_enabled = torch.ones((max_loras,), dtype=torch.int32, device="cuda")
63+
num_tokens_per_lora = torch.ones((max_loras+1,), dtype=torch.int32, device="cuda")
64+
adapter_enabled = torch.ones((max_loras+1,), dtype=torch.int32, device="cuda")
6565

6666
# call kernel
6767
ops.moe_lora_align_block_size(
6868
topk_ids,
6969
token_lora_mapping,
70-
num_tokens_per_lora,
71-
adapter_enabled,
7270
num_experts,
7371
block_size,
7472
max_loras,
@@ -77,6 +75,8 @@ def test_moe_lora_align_block_size(
7775
sorted_token_ids,
7876
expert_ids,
7977
num_tokens_post_pad,
78+
num_tokens_per_lora,
79+
adapter_enabled,
8080
)
8181

8282
# verify values
@@ -91,73 +91,3 @@ def test_moe_lora_align_block_size(
9191
expert_id = expert_ids[lora_idx][token_idx]
9292
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
9393

94-
@pytest.mark.parametrize("num_tokens", [4096])
95-
@pytest.mark.parametrize("topk_num", [6])
96-
@pytest.mark.parametrize("num_experts", [64])
97-
@pytest.mark.parametrize("max_loras", [2])
98-
@pytest.mark.parametrize("block_size", [16])
99-
@pytest.mark.parametrize("adapter_enabled", [[0,1],[0,0]])
100-
def test_moe_lora_align_block_size_early_exit(
101-
num_tokens, topk_num, num_experts, max_loras, block_size, adapter_enabled
102-
):
103-
104-
# sample data
105-
random.seed(1)
106-
topk_ids, token_lora_mapping = sample_data(
107-
num_experts, max_loras, num_tokens, topk_num
108-
)
109-
110-
# compute paddings
111-
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
112-
max_num_tokens_padded = round_up(max_num_tokens_padded, block_size)
113-
max_num_m_blocks = CEILDIV(max_num_tokens_padded, block_size)
114-
115-
# init output tensors
116-
sorted_token_ids = torch.full(
117-
(max_loras * max_num_tokens_padded,),
118-
topk_ids.numel(),
119-
dtype=torch.int32,
120-
device="cuda",
121-
)
122-
expert_ids = torch.full(
123-
(max_loras * max_num_m_blocks,), num_experts, dtype=torch.int32, device="cuda"
124-
)
125-
num_tokens_post_pad = torch.zeros((max_loras,), dtype=torch.int32, device="cuda")
126-
127-
num_tokens_per_lora = torch.ones((max_loras,), dtype=torch.int32, device="cuda")
128-
adapter_enabled = torch.tensor(adapter_enabled, dtype=torch.int32, device="cuda")
129-
130-
# call kernel
131-
ops.moe_lora_align_block_size(
132-
topk_ids,
133-
token_lora_mapping,
134-
num_tokens_per_lora,
135-
adapter_enabled,
136-
num_experts,
137-
block_size,
138-
max_loras,
139-
sorted_token_ids,
140-
expert_ids,
141-
num_tokens_post_pad,
142-
)
143-
144-
# verify values
145-
expert_ids = expert_ids.view(max_loras, -1)
146-
sorted_token_ids = sorted_token_ids.view(max_loras, -1, block_size)
147-
148-
for lora_idx in range(max_loras):
149-
150-
# assert not operation was performed
151-
if adapter_enabled[lora_idx].item() == 0:
152-
assert torch.all(sorted_token_ids[lora_idx] == topk_ids.numel())
153-
else:
154-
for token_idx in range(sorted_token_ids.size(1)):
155-
block = sorted_token_ids[lora_idx][token_idx]
156-
indices = block[block != topk_ids.numel()]
157-
if indices.numel() > 0:
158-
expert_id = expert_ids[lora_idx][token_idx]
159-
assert torch.all(topk_ids.view(-1)[indices] == expert_id)
160-
161-
162-
if __name__ == "__main__":
163-
pytest.main([__file__])

vllm/lora/ops/triton_ops/fused_moe_lora_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -337,12 +337,12 @@ def _fused_moe_lora_fake(
337337
num_tokens_post_padded: torch.Tensor,
338338
max_lora_rank: int,
339339
top_k_num: int,
340+
lora_ids: torch.Tensor,
341+
adapter_enabled: torch.Tensor,
340342
block_size_m: int,
341343
block_size_n: int,
342344
block_size_k: int,
343345
group_size_m: int,
344-
lora_ids: torch.Tensor,
345-
adapter_enabled: torch.Tensor,
346346
mul_routed_weight: bool = False,
347347
) -> None:
348348
return

0 commit comments

Comments
 (0)