-
Notifications
You must be signed in to change notification settings - Fork 13.9k
Add support for CUMSUM and TRI for CUDA. #17584
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
pwilkin
wants to merge
9
commits into
ggml-org:master
Choose a base branch
from
pwilkin:tri_cumsum_cuda
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+353
−0
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
d138a03
Add support for CUMSUM and TRI for CUDA.
pwilkin 67207d2
Minor optimizations.
pwilkin fab0029
Correct warp_prefix_inclusive_sum in float2 variant to return float2
pwilkin 51c40a5
Optimize TRI
pwilkin c30f565
Whitespace
pwilkin 31b55fa
Fix strides.
pwilkin d1ca1c2
Implement double loop
pwilkin 5289b53
Whitespace
pwilkin f422ba8
Fix HIP compilation bugs
pwilkin File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| #include <algorithm> | ||
|
|
||
| #include "cumsum.cuh" | ||
|
|
||
| // Kernel to compute cumulative sum along the innermost dimension (ne[0]) | ||
| // Each block processes one row (ne[0] elements) | ||
| // Algorithm matches Metal implementation: | ||
| // 1. Each warp computes prefix sum within itself | ||
| // 2. Last thread of each warp stores result in shared memory | ||
| // 3. All warps sync | ||
| // 4. Each element adds the sum of all preceding warps | ||
|
|
||
| template<typename T> | ||
| static __global__ void cumsum_kernel( | ||
| const T * src, T * dst, | ||
| const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
| const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, | ||
| const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { | ||
|
|
||
| // Shared memory to store warp sums (always use float for accumulation) | ||
| extern __shared__ float shmem[]; | ||
|
|
||
| const int64_t i3 = blockIdx.z; | ||
| const int64_t i2 = blockIdx.y; | ||
| const int64_t i1 = blockIdx.x; | ||
|
|
||
| if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { | ||
| return; | ||
| } | ||
|
|
||
| const T * src_row = src + i1 * nb01 + i2*nb02 + i3*nb03; | ||
| T * dst_row = dst + i1 * nb1 + i2*nb2 + i3*nb3; | ||
|
|
||
| const int tid = threadIdx.x; | ||
| const int lane_id = tid % WARP_SIZE; | ||
|
|
||
| if (tid >= ne00) { | ||
| return; | ||
| } | ||
|
|
||
| // Phase 1: Each thread processes elements at stride blockDim.x | ||
| // Compute warp-level prefix sums | ||
| for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { | ||
| // Load value and compute prefix sum within warp | ||
| float val = static_cast<float>(src_row[i0]); | ||
| val = warp_prefix_inclusive_sum(val); | ||
| dst_row[i0] = static_cast<T>(val); | ||
|
|
||
| // Last thread of warp stores its sum to shared memory at position based on data index | ||
| if (lane_id == WARP_SIZE - 1 || i0 == ne00 - 1) { | ||
| const int shmem_idx = i0 / WARP_SIZE; | ||
| shmem[shmem_idx] = val; | ||
| } | ||
| } | ||
|
|
||
| // Sync once after all warp prefix sums are computed | ||
| __syncthreads(); | ||
|
|
||
| // Phase 2: Add the sum of all preceding warp groups to each element | ||
| for (int64_t i0 = tid; i0 < ne00; i0 += blockDim.x) { | ||
| const int shmem_idx = i0 / WARP_SIZE; | ||
| float sum = 0.0f; | ||
| for (int j = 0; j < shmem_idx; ++j) { | ||
| sum += shmem[j]; | ||
| } | ||
| dst_row[i0] = static_cast<T>(static_cast<float>(dst_row[i0]) + sum); | ||
| } | ||
| } | ||
|
|
||
| template<typename T> | ||
| static void cumsum_cuda( | ||
| const T * src, T * dst, | ||
| const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
| const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, | ||
| const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, | ||
| cudaStream_t stream) { | ||
|
|
||
| dim3 grid_dims(ne01, ne02, ne03); | ||
|
|
||
| // Shared memory size: one float per warp | ||
| const int num_warps = (ne00 + WARP_SIZE - 1) / WARP_SIZE; | ||
| const size_t shmem_size = num_warps * sizeof(float); | ||
| const size_t type_size = sizeof(T); | ||
|
|
||
| int block_size = num_warps * WARP_SIZE; | ||
| block_size = std::min(block_size, CUDA_CUMSUM_BLOCK_SIZE); | ||
| dim3 block_dims(block_size, 1, 1); | ||
|
|
||
| cumsum_kernel<<<grid_dims, block_dims, shmem_size, stream>>>( | ||
| src, dst, | ||
| ne00, ne01, ne02, ne03, | ||
| nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, | ||
| nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size | ||
| ); | ||
| } | ||
|
|
||
| void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const ggml_tensor * src0 = dst->src[0]; | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| GGML_ASSERT(src0->type == dst->type); | ||
| switch(src0->type) { | ||
| case GGML_TYPE_F32: | ||
| { | ||
| cumsum_cuda( | ||
| (const float *)src0->data, (float *)dst->data, | ||
| src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| stream | ||
| ); | ||
| } break; | ||
| case GGML_TYPE_F16: | ||
| { | ||
| cumsum_cuda( | ||
| (const half *)src0->data, (half *)dst->data, | ||
| src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| stream | ||
| ); | ||
| } break; | ||
| case GGML_TYPE_BF16: | ||
| { | ||
| cumsum_cuda( | ||
| (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, | ||
| src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| stream | ||
| ); | ||
| } break; | ||
| default: | ||
| GGML_ABORT("fatal error"); | ||
| } | ||
| } | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #include "common.cuh" | ||
|
|
||
| #define CUDA_CUMSUM_BLOCK_SIZE 256 | ||
|
|
||
| void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| #include "ggml-cuda/common.cuh" | ||
| #include "tri.cuh" | ||
| #include "ggml.h" | ||
|
|
||
| template<typename T, bool prefix_keep, int add_to_split> | ||
| static __global__ void tri_kernel( | ||
| const T * src, T * dst, | ||
| const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
| const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, | ||
| const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3) { | ||
| const int64_t i3 = blockIdx.z; | ||
| const int64_t i2 = blockIdx.y; | ||
| const int64_t i1 = blockIdx.x; | ||
| const int64_t split_point = i1 + add_to_split; | ||
|
|
||
| if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { | ||
| return; | ||
| } | ||
|
|
||
| const T * src_row = src + i1*nb01 + i2*nb02 + i3*nb03; | ||
| T * dst_row = dst + i1*nb1 + i2*nb2 + i3*nb3; | ||
|
|
||
| if constexpr (prefix_keep) { | ||
| for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { | ||
| dst_row[i0] = src_row[i0]; | ||
| } | ||
| for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { | ||
| dst_row[i0] = T(0); | ||
| } | ||
| } else { | ||
| for (int64_t i0 = threadIdx.x; i0 < split_point; i0 += blockDim.x) { | ||
| dst_row[i0] = T(0); | ||
| } | ||
| for (int64_t i0 = threadIdx.x + split_point; i0 < ne00; i0 += blockDim.x) { | ||
| dst_row[i0] = src_row[i0]; | ||
| } | ||
| } | ||
| } | ||
|
|
||
| template<typename T> | ||
| static void tri_cuda( | ||
| const T * src, T * dst, | ||
| const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, | ||
| const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, | ||
| const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, | ||
| const ggml_tri_type ttype, | ||
| cudaStream_t stream) { | ||
|
|
||
| dim3 block_dims(CUDA_TRI_BLOCK_SIZE, 1, 1); | ||
| dim3 grid_dims(ne01, ne02, ne03); | ||
| const size_t type_size = sizeof(T); | ||
|
|
||
| const int add_to_split = (ttype == GGML_TRI_TYPE_LOWER_DIAG || ttype == GGML_TRI_TYPE_UPPER) ? 1 : 0; | ||
| const bool prefix_keep = (ttype == GGML_TRI_TYPE_LOWER || ttype == GGML_TRI_TYPE_LOWER_DIAG); | ||
|
|
||
| if (prefix_keep) { | ||
| if (add_to_split == 0) { | ||
| tri_kernel<T, true, 0><<<grid_dims, block_dims, 0, stream>>>( | ||
| src, dst, | ||
| ne00, ne01, ne02, ne03, | ||
| nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, | ||
| nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size | ||
| ); | ||
| } else { // only 0 and 1 supported | ||
| tri_kernel<T, true, 1><<<grid_dims, block_dims, 0, stream>>>( | ||
| src, dst, | ||
| ne00, ne01, ne02, ne03, | ||
| nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, | ||
| nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size | ||
| ); | ||
| } | ||
| } else { | ||
| if (add_to_split == 0) { | ||
| tri_kernel<T, false, 0><<<grid_dims, block_dims, 0, stream>>>( | ||
| src, dst, | ||
| ne00, ne01, ne02, ne03, | ||
| nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, | ||
| nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size | ||
| ); | ||
| } else { | ||
| tri_kernel<T, false, 1><<<grid_dims, block_dims, 0, stream>>>( | ||
| src, dst, | ||
| ne00, ne01, ne02, ne03, | ||
| nb00 / type_size, nb01 / type_size, nb02 / type_size, nb03 / type_size, | ||
| nb0 / type_size, nb1 / type_size, nb2 / type_size, nb3 / type_size | ||
| ); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { | ||
| const ggml_tensor * src0 = dst->src[0]; | ||
| cudaStream_t stream = ctx.stream(); | ||
|
|
||
| const ggml_tri_type ttype = static_cast<ggml_tri_type>(ggml_get_op_params_i32(dst, 0)); | ||
|
|
||
| GGML_ASSERT(src0->type == dst->type); | ||
|
|
||
| switch(src0->type) { | ||
| case GGML_TYPE_F32: | ||
| { | ||
| tri_cuda( | ||
| (const float *)src0->data, (float *)dst->data, | ||
| src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| ttype, stream | ||
| ); | ||
| } break; | ||
| case GGML_TYPE_F16: | ||
| { | ||
| tri_cuda( | ||
| (const half *)src0->data, (half *)dst->data, | ||
| src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| ttype, stream | ||
| ); | ||
| } break; | ||
| case GGML_TYPE_BF16: | ||
| { | ||
| tri_cuda( | ||
| (const nv_bfloat16 *)src0->data, (nv_bfloat16 *)dst->data, | ||
| src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], | ||
| src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], | ||
| dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], | ||
| ttype, stream | ||
| ); | ||
| } break; | ||
| default: | ||
| GGML_ABORT("fatal error"); | ||
| } | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,5 @@ | ||
| #include "common.cuh" | ||
|
|
||
| #define CUDA_TRI_BLOCK_SIZE 256 | ||
|
|
||
| void ggml_cuda_op_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be much preferable to store the temporary results in registers or shared memory rather than global memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Isn't
valhere already stored in a register though? I'm afraid I'll need some more guidance here.