@@ -60,15 +60,13 @@ def test_moe_lora_align_block_size(
6060        (max_loras  *  max_num_m_blocks ,), num_experts , dtype = torch .int32 , device = "cuda" 
6161    )
6262    num_tokens_post_pad  =  torch .zeros ((max_loras ,), dtype = torch .int32 , device = "cuda" )
63-     num_tokens_per_lora  =  torch .ones ((max_loras ,), dtype = torch .int32 , device = "cuda" )
64-     adapter_enabled  =  torch .ones ((max_loras ,), dtype = torch .int32 , device = "cuda" )
63+     num_tokens_per_lora  =  torch .ones ((max_loras + 1 ,), dtype = torch .int32 , device = "cuda" )
64+     adapter_enabled  =  torch .ones ((max_loras + 1 ,), dtype = torch .int32 , device = "cuda" )
6565
6666    # call kernel 
6767    ops .moe_lora_align_block_size (
6868        topk_ids ,
6969        token_lora_mapping ,
70-         num_tokens_per_lora ,
71-         adapter_enabled ,
7270        num_experts ,
7371        block_size ,
7472        max_loras ,
@@ -77,6 +75,8 @@ def test_moe_lora_align_block_size(
7775        sorted_token_ids ,
7876        expert_ids ,
7977        num_tokens_post_pad ,
78+         num_tokens_per_lora ,
79+         adapter_enabled ,
8080    )
8181
8282    # verify values 
@@ -91,73 +91,3 @@ def test_moe_lora_align_block_size(
9191                expert_id  =  expert_ids [lora_idx ][token_idx ]
9292                assert  torch .all (topk_ids .view (- 1 )[indices ] ==  expert_id )
9393
94- @pytest .mark .parametrize ("num_tokens" , [4096 ]) 
95- @pytest .mark .parametrize ("topk_num" , [6 ]) 
96- @pytest .mark .parametrize ("num_experts" , [64 ]) 
97- @pytest .mark .parametrize ("max_loras" , [2 ]) 
98- @pytest .mark .parametrize ("block_size" , [16 ]) 
99- @pytest .mark .parametrize ("adapter_enabled" , [[0 ,1 ],[0 ,0 ]]) 
100- def  test_moe_lora_align_block_size_early_exit (
101-     num_tokens , topk_num , num_experts , max_loras , block_size , adapter_enabled 
102- ):
103- 
104-     # sample data 
105-     random .seed (1 )
106-     topk_ids , token_lora_mapping  =  sample_data (
107-         num_experts , max_loras , num_tokens , topk_num 
108-     )
109- 
110-     # compute paddings 
111-     max_num_tokens_padded  =  topk_ids .numel () +  num_experts  *  (block_size  -  1 )
112-     max_num_tokens_padded  =  round_up (max_num_tokens_padded , block_size )
113-     max_num_m_blocks  =  CEILDIV (max_num_tokens_padded , block_size )
114- 
115-     # init output tensors 
116-     sorted_token_ids  =  torch .full (
117-         (max_loras  *  max_num_tokens_padded ,),
118-         topk_ids .numel (),
119-         dtype = torch .int32 ,
120-         device = "cuda" ,
121-     )
122-     expert_ids  =  torch .full (
123-         (max_loras  *  max_num_m_blocks ,), num_experts , dtype = torch .int32 , device = "cuda" 
124-     )
125-     num_tokens_post_pad  =  torch .zeros ((max_loras ,), dtype = torch .int32 , device = "cuda" )
126-     
127-     num_tokens_per_lora  =  torch .ones ((max_loras ,), dtype = torch .int32 , device = "cuda" )
128-     adapter_enabled  =  torch .tensor (adapter_enabled , dtype = torch .int32 , device = "cuda" )
129- 
130-     # call kernel 
131-     ops .moe_lora_align_block_size (
132-         topk_ids ,
133-         token_lora_mapping ,
134-         num_tokens_per_lora ,
135-         adapter_enabled ,
136-         num_experts ,
137-         block_size ,
138-         max_loras ,
139-         sorted_token_ids ,
140-         expert_ids ,
141-         num_tokens_post_pad ,
142-     )
143- 
144-     # verify values 
145-     expert_ids  =  expert_ids .view (max_loras , - 1 )
146-     sorted_token_ids  =  sorted_token_ids .view (max_loras , - 1 , block_size )
147- 
148-     for  lora_idx  in  range (max_loras ):
149- 
150-         # assert not operation was performed 
151-         if  adapter_enabled [lora_idx ].item () ==  0 :
152-             assert  torch .all (sorted_token_ids [lora_idx ] ==  topk_ids .numel ())
153-         else :
154-             for  token_idx  in  range (sorted_token_ids .size (1 )):
155-                 block  =  sorted_token_ids [lora_idx ][token_idx ]
156-                 indices  =  block [block  !=  topk_ids .numel ()]
157-                 if  indices .numel () >  0 :
158-                     expert_id  =  expert_ids [lora_idx ][token_idx ]
159-                     assert  torch .all (topk_ids .view (- 1 )[indices ] ==  expert_id )
160- 
161- 
162- if  __name__  ==  "__main__" :
163-     pytest .main ([__file__ ])
0 commit comments