22# SPDX-FileCopyrightText: Copyright contributors to the vLLM project 
33
44import  torch 
5- import  triton 
6- import  triton .language  as  tl 
75
6+ from  vllm .triton_utils  import  tl , triton 
87from  vllm .utils .torch_utils  import  direct_register_custom_op 
98
109_LORA_PTR_DICT : dict [tuple [int , ...], torch .tensor ] =  {}
@@ -110,7 +109,7 @@ def _fused_moe_lora_kernel(
110109
111110    # get a_ptr,b_ptr,c_ptr 
112111    cur_a_ptr  =  a_ptr  +  (slice_id  %  num_slice_a ) *  slice_a_size 
113-     cur_b_ptr  =  tl .load (b_ptr  +  slice_id ).to (tl .pointer_type (tl . bfloat16 ))
112+     cur_b_ptr  =  tl .load (b_ptr  +  slice_id ).to (tl .pointer_type (c_ptr . dtype . element_ty ))
114113    cur_c_ptr  =  c_ptr  +  (slice_id  %  num_slice_c ) *  slice_c_size 
115114
116115    offs_bn  =  (pid_n  *  BLOCK_SIZE_N  +  tl .arange (0 , BLOCK_SIZE_N ).to (tl .int64 )) %  N 
@@ -154,7 +153,7 @@ def _fused_moe_lora_kernel(
154153        moe_weight  =  tl .load (topk_weights_ptr  +  offs_token , mask = token_mask , other = 0 )
155154        accumulator  =  accumulator  *  moe_weight [:, None ]
156155
157-     accumulator  =  accumulator .to (tl . bfloat16 )
156+     accumulator  =  accumulator .to (c_ptr . dtype . element_ty )
158157    # Write back the block of the output 
159158    offs_cn  =  pid_n  *  BLOCK_SIZE_N  +  tl .arange (0 , BLOCK_SIZE_N )
160159    c_ptrs  =  cur_c_ptr  +  stride_cm  *  offs_token [:, None ] +  stride_cn  *  offs_cn [None , :]
@@ -205,6 +204,10 @@ def _fused_moe_lora(
205204    assert  output .shape [0 ] ==  topk_weights .shape [0 ]
206205    assert  top_k_num  ==  topk_weights .shape [1 ]
207206
207+     for  lora_a , lora_b  in  zip (lora_a_stacked , lora_b_stacked ):
208+         assert  lora_a .dtype  ==  lora_b .dtype  ==  output .dtype  ==  qcurr_hidden_states .dtype 
209+         assert  lora_a .dtype  in  [torch .float16 , torch .bfloat16 ]
210+ 
208211    device  =  qcurr_hidden_states .device 
209212    num_slices  =  len (lora_a_stacked )
210213
@@ -227,9 +230,9 @@ def _fused_moe_lora(
227230    num_tokens  =  M  *  top_k_num 
228231    w1_output_dim_size  =  w1_lora_b_stacked .shape [2 ]
229232
230-     lora_intermediate_cache1  =  torch .zeros (
233+     lora_intermediate_cache1  =  torch .empty (
231234        (num_slices  *  M  *  top_k_num  *  (max_lora_rank  +  w1_output_dim_size )),
232-         dtype = torch . bfloat16 ,
235+         dtype = output . dtype ,
233236        device = device ,
234237    )
235238
@@ -288,10 +291,6 @@ def _fused_moe_lora(
288291    K  =  max_lora_rank 
289292    N  =  w1_output_dim_size 
290293
291-     # a_intermediate_cache1 = a_intermediate_cache1.view( 
292-     #     M, -1, a_intermediate_cache1.shape[3] 
293-     # ) 
294- 
295294    a_intermediate_cache1  =  a_intermediate_cache1 .view (
296295        - 1 , a_intermediate_cache1 .shape [3 ]
297296    )
0 commit comments