@@ -16,6 +16,12 @@ struct AttentionCausualMask {
1616 }
1717};
1818
19+ struct MaxOp {
20+ __device__ float operator ()(const float a, const float b) const {
21+ return a > b ? a: b;
22+ }
23+ };
24+
1925template <unsigned int BLOCK_SIZE, class Tdata , class Tmask >
2026static __device__ void block_padding (
2127 Tdata *__restrict__ att,
@@ -33,7 +39,12 @@ static __device__ void block_padding(
3339
3440 __shared__ float max;
3541 {
42+ #ifdef ENABLE_SUGON_DCU
43+ MaxOp max_op;
44+ auto acc = block_op.Reduce (thread_data, max_op, total_seq_len);
45+ #else
3646 auto acc = block_op.Reduce (thread_data, cub::Max (), total_seq_len);
47+ #endif
3748 if (threadIdx .x == 0 ) { max = acc; }
3849 }
3950 __syncthreads ();
@@ -67,7 +78,12 @@ static __device__ void block_folding(
6778 thread_data[i] = att_idx < total_seq_len && mask (token_idx, seq_len, att_idx, total_seq_len)
6879 ? float (att[i])
6980 : -__FLT_MAX__;
81+ #ifdef ENABLE_SUGON_DCU
82+ MaxOp max_op;
83+ thread_max = max_op (thread_max, thread_data[i]);
84+ #else
7085 thread_max = cub::Max ()(thread_max, thread_data[i]);
86+ #endif
7187 }
7288
7389 using BlockOp = cub::BlockReduce<float , BLOCK_SIZE>;
@@ -76,7 +92,12 @@ static __device__ void block_folding(
7692
7793 __shared__ float max;
7894 {
95+ #ifdef ENABLE_SUGON_DCU
96+ MaxOp max_op;
97+ auto acc = block_op.Reduce (thread_max, max_op);
98+ #else
7999 auto acc = block_op.Reduce (thread_max, cub::Max ());
100+ #endif
80101 if (threadIdx .x == 0 ) { max = acc; }
81102 }
82103 __syncthreads ();
@@ -130,7 +151,7 @@ static __forceinline__ __device__ void folding(
130151}
131152
132153template <unsigned int BLOCK_SIZE, class Tdata >
133- __global__ void fused_softmax_padding (
154+ __launch_bounds__ (MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_padding(
134155 Tdata *__restrict__ att,
135156 unsigned int const stride_x,
136157 unsigned int const stride_y,
@@ -140,7 +161,7 @@ __global__ void fused_softmax_padding(
140161}
141162
142163template <unsigned int BLOCK_SIZE, unsigned int ITEMS_PER_THREAD, class Tdata >
143- __global__ void fused_softmax_folding (
164+ __launch_bounds__ (MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_folding(
144165 Tdata *__restrict__ att,
145166 unsigned int const stride_x,
146167 unsigned int const stride_y,
@@ -152,7 +173,7 @@ __global__ void fused_softmax_folding(
152173}
153174
154175template <unsigned int BLOCK_SIZE, class Tdata >
155- __global__ void fused_softmax_standard (
176+ __launch_bounds__ (MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_standard(
156177 Tdata *__restrict__ att_,
157178 unsigned int const stride_x,
158179 unsigned int const stride_y,
@@ -183,7 +204,12 @@ __global__ void fused_softmax_standard(
183204 __syncthreads ();
184205 // Block reduce max
185206 {
207+ #ifdef ENABLE_SUGON_DCU
208+ MaxOp max_op;
209+ auto acc = block_op.Reduce (partial, max_op);
210+ #else
186211 auto acc = block_op.Reduce (partial, cub::Max ());
212+ #endif
187213 if (threadIdx .x == 0 ) { max_ = acc; }
188214 }
189215 __syncthreads ();
@@ -200,7 +226,11 @@ __global__ void fused_softmax_standard(
200226
201227 // Block reduce sum
202228 {
229+ #ifdef ENABLE_SUGON_DCU
230+ auto acc = block_op.Sum (partial);
231+ #else
203232 auto acc = block_op.Reduce (partial, cub::Sum ());
233+ #endif
204234 if (threadIdx .x == 0 ) { sum_ = acc; }
205235 }
206236 __syncthreads ();
0 commit comments