Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 58 additions & 0 deletions ggml/src/ggml-cuda/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,64 @@ static __device__ __forceinline__ float warp_reduce_max(float x) {
return x;
}

static __device__ __forceinline__ unsigned int get_warp_mask() {
#ifdef __HIP_PLATFORM_AMD__
return __ballot(1); // HIP equivalent
#else
return __activemask(); // CUDA
#endif
}

template<typename T, int width = WARP_SIZE>
static __device__ __forceinline__ T warp_prefix_inclusive_sum(T x) {
const int lane_id = threadIdx.x % width;
const auto mask = get_warp_mask();
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const T t = __shfl_up_sync(mask, x, offset, width);
if (lane_id >= offset) {
x += t;
}
}
return x;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ float2 warp_prefix_inclusive_sum(float2 a) {
const int lane_id = threadIdx.x % width;
const auto mask = get_warp_mask();
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const float t_x = __shfl_up_sync(mask, a.x, offset, width);
const float t_y = __shfl_up_sync(mask, a.y, offset, width);
if (lane_id >= offset) {
a.x += t_x;
a.y += t_y;
}
}
return a;
}

template<int width = WARP_SIZE>
static __device__ __forceinline__ half2 warp_prefix_inclusive_sum(half2 a) {
#ifdef FP16_AVAILABLE
const int lane_id = threadIdx.x % width;
const auto mask = get_warp_mask();
#pragma unroll
for (int offset = 1; offset < width; offset <<= 1) {
const half2 t = __shfl_up_sync(mask, a, offset, width);
if (lane_id >= offset) {
a = __hadd2(a, t);
}
}
return a;

#else
NO_DEVICE_CODE;
return a;
#endif // FP16_AVAILABLE
}

static __device__ __forceinline__ half ggml_cuda_hmax(const half a, const half b) {
#ifdef FP16_AVAILABLE

Expand Down
136 changes: 136 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
#include <algorithm>

#include "cumsum.cuh"

// Kernel to compute cumulative sum along the innermost dimension (ne[0])
// Each block processes one row (ne[0] elements)
// Algorithm matches Metal implementation:
// 1. Each warp computes prefix sum within itself
// 2. Last thread of each warp stores result in shared memory
// 3. All warps sync
// 4. Each element adds the sum of all preceding warps

template<typename T>
static __global__ void cumsum_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {

// Shared memory to store warp sums (always use float for accumulation)
extern __shared__ float shmem[];

const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;

if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03;
T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3;

const int tid = threadIdx.x;
const int lane_id = tid % WARP_SIZE;

if (tid >= ne00) {
return;
}

// Phase 1: Each thread processes elements at stride blockDim.x
// Compute warp-level prefix sums
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
// Load value and compute prefix sum within warp
float val = static_cast<float>(src_row[i0]);
val = warp_prefix_inclusive_sum(val);
dst_row[i0] = static_cast<T>(val);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be much preferable to store the temporary results in registers or shared memory rather than global memory.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't val here already stored in a register though? I'm afraid I'll need some more guidance here.


// Last thread of warp stores its sum to shared memory at position based on data index
if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) {
const int shmem_idx = i0 / WARP_SIZE;
shmem[shmem_idx] = val;
}
}

// Sync once after all warp prefix sums are computed
__syncthreads();

// Phase 2: Add the sum of all preceding warp groups to each element
for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) {
const int shmem_idx = i0 / WARP_SIZE;
float sum = 0.0f;
for (int j = 0; j < shmem_idx; ++j) {
sum += shmem[j];
}
dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum);
}
}

template<typename T>
static void cumsum_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
cudaStream_t stream) {

dim3 grid_dims(ne01, ne02, ne03);

// Shared memory size: one float per warp
const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE;
const size_t shmem_size = num_warps * sizeof(float);
const size_t type_size = sizeof(T);

int block_size = num_warps * WARP_SIZE;
block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE);
dim3 block_dims(block_size, 1, 1);

cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

GGML_ASSERT(src0->type == dst->type);
switch(src0->type) {
case GGML_TYPE_F32:
{
cumsum_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_F16:
{
cumsum_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
case GGML_TYPE_BF16:
{
cumsum_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/cumsum.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_CUMSUM_BLOCK_SIZE 256

void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
10 changes: 10 additions & 0 deletions ggml/src/ggml-cuda/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml-cuda/tri.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml.h"

#include <algorithm>
Expand Down Expand Up @@ -2700,6 +2702,12 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst);
break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst);
break;
case GGML_OP_RWKV_WKV6:
ggml_cuda_op_rwkv_wkv6(ctx, dst);
break;
Expand Down Expand Up @@ -4262,6 +4270,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
case GGML_OP_CUMSUM:
case GGML_OP_TRI:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
Expand Down
133 changes: 133 additions & 0 deletions ggml/src/ggml-cuda/tri.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
#include "ggml-cuda/common.cuh"
#include "tri.cuh"
#include "ggml.h"

template<typename T, bool prefix_keep, int add_to_split>
static __global__ void tri_kernel(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) {
const int64_t i3 = blockIdx.z;
const int64_t i2 = blockIdx.y;
const int64_t i1 = blockIdx.x;
const int64_t split_point = i1 + add_to_split;

if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}

const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03;
T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3;

if constexpr (prefix_keep) {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = T(0);
}
} else {
for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) {
dst_row[i0] = T(0);
}
for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) {
dst_row[i0] = src_row[i0];
}
}
}

template<typename T>
static void tri_cuda(
const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const ggml_tri_type ttype,
cudaStream_t stream) {

dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1);
dim3 grid_dims(ne01, ne02, ne03);
const size_t type_size = sizeof(T);

const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0;
const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG);

if (prefix_keep) {
if (add_to_split == 0) {
tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else { // only 0 and 1 supported
tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
} else {
if (add_to_split == 0) {
tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
} else {
tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>(
src, dst,
ne00, ne01, ne02, ne03,
nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size,
nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size
);
}
}
}

void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
cudaStream_t stream = ctx.stream();

const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0));

GGML_ASSERT(src0->type == dst->type);

switch(src0->type) {
case GGML_TYPE_F32:
{
tri_cuda(
(const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
case GGML_TYPE_F16:
{
tri_cuda(
(const half *)src0->data, (half *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
case GGML_TYPE_BF16:
{
tri_cuda(
(const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],
ttype, stream
);
} break;
default:
GGML_ABORT("fatal error");
}
}
5 changes: 5 additions & 0 deletions ggml/src/ggml-cuda/tri.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_TRI_BLOCK_SIZE 256

void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
6 changes: 6 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7938,6 +7938,12 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));

test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_LOWER, GGML_TYPE_F32, { 256, 256, 4, 4 }));
test_cases.emplace_back(new test_tri(GGML_TRI_TYPE_UPPER_DIAG, GGML_TYPE_F32, { 1024, 1024, 8, 4 }));

test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));
test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 128, 128, 4, 1 }));

for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
Expand Down