1+ #include < stdio.h>
2+ #include < stdlib.h>
3+ #include < time.h>
4+ #include < torch/all.h>
5+ #include < ATen/cuda/CUDAContext.h>
6+ #include < c10/cuda/CUDAGuard.h>
7+
8+ #include < ATen/ATen.h>
9+ #include < ATen/cuda/Atomic.cuh>
10+
11+ #include " ../cuda_compat.h"
12+ #include " ../dispatch_utils.h"
13+ #include " core/math.hpp"
14+
15+ namespace {
16+
17+ __device__ __forceinline__ int32_t index (int32_t total_col, int32_t row,
18+ int32_t col) {
19+ return row * total_col + col;
20+ }
21+
22+ } // namespace
23+
24+ // TODO: Refactor common parts with moe_align_sum_kernels
25+ template <typename scalar_t , typename token_cnts_t >
26+ __global__ void moe_lora_align_sum_kernel (
27+ scalar_t * __restrict__ topk_ids, int32_t * token_lora_mapping,
28+ int64_t block_size, int num_experts, int max_loras, size_t numel,
29+ int max_num_tokens_padded, int max_num_m_blocks,
30+ int32_t * __restrict__ sorted_token_ids, int32_t * __restrict__ expert_ids,
31+ int topk_num, int32_t * total_tokens_post_pad) {
32+ const size_t tokens_per_thread = div_ceil (numel, blockDim .x );
33+ const size_t start_idx = threadIdx .x * tokens_per_thread;
34+
35+ int lora_id = blockIdx .x ;
36+ extern __shared__ int32_t shared_mem[];
37+ int32_t * cumsum = shared_mem;
38+ token_cnts_t * tokens_cnts = (token_cnts_t *)(shared_mem + num_experts + 1 );
39+
40+ // Initialize sorted_token_ids with numel
41+ for (size_t it = threadIdx .x ; it < max_num_tokens_padded; it += blockDim .x ) {
42+ sorted_token_ids[lora_id * max_num_tokens_padded + it] = numel;
43+ }
44+
45+ // Initialize expert_ids with -1
46+ for (size_t it = threadIdx .x ; it < max_num_m_blocks; it += blockDim .x ) {
47+ expert_ids[lora_id * max_num_m_blocks + it] = -1 ;
48+ }
49+
50+ // Initialize total_tokens_post_pad with 0
51+ if (threadIdx .x == 0 ) {
52+ total_tokens_post_pad[lora_id] = 0 ;
53+ }
54+
55+ for (int i = 0 ; i < num_experts; ++i) {
56+ tokens_cnts[index (num_experts, threadIdx .x + 1 , i)] = 0 ;
57+ }
58+
59+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
60+ int mask = token_lora_mapping[i / topk_num] == lora_id;
61+ int idx = index (num_experts, threadIdx .x + 1 , topk_ids[i]);
62+ tokens_cnts[idx] += mask;
63+ }
64+
65+ __syncthreads ();
66+
67+ // For each expert we accumulate the token counts from the different threads.
68+ if (threadIdx .x < num_experts) {
69+ tokens_cnts[index (num_experts, 0 , threadIdx .x )] = 0 ;
70+ for (int i = 1 ; i <= blockDim .x ; ++i) {
71+ tokens_cnts[index (num_experts, i, threadIdx .x )] +=
72+ tokens_cnts[index (num_experts, i - 1 , threadIdx .x )];
73+ }
74+ }
75+
76+ __syncthreads ();
77+
78+ // We accumulate the token counts of all experts in thread 0.
79+ if (threadIdx .x == 0 ) {
80+ cumsum[0 ] = 0 ;
81+ for (int i = 1 ; i <= num_experts; ++i) {
82+ cumsum[i] = cumsum[i - 1 ] +
83+ div_ceil (tokens_cnts[index (num_experts, blockDim .x , i - 1 )],
84+ block_size) *
85+ block_size;
86+ }
87+ total_tokens_post_pad[lora_id] = static_cast <int32_t >(cumsum[num_experts]);
88+ }
89+
90+ __syncthreads ();
91+
92+ /* *
93+ * For each expert, each thread processes the tokens of the corresponding
94+ * blocks and stores the corresponding expert_id for each block.
95+ */
96+ if (threadIdx .x < num_experts) {
97+ for (int i = cumsum[threadIdx .x ]; i < cumsum[threadIdx .x + 1 ];
98+ i += block_size) {
99+ expert_ids[index (max_num_m_blocks, lora_id, i / block_size)] =
100+ threadIdx .x ;
101+ }
102+ }
103+
104+ for (int i = start_idx; i < numel && i < start_idx + tokens_per_thread; ++i) {
105+ int32_t expert_id = topk_ids[i];
106+ /* * The cumsum[expert_id] stores the starting index of the tokens that the
107+ * expert with expert_id needs to process, and
108+ * tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
109+ * processed by the expert with expert_id within the current thread's token
110+ * shard.
111+ */
112+ int32_t rank_post_pad =
113+ tokens_cnts[index (num_experts, threadIdx .x , expert_id)] +
114+ cumsum[expert_id];
115+
116+ int mask = (int )token_lora_mapping[i / topk_num] == lora_id;
117+ atomicAdd (
118+ &sorted_token_ids[index (max_num_tokens_padded, lora_id, rank_post_pad)],
119+ (i - numel) * mask);
120+ tokens_cnts[index (num_experts, threadIdx .x , expert_id)] += mask;
121+ }
122+ }
123+
124+ void moe_lora_align_block_size (torch::Tensor topk_ids,
125+ torch::Tensor token_lora_mapping,
126+ int64_t num_experts, int64_t block_size,
127+ int64_t max_loras,
128+ torch::Tensor sorted_token_ids,
129+ torch::Tensor expert_ids,
130+ torch::Tensor num_tokens_post_pad) {
131+ const int topk_num = topk_ids.size (1 );
132+
133+ int max_num_tokens_padded = topk_ids.numel () + num_experts * (block_size - 1 );
134+
135+ TORCH_CHECK (block_size > 0 , " block_size should be greater than 0. " );
136+ max_num_tokens_padded = round_to_next_multiple_of (
137+ max_num_tokens_padded, static_cast <int >(block_size));
138+ int max_num_m_blocks = div_ceil (max_num_tokens_padded, block_size);
139+
140+ int device_max_shared_mem;
141+ auto dev = topk_ids.get_device ();
142+ cudaDeviceGetAttribute (&device_max_shared_mem,
143+ cudaDevAttrMaxSharedMemoryPerBlockOptin, dev);
144+ const cudaStream_t stream = at::cuda::getCurrentCUDAStream ();
145+
146+ const int32_t num_thread = max ((int32_t )num_experts, 128 ); // WARP_SIZE,
147+ TORCH_CHECK (num_thread <= 1024 ,
148+ " num_thread must be less than 1024, "
149+ " and fallback is not implemented yet." );
150+ const int32_t shared_mem = (num_thread + 1 ) * num_experts * sizeof (int32_t ) +
151+ (num_experts + 1 ) * sizeof (int32_t );
152+
153+ if (shared_mem > device_max_shared_mem) {
154+ TORCH_CHECK (false ,
155+ " Shared memory usage exceeds device limit, and global memory "
156+ " fallback is not implemented yet." );
157+ }
158+
159+ VLLM_DISPATCH_INTEGRAL_TYPES (
160+ topk_ids.scalar_type (), " moe_lora_align_sum_kernel" , [&] {
161+ dim3 blockDim (num_thread);
162+ auto kernel = moe_lora_align_sum_kernel<scalar_t , int32_t >;
163+ AT_CUDA_CHECK (VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize (
164+ (void *)kernel, shared_mem));
165+ kernel<<<max_loras, blockDim , shared_mem, stream>>> (
166+ topk_ids.data_ptr <scalar_t >(),
167+ token_lora_mapping.data_ptr <int32_t >(), block_size, num_experts,
168+ max_loras, topk_ids.numel (), max_num_tokens_padded,
169+ max_num_m_blocks, sorted_token_ids.data_ptr <int32_t >(),
170+ expert_ids.data_ptr <int32_t >(), topk_num,
171+ num_tokens_post_pad.data_ptr <int32_t >());
172+ });
173+ }
0 commit comments