From d138a03ddfb23b6a373f757503cf107658eff1e2 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 29 Nov 2025 00:15:13 +0100 Subject: [PATCH 1/9] Add support for CUMSUM and TRI for CUDA. --- ggml/src/ggml-cuda/common.cuh | 50 +++++++++++++ ggml/src/ggml-cuda/cumsum.cu | 126 ++++++++++++++++++++++++++++++++ ggml/src/ggml-cuda/cumsum.cuh | 5 ++ ggml/src/ggml-cuda/ggml-cuda.cu | 10 +++ ggml/src/ggml-cuda/tri.cu | 104 ++++++++++++++++++++++++++ ggml/src/ggml-cuda/tri.cuh | 5 ++ tests/test-backend-ops.cpp | 6 ++ 7 files changed, 306 insertions(+) create mode 100644 ggml/src/ggml-cuda/cumsum.cu create mode 100644 ggml/src/ggml-cuda/cumsum.cuh create mode 100644 ggml/src/ggml-cuda/tri.cu create mode 100644 ggml/src/ggml-cuda/tri.cuh diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 0b10e5f6ae0..c53208bed8b 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,6 +461,56 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +template +static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const T t = __shfl_up_sync(mask, x, offset, width); + if (lane_id >= offset) { + x += t; + } + } + return x; +} + +template +static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const float t_x = __shfl_up_sync(mask, a.x, offset, width); + const float t_y = __shfl_up_sync(mask, a.y, offset, width); + if (lane_id >= offset) { + a.x += t_x; + a.y += t_y; + } + } + return a; +} + +template +static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { +#ifdef FP16_AVAILABLE + const int lane_id = threadIdx.x % width; + const auto mask = __activemask(); +#pragma unroll + for (int offset = 1; offset < width; offset <<= 1) { + const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + if (lane_id >= offset) { + a += t; + } + } + return a; + +#else + NO_DEVICE_CODE; + return a; +#endif // FP16_AVAILABLE +} + static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) { #ifdef FP16_AVAILABLE diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu new file mode 100644 index 00000000000..e14be0721c6 --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -0,0 +1,126 @@ +#include "cumsum.cuh" + +// Kernel to compute cumulative sum along the innermost dimension (ne[0]) +// Each block processes one row (ne[0] elements) +// Algorithm matches Metal implementation: +// 1. Each warp computes prefix sum within itself +// 2. Last thread of each warp stores result in shared memory +// 3. All warps sync +// 4. Each element adds the sum of all preceding warps + +template +static __global__ void cumsum_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { + + // Shared memory to store warp sums (always use float for accumulation) + extern __shared__ float shmem[]; + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + const int tid = threadIdx.x; + const int lane_id = tid % WARP_SIZE; + + // Phase 1: Each thread processes elements at stride blockDim.x + // Compute warp-level prefix sums + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + // Load value and compute prefix sum within warp + float val = static_cast(src_row[i0]); + val = warp_prefix_inclusive_sum(val); + dst_row[i0] = static_cast(val); + + // Last thread of warp stores its sum to shared memory at position based on data index + if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { + const int shmem_idx = i0 / WARP_SIZE; + shmem[shmem_idx] = val; + } + } + + // Sync once after all warp prefix sums are computed + __syncthreads(); + + // Phase 2: Add the sum of all preceding warp groups to each element + for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { + const int shmem_idx = i0 / WARP_SIZE; + float sum = 0.0f; + for (int j = 0; j < shmem_idx; ++j) { + sum += shmem[j]; + } + dst_row[i0] = static_cast(static_cast(dst_row[i0]) + sum); + } +} + +template +static void cumsum_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + cudaStream_t stream) { + + dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + // Shared memory size: one float per warp + const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; + const size_t shmem_size = num_warps * sizeof(float); + + cumsum_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3 + ); +} + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(src0->type == dst->type); + switch(src0->type) { + case GGML_TYPE_F32: + { + cumsum_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_F16: + { + cumsum_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + case GGML_TYPE_BF16: + { + cumsum_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/cumsum.cuh b/ggml/src/ggml-cuda/cumsum.cuh new file mode 100644 index 00000000000..782d1d92e9b --- /dev/null +++ b/ggml/src/ggml-cuda/cumsum.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_CUMSUM_BLOCK_SIZE 256 + +void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index a844a3d99a2..689e5dfc384 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -54,6 +54,8 @@ #include "ggml-cuda/set-rows.cuh" #include "ggml-cuda/pad_reflect_1d.cuh" #include "ggml-cuda/solve_tri.cuh" +#include "ggml-cuda/tri.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml.h" #include @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; + case GGML_OP_TRI: + ggml_cuda_op_tri(ctx, dst); + break; case GGML_OP_RWKV_WKV6: ggml_cuda_op_rwkv_wkv6(ctx, dst); break; @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: + case GGML_OP_CUMSUM: + case GGML_OP_TRI: return true; case GGML_OP_SOLVE_TRI: return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32; diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu new file mode 100644 index 00000000000..b531f696302 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cu @@ -0,0 +1,104 @@ +#include "tri.cuh" +#include "ggml.h" + +// Triangle type comparison - determines which elements to keep +__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { + switch (type) { + case GGML_TRI_TYPE_LOWER: return i < r; + case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; + case GGML_TRI_TYPE_UPPER: return i > r; + case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; + default: return false; + } +} + +template +static __global__ void tri_kernel( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype) { + + const int64_t i3 = blockIdx.z; + const int64_t i2 = blockIdx.y; + const int64_t i1 = blockIdx.x; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); + T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + + // Each thread processes elements at stride blockDim.x + for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = tri_compare(i0, i1, ttype) + ? src_row[i0] : static_cast(0.f); + } +} + +template +static void tri_cuda( + const T * src, T * dst, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, + const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, + const ggml_tri_type ttype, + cudaStream_t stream) { + + dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); + dim3 grid_dims(ne01, ne02, ne03); + + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00, nb01, nb02, nb03, + nb0, nb1, nb2, nb3, + ttype + ); +} + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + cudaStream_t stream = ctx.stream(); + + const ggml_tri_type ttype = static_cast(ggml_get_op_params_i32(dst, 0)); + + GGML_ASSERT(src0->type == dst->type); + + switch(src0->type) { + case GGML_TYPE_F32: + { + tri_cuda( + (const float *)src0->data, (float *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_F16: + { + tri_cuda( + (const half *)src0->data, (half *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + case GGML_TYPE_BF16: + { + tri_cuda( + (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], + src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], + dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], + ttype, stream + ); + } break; + default: + GGML_ABORT("fatal error"); + } +} diff --git a/ggml/src/ggml-cuda/tri.cuh b/ggml/src/ggml-cuda/tri.cuh new file mode 100644 index 00000000000..a4cc66750d3 --- /dev/null +++ b/ggml/src/ggml-cuda/tri.cuh @@ -0,0 +1,5 @@ +#include "common.cuh" + +#define CUDA_TRI_BLOCK_SIZE 256 + +void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 60bab47b9f2..306fa15b923 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -7938,6 +7938,12 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 })); + test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 })); + + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 })); + for (int bs : {1, 2, 3, 4, 5, 8, 512}) { for (ggml_type type_a : all_types) { for (ggml_type type_b : {GGML_TYPE_F32}) { From 67207d21f9f84f1e0ac407606f40bd382100c096 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 29 Nov 2025 00:17:29 +0100 Subject: [PATCH 2/9] Minor optimizations. --- ggml/src/ggml-cuda/cumsum.cu | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index e14be0721c6..030397d403d 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -32,6 +32,10 @@ static __global__ void cumsum_kernel( const int tid = threadIdx.x; const int lane_id = tid % WARP_SIZE; + if (tid >= ne00) { + return; + } + // Phase 1: Each thread processes elements at stride blockDim.x // Compute warp-level prefix sums for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { @@ -69,13 +73,18 @@ static void cumsum_cuda( const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, cudaStream_t stream) { - dim3 block_dims(CUDA_CUMSUM_BLOCK_SIZE, 1, 1); dim3 grid_dims(ne01, ne02, ne03); // Shared memory size: one float per warp const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; const size_t shmem_size = num_warps * sizeof(float); + int block_size = num_warps * WARP_SIZE; + if (block_size > CUDA_CUMSUM_BLOCK_SIZE) { + block_size = CUDA_CUMSUM_BLOCK_SIZE; + } + dim3 block_dims(block_size, 1, 1); + cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, From fab002949f9f4458577a3f314dc5772d9e25ec68 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Sat, 29 Nov 2025 00:40:51 +0100 Subject: [PATCH 3/9] Correct warp_prefix_inclusive_sum in float2 variant to return float2 --- ggml/src/ggml-cuda/common.cuh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c53208bed8b..c747c1c80df 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -476,7 +476,7 @@ static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { } template -static __device__ __forceinline__ float warp_prefix_inclusive_sum(float2 a) { +static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { const int lane_id = threadIdx.x % width; const auto mask = __activemask(); #pragma unroll From 51c40a5a3951b7eeca080dd7a7c9f84025eaec90 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 16:10:05 +0100 Subject: [PATCH 4/9] Optimize TRI --- ggml/src/ggml-cuda/cumsum.cu | 15 ++++++++------- ggml/src/ggml-cuda/tri.cu | 30 ++++++++++++------------------ 2 files changed, 20 insertions(+), 25 deletions(-) diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 030397d403d..e758fd8bdba 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -1,3 +1,5 @@ +#include + #include "cumsum.cuh" // Kernel to compute cumulative sum along the innermost dimension (ne[0]) @@ -26,8 +28,8 @@ static __global__ void cumsum_kernel( return; } - const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); - T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3; const int tid = threadIdx.x; const int lane_id = tid % WARP_SIZE; @@ -78,18 +80,17 @@ static void cumsum_cuda( // Shared memory size: one float per warp const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; const size_t shmem_size = num_warps * sizeof(float); + const size_t type_size = sizeof(T); int block_size = num_warps * WARP_SIZE; - if (block_size > CUDA_CUMSUM_BLOCK_SIZE) { - block_size = CUDA_CUMSUM_BLOCK_SIZE; - } + block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); dim3 block_dims(block_size, 1, 1); cumsum_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - nb00, nb01, nb02, nb03, - nb0, nb1, nb2, nb3 + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size ); } diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index b531f696302..9ac13e33d4a 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -1,28 +1,18 @@ #include "tri.cuh" #include "ggml.h" -// Triangle type comparison - determines which elements to keep -__device__ static inline bool tri_compare(const int i, const int r, const ggml_tri_type type) { - switch (type) { - case GGML_TRI_TYPE_LOWER: return i < r; - case GGML_TRI_TYPE_LOWER_DIAG: return i <= r; - case GGML_TRI_TYPE_UPPER: return i > r; - case GGML_TRI_TYPE_UPPER_DIAG: return i >= r; - default: return false; - } -} - template static __global__ void tri_kernel( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const ggml_tri_type ttype) { + const int add_to_split, const bool prefix_keep) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; + const int64_t split_point = i1 + add_to_split; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; @@ -30,11 +20,11 @@ static __global__ void tri_kernel( const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - + // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { - dst_row[i0] = tri_compare(i0, i1, ttype) - ? src_row[i0] : static_cast(0.f); + const bool keep = ((i0 < split_point) == prefix_keep); + dst_row[i0] = keep ? src_row[i0] : T(0); } } @@ -49,13 +39,17 @@ static void tri_cuda( dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); dim3 grid_dims(ne01, ne02, ne03); + const size_t type_size = sizeof(T); + + const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; + const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); tri_kernel<<>>( src, dst, ne00, ne01, ne02, ne03, - nb00, nb01, nb02, nb03, - nb0, nb1, nb2, nb3, - ttype + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size, + add_to_split, prefix_keep ); } From c30f56543eb4c7c2be522f5c2a4458da787e5169 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 16:12:03 +0100 Subject: [PATCH 5/9] Whitespace --- ggml/src/ggml-cuda/tri.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index 9ac13e33d4a..ddc0fb64ce2 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -20,7 +20,7 @@ static __global__ void tri_kernel( const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - + // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { const bool keep = ((i0 < split_point) == prefix_keep); From 31b55fabd03e5f038a2222a7e63e167efe58850d Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 16:15:41 +0100 Subject: [PATCH 6/9] Fix strides. --- ggml/src/ggml-cuda/tri.cu | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index ddc0fb64ce2..8e7ed14b03f 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -18,8 +18,8 @@ static __global__ void tri_kernel( return; } - const T * src_row = (const T *) ((const char *) src + i1*nb01 + i2*nb02 + i3*nb03); - T * dst_row = (T *) (( char *) dst + i1*nb1 + i2*nb2 + i3*nb3); + const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; + T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; // Each thread processes elements at stride blockDim.x for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { From d1ca1c2592c196360b1c20e955fc340665ed9af4 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 18:02:47 +0100 Subject: [PATCH 7/9] Implement double loop --- ggml/src/ggml-cuda/tri.cu | 65 ++++++++++++++++++++++++++++++--------- 1 file changed, 50 insertions(+), 15 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index 8e7ed14b03f..0e7dda79318 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -1,14 +1,13 @@ +#include "ggml-cuda/common.cuh" #include "tri.cuh" #include "ggml.h" -template +template static __global__ void tri_kernel( const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, - const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, - const int add_to_split, const bool prefix_keep) { - + const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { const int64_t i3 = blockIdx.z; const int64_t i2 = blockIdx.y; const int64_t i1 = blockIdx.x; @@ -21,10 +20,20 @@ static __global__ void tri_kernel( const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; - // Each thread processes elements at stride blockDim.x - for (int64_t i0 = threadIdx.x; i0 < ne00; i0 += blockDim.x) { - const bool keep = ((i0 < split_point) == prefix_keep); - dst_row[i0] = keep ? src_row[i0] : T(0); + if constexpr (prefix_keep) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = T(0); + } + } else { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + dst_row[i0] = T(0); + } + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + dst_row[i0] = src_row[i0]; + } } } @@ -44,13 +53,39 @@ static void tri_cuda( const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); - tri_kernel<<>>( - src, dst, - ne00, ne01, ne02, ne03, - nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, - nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size, - add_to_split, prefix_keep - ); + if (prefix_keep) { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { // only 0 and 1 supported + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } else { + if (add_to_split == 0) { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } else { + tri_kernel<<>>( + src, dst, + ne00, ne01, ne02, ne03, + nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, + nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size + ); + } + } } void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { From 5289b530285370604130ca2d43e65a03db4987e7 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 18:03:50 +0100 Subject: [PATCH 8/9] Whitespace --- ggml/src/ggml-cuda/tri.cu | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-cuda/tri.cu b/ggml/src/ggml-cuda/tri.cu index 0e7dda79318..a3b1601fe46 100644 --- a/ggml/src/ggml-cuda/tri.cu +++ b/ggml/src/ggml-cuda/tri.cu @@ -21,17 +21,17 @@ static __global__ void tri_kernel( T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; if constexpr (prefix_keep) { - for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { dst_row[i0] = src_row[i0]; } - for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { dst_row[i0] = T(0); } } else { - for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { dst_row[i0] = T(0); } - for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { + for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { dst_row[i0] = src_row[i0]; } } From f422ba8ee0d581a36a31d95a55138970d07baf90 Mon Sep 17 00:00:00 2001 From: Piotr Wilkin Date: Mon, 1 Dec 2025 21:33:43 +0100 Subject: [PATCH 9/9] Fix HIP compilation bugs --- ggml/src/ggml-cuda/common.cuh | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index c747c1c80df..e4d0f2d5708 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -461,10 +461,18 @@ static __device__ __forceinline__ float warp_reduce_max(float x) { return x; } +static __device__ __forceinline__ unsigned int get_warp_mask() { +#ifdef __HIP_PLATFORM_AMD__ + return __ballot(1); // HIP equivalent +#else + return __activemask(); // CUDA +#endif +} + template static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { const int lane_id = threadIdx.x % width; - const auto mask = __activemask(); + const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { const T t = __shfl_up_sync(mask, x, offset, width); @@ -478,7 +486,7 @@ static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) { template static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) { const int lane_id = threadIdx.x % width; - const auto mask = __activemask(); + const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { const float t_x = __shfl_up_sync(mask, a.x, offset, width); @@ -495,12 +503,12 @@ template static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) { #ifdef FP16_AVAILABLE const int lane_id = threadIdx.x % width; - const auto mask = __activemask(); + const auto mask = get_warp_mask(); #pragma unroll for (int offset = 1; offset < width; offset <<= 1) { - const t = __hadd2(__shfl_up_sync(mask, a, offset, width)); + const half2 t = __shfl_up_sync(mask, a, offset, width); if (lane_id >= offset) { - a += t; + a = __hadd2(a, t); } } return a;