Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/devices/cuda/common_cuda.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
#ifndef __COMMON_CUDA_H__
#define __COMMON_CUDA_H__

#ifdef ENABLE_SUGON_DCU
#define MAX_THREADS_PER_BLOCK 512
#else
#define MAX_THREADS_PER_BLOCK 1024
#endif

#define MAX_WARP_PER_BLOCK 32
#define WARP_SIZE 32

Expand Down
36 changes: 33 additions & 3 deletions src/ops/causal_softmax/cuda/causal_softmax.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ struct AttentionCausualMask {
}
};

struct MaxOp {
__device__ float operator()(const float a, const float b) const {
return a > b ? a: b;
}
};

template<unsigned int BLOCK_SIZE, class Tdata, class Tmask>
static __device__ void block_padding(
Tdata *__restrict__ att,
Expand All @@ -33,7 +39,12 @@ static __device__ void block_padding(

__shared__ float max;
{
#ifdef ENABLE_SUGON_DCU
MaxOp max_op;
auto acc = block_op.Reduce(thread_data, max_op, total_seq_len);
#else
auto acc = block_op.Reduce(thread_data, cub::Max(), total_seq_len);
#endif
if (threadIdx.x == 0) { max = acc; }
}
__syncthreads();
Expand Down Expand Up @@ -67,7 +78,12 @@ static __device__ void block_folding(
thread_data[i] = att_idx < total_seq_len && mask(token_idx, seq_len, att_idx, total_seq_len)
? float(att[i])
: -__FLT_MAX__;
#ifdef ENABLE_SUGON_DCU
MaxOp max_op;
thread_max = max_op(thread_max, thread_data[i]);
#else
thread_max = cub::Max()(thread_max, thread_data[i]);
#endif
}

using BlockOp = cub::BlockReduce<float, BLOCK_SIZE>;
Expand All @@ -76,7 +92,12 @@ static __device__ void block_folding(

__shared__ float max;
{
#ifdef ENABLE_SUGON_DCU
MaxOp max_op;
auto acc = block_op.Reduce(thread_max, max_op);
#else
auto acc = block_op.Reduce(thread_max, cub::Max());
#endif
if (threadIdx.x == 0) { max = acc; }
}
__syncthreads();
Expand Down Expand Up @@ -130,7 +151,7 @@ static __forceinline__ __device__ void folding(
}

template<unsigned int BLOCK_SIZE, class Tdata>
__global__ void fused_softmax_padding(
__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_padding(
Tdata *__restrict__ att,
unsigned int const stride_x,
unsigned int const stride_y,
Expand All @@ -140,7 +161,7 @@ __global__ void fused_softmax_padding(
}

template<unsigned int BLOCK_SIZE, unsigned int ITEMS_PER_THREAD, class Tdata>
__global__ void fused_softmax_folding(
__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_folding(
Tdata *__restrict__ att,
unsigned int const stride_x,
unsigned int const stride_y,
Expand All @@ -152,7 +173,7 @@ __global__ void fused_softmax_folding(
}

template<unsigned int BLOCK_SIZE, class Tdata>
__global__ void fused_softmax_standard(
__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void fused_softmax_standard(
Tdata *__restrict__ att_,
unsigned int const stride_x,
unsigned int const stride_y,
Expand Down Expand Up @@ -183,7 +204,12 @@ __global__ void fused_softmax_standard(
__syncthreads();
// Block reduce max
{
#ifdef ENABLE_SUGON_DCU
MaxOp max_op;
auto acc = block_op.Reduce(partial, max_op);
#else
auto acc = block_op.Reduce(partial, cub::Max());
#endif
if (threadIdx.x == 0) { max_ = acc; }
}
__syncthreads();
Expand All @@ -200,7 +226,11 @@ __global__ void fused_softmax_standard(

// Block reduce sum
{
#ifdef ENABLE_SUGON_DCU
auto acc = block_op.Sum(partial);
#else
auto acc = block_op.Reduce(partial, cub::Sum());
#endif
if (threadIdx.x == 0) { sum_ = acc; }
}
__syncthreads();
Expand Down
16 changes: 7 additions & 9 deletions src/ops/matmul/cuda/matmul_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,18 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v
std::swap(a, b);
}

Tdata alpha_, beta_;
cudaDataType a_type, b_type, c_type;
cublasComputeType_t compute_type;

if constexpr (std::is_same<Tdata, half>::value) {
alpha_ = __float2half(alpha);
beta_ = __float2half(beta);
a_type = b_type = c_type = CUDA_R_16F;
compute_type = CUBLAS_COMPUTE_16F;
compute_type = CUBLAS_COMPUTE_32F;
} else {
alpha_ = alpha;
beta_ = beta;
a_type = b_type = c_type = CUDA_R_32F;
#ifdef ENABLE_SUGON_DCU
compute_type = CUBLAS_COMPUTE_32F;
#else
compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
#endif
}

auto op_a = info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T;
Expand All @@ -40,7 +38,7 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v
info.m,
info.n,
info.k,
&alpha_,
&alpha,
a,
a_type,
info.a_matrix.ld(),
Expand All @@ -49,7 +47,7 @@ infiniopStatus_t matmul_cuda(MatmulCudaDescriptor_t desc, void *c, float beta, v
b_type,
info.b_matrix.ld(),
info.b_matrix.stride,
&beta_,
&beta,
c,
c_type,
info.c_matrix.ld(),
Expand Down
14 changes: 7 additions & 7 deletions src/ops/random_sample/cuda/random_sample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
#include <cub/cub.cuh>

template<class T, int BLOCK_DIM>
__global__ void softmax(
__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void softmax(
T *val_out,
int topk,
float temperature, int voc) {
Expand All @@ -29,14 +29,14 @@ __global__ void softmax(
}
}

__global__ void index(uint64_t *key_in, int voc) {
__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void index(uint64_t *key_in, int voc) {
int ind = threadIdx.x + blockIdx.x * blockDim.x;
if (ind < voc) {
key_in[ind] = static_cast<uint64_t>(ind);
}
}
template<class T>
__global__ void random_sample_kernel(uint64_t *result,
__launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void random_sample_kernel(uint64_t *result,
T *val_out,
float random_val,
float topp,
Expand Down Expand Up @@ -119,7 +119,9 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
uint64_t *key_in = (uint64_t *) keyTmp;
uint64_t *key_out = key_in + voc;

index<<<(voc + 1023) / 1024, 1024, 0, (cudaStream_t) stream>>>(key_in, voc);
int block_dim = MAX_THREADS_PER_BLOCK;
int num_blocks = ROUND_UP_DIV(voc, block_dim);
index<<<num_blocks, block_dim, 0, (cudaStream_t) stream>>>(key_in, voc);
//下面开始计算workspace空间
size_t size_radix_sort;
size_t size_scan;
Expand All @@ -134,9 +136,7 @@ void random_sample_nv_gpu_f16(RandomSampleCudaDescriptor_t desc, void *workspace
voc, (cudaStream_t) stream);//该函数会把排序结果和对应索引保存在val_out和key_out上
//排序结束,然后开始做softmax变换
if (topp > 0 && topk > 1) {
int BLOCK_DIM = 1024;
int num_blocks = (voc + BLOCK_DIM - 1) / BLOCK_DIM;
softmax<half, 1024><<<num_blocks, BLOCK_DIM, 0, (cudaStream_t) stream>>>(val_out, topk,
softmax<half, MAX_THREADS_PER_BLOCK><<<num_blocks, block_dim, 0, (cudaStream_t) stream>>>(val_out, topk,
temperature, voc);


Expand Down
9 changes: 5 additions & 4 deletions src/ops/rearrange/cuda/rearrange.cu
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
#include "../../../devices/cuda/common_cuda.h"
#include "rearrange.cuh"
#include "../../utils.h"

template<class Tmem>
static __global__ void rearrange(
static __launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void rearrange(
void *__restrict__ dst,
int const rsa,
int const csa,
Expand Down Expand Up @@ -35,9 +36,9 @@ void rearrange_nv_gpu(RearrangeCudaDescriptor_t desc, void *y, void const *x, vo
return;
}

auto warps = 1024 / WARP_SIZE;
auto grid = dim3((c + warps - 1) / warps, r);
auto block = dim3(WARP_SIZE, (c + grid.x - 1) / grid.x);
auto warps = MAX_THREADS_PER_BLOCK / WARP_SIZE;
auto grid = dim3(ROUND_UP_DIV(c, warps), r);
auto block = dim3(WARP_SIZE, ROUND_UP_DIV(c, grid.x));
dst_rs /= unit;
dst_cs /= unit;
src_rs /= unit;
Expand Down
13 changes: 10 additions & 3 deletions src/ops/rms_norm/cuda/rms_norm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

// assert BLOCK_SIZE >= blockDim.x
template<unsigned int BLOCK_SIZE, class Tdata, class Wdata>
static __global__ void rms_norm_padding(
__launch_bounds__(MAX_THREADS_PER_BLOCK) static __global__ void rms_norm_padding(
Tdata *__restrict__ o_,
unsigned int const stride_y,
Tdata const *__restrict__ x_,
Expand All @@ -19,8 +19,11 @@ static __global__ void rms_norm_padding(

using BlockOp = cub::BlockReduce<float, BLOCK_SIZE>;
__shared__ typename BlockOp::TempStorage temp_storage;
#ifdef ENABLE_SUGON_DCU
auto acc = BlockOp(temp_storage).Sum(x * x);
#else
auto acc = BlockOp(temp_storage).Reduce(x * x, cub::Sum());

#endif
__shared__ Tdata rms;
if (threadIdx.x == 0) {
rms = Tdata(rsqrtf(acc / float(blockDim.x) + epsilon));
Expand All @@ -31,7 +34,7 @@ static __global__ void rms_norm_padding(
}

template<unsigned int BLOCK_SIZE, unsigned int ITEMS_PER_THREAD, class Tdata, class Wdata>
static __global__ void rms_norm_folding(
__launch_bounds__(MAX_THREADS_PER_BLOCK) static __global__ void rms_norm_folding(
Tdata *__restrict__ y,
unsigned int const stride_y,
Tdata const *__restrict__ x,
Expand Down Expand Up @@ -59,7 +62,11 @@ static __global__ void rms_norm_folding(
{
using BlockOp = cub::BlockReduce<float, BLOCK_SIZE>;
__shared__ typename BlockOp::TempStorage temp_storage;
#ifdef ENABLE_SUGON_DCU
acc = BlockOp(temp_storage).Sum(squared);
#else
acc = BlockOp(temp_storage).Reduce(squared, cub::Sum());
#endif
}

__shared__ Tdata rms;
Expand Down
2 changes: 1 addition & 1 deletion src/ops/swiglu/cuda/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ inline int gcd(int a, int b) {
}

template<class Tdata>
static __global__ void swiglu(
static __launch_bounds__(MAX_THREADS_PER_BLOCK) __global__ void swiglu(
Tdata *__restrict__ c,
int const stride_c,
Tdata const *__restrict__ a,
Expand Down
34 changes: 32 additions & 2 deletions xmake.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@ option("metax-gpu")
option_end()


option("sugon-dcu")
set_default(false)
set_showmenu(true)
set_description("Enable or disable Sugon DCU kernel")
add_defines("ENABLE_SUGON_DCU")
add_defines("ENABLE_NV_GPU")
option_end()

if is_mode("debug") then
add_cxflags("-g -O0")
add_defines("DEBUG_MODE")
Expand All @@ -74,9 +82,11 @@ if has_config("cpu") then

end

if has_config("nv-gpu") then

if has_config("nv-gpu", "sugon-dcu") then
add_defines("ENABLE_NV_GPU")
if has_config("sugon-dcu") then
add_defines("ENABLE_SUGON_DCU")
end
local CUDA_ROOT = os.getenv("CUDA_ROOT") or os.getenv("CUDA_HOME") or os.getenv("CUDA_PATH")
local CUDNN_ROOT = os.getenv("CUDNN_ROOT") or os.getenv("CUDNN_HOME") or os.getenv("CUDNN_PATH")
if CUDA_ROOT ~= nil then
Expand Down Expand Up @@ -267,6 +277,11 @@ if has_config("metax-gpu") then

end


toolchain("sugon-dcu-linker")
set_toolset("sh", "nvcc")
toolchain_end()

target("infiniop")
set_kind("shared")

Expand All @@ -276,6 +291,21 @@ target("infiniop")
if has_config("nv-gpu") then
add_deps("nv-gpu")
end
if has_config("sugon-dcu") then
local builddir = string.format(
"build/%s/%s/%s",
get_config("plat"),
get_config("arch"),
get_config("mode")
)
add_shflags("-s", "-shared", "-fPIC")
add_links("cublas", "cudnn", "cudadevrt", "cudart_static", "rt", "pthread", "dl")
-- Using -lnv-gpu will fail, manually link the target using full path
add_deps("nv-gpu", {inherit = false})
add_links(builddir.."/libnv-gpu.a")
set_toolchains("sugon-dcu-linker")
end

if has_config("cambricon-mlu") then
add_deps("cambricon-mlu")
end
Expand Down