diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0b10e5f6ae0..e4d0f2d5708 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,6 +461,64 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ unsigned int get_warp_mask() { +#ifdef __HIP_PLATFORM_AMD__ + return __ballot(1); // HIP equivalent +#else + return __activemask(); // CUDA +#endif +} + +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = get_warp_mask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = get_warp_mask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = get_warp_mask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const half2 t = __shfl_up_sync(mask, a, offset, width); + if (lane_id >= offset) { + a = __hadd2(a, t); + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 00000000000..b6ceb7a1495 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,241 @@ +#include + +#include "cumsum.cuh" +#include "ggml-impl.h" + +// Check if CUB is available +#ifdef __has_include +# if __has_include() +# define HAS_CUB_DEVICE_SCAN 1 +# include +# else +# define HAS_CUB_DEVICE_SCAN 0 +# endif +#else +# define HAS_CUB_DEVICE_SCAN 0 +#endif + +#if HAS_CUB_DEVICE_SCAN + +template +static __global__ void cumsum_cub_kernel( + const T* __restrict__ src, + T* __restrict__ dst, + int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne03, + int64_t nb01, int64_t nb02, int64_t nb03, + int64_t nb1, int64_t nb2, int64_t nb3) +{ + using BlockScan = cub::BlockScan; + + __shared__ typename BlockScan::TempStorage temp_storage; + __shared__ T block_carry; // carry from previous tile + __shared__ T block_total; // total of current tile + + const int tid = threadIdx.x; + + const int64_t i1 = blockIdx.x; + const int64_t i2 = blockIdx.y; + const int64_t i3 = blockIdx.z; + + if (i1 >= ne01 || i2 >= ne02 || i3 >= ne03) { + return; + } + + const T* src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; + T* dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; + + if (tid == 0) { + block_carry = 0; + } + __syncthreads(); + + for (int64_t start = 0; start < ne00; start += BLOCK_SIZE) { + int64_t idx = start + tid; + + T x = (idx < ne00) ? src_row[idx] : T(0); + + T inclusive; + BlockScan(temp_storage).InclusiveSum(x, inclusive); + + // Last thread stores total + if (tid == BLOCK_SIZE - 1) { + block_total = inclusive; + } + __syncthreads(); + + T final = inclusive + block_carry; + + if (idx < ne00) { + dst_row[idx] = final; + } + __syncthreads(); + + if (tid == 0) { + block_carry += block_total; + } + __syncthreads(); + } +} + +#endif // HAS_CUB_DEVICE_SCAN + +// Fallback kernel implementation (original) +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + + const int tid = threadIdx.x; + const int lane = tid & (WARP_SIZE - 1); + const int warp = tid / WARP_SIZE; + const int warps_per_block = blockDim.x / WARP_SIZE; + + extern __shared__ float smem[]; + float* s_vals = smem; + float* s_warp_sums = smem + blockDim.x; + float* s_carry = smem + blockDim.x + warps_per_block; + float* s_chunk_total = s_carry + 1; + + // Initialize carry + if (tid == 0) { + *s_carry = 0.0f; + } + __syncthreads(); + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1 * nb01 + i2 * nb02 + i3 * nb03; + T * dst_row = dst + i1 * nb1 + i2 * nb2 + i3 * nb3; + + for (int64_t start = 0; start < ne00; start += blockDim.x) { + int64_t idx = start + tid; + float val = (idx < ne00) ? static_cast(src_row[idx]) : 0.0f; + + // 1. Warp inclusive scan + val = warp_prefix_inclusive_sum(val); + s_vals[tid] = val; + + // Store warp total + if (lane == WARP_SIZE - 1) { + s_warp_sums[warp] = val; + } + __syncthreads(); + + // 2. Exclusive scan of warp sums (warp 0 only) + if (warp == 0) { + float w = (tid < warps_per_block) ? s_warp_sums[tid] : 0.0f; + float inc = warp_prefix_inclusive_sum(w); + if (tid < warps_per_block) { + s_warp_sums[tid] = inc - w; // exclusive sum + } + if (tid == warps_per_block - 1) { + *s_chunk_total = inc; // total sum of this chunk + } + } + __syncthreads(); + + float carry = *s_carry; + float final_val = s_vals[tid] + s_warp_sums[warp] + carry; + if (idx < ne00) { + dst_row[idx] = static_cast(final_val); + } + __syncthreads(); + + // Update carry for next chunk + if (tid == 0) { + *s_carry += *s_chunk_total; + } + __syncthreads(); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + const size_t type_size = sizeof(T); + bool use_cub = false; +#if HAS_CUB_DEVICE_SCAN + // Check if we can use CUB (data must be contiguous along innermost dimension) + const bool is_contiguous = (nb00 == type_size) && (nb0 == type_size); + + if (is_contiguous) { + use_cub = true; + } +#endif // HAS_CUB_DEVICE_SCAN + dim3 grid_dims(ne01, ne02, ne03); + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + int block_size = num_warps * WARP_SIZE; + block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); + dim3 block_dims(block_size, 1, 1); + const int warps_per_block = block_size / WARP_SIZE; + const size_t shmem_size = (block_size + warps_per_block + 2) * sizeof(float); + + if (use_cub) { + cumsum_cub_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 00000000000..782d1d92e9b --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a844a3d99a2..689e5dfc384 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -54,6 +54,8 @@ #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/tri.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml.h" #include @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 00000000000..a3b1601fe46 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,133 @@ +#include "ggml-cuda/common.cuh" +#include "tri.cuh" +#include "ggml.h" + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + const int64_t split_point = i1 + add_to_split; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; + + if constexpr (prefix_keep) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = T(0); + } + } else { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = T(0); + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + const size_t type_size = sizeof(T); + + const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; + const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); + + if (prefix_keep) { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { // only 0 and 1 supported + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } else { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 00000000000..a4cc66750d3 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 60bab47b9f2..43c26e6be95 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7709,6 +7709,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 10, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 127, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 5, 4, 3 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 255, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 256, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 511, 5, 4, 3 })); @@ -7938,6 +7939,13 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); + + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 4 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 16, 5, 4 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20000, 10, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) {