diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index f917a745d5a..7fb8f65937b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -4235,9 +4235,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { - ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); + ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<subgroup_size); } for (auto &c : compiles) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index bc1c278bf49..5cd0785d20f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -75,7 +75,7 @@ void softmax_warp_inplace(inout float vals[experts_per_thread], const uint limit } void main() { - const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_LocalInvocationID.y; + const uint row = gl_WorkGroupID.x * gl_WorkGroupSize.y + gl_SubgroupID; if (row >= n_rows) { return; } @@ -83,17 +83,18 @@ void main() { const uint logits_offset = n_experts * row; const uint weights_offset = n_expert_used * row; const uint ids_offset = n_experts * row; + const uint lane = gl_SubgroupInvocationID; float wt[experts_per_thread]; [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { - const uint expert = i + gl_LocalInvocationID.x; + const uint expert = i + lane; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; } if (!late_softmax) { - softmax_warp_inplace(wt, n_experts, gl_LocalInvocationID.x, false); + softmax_warp_inplace(wt, n_experts, lane, false); } // at this point, each thread holds a portion of softmax, @@ -111,11 +112,11 @@ void main() { for (int k = 0; k < n_expert_used; k++) { float max_val = wt[0]; - uint max_expert = gl_LocalInvocationID.x; + uint max_expert = lane; [[unroll]] for (int i = 1; i < experts_per_thread; i++) { - const uint expert = gl_LocalInvocationID.x + i * WARP_SIZE; + const uint expert = lane + i * WARP_SIZE; if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) { max_val = wt[i]; max_expert = expert; @@ -132,11 +133,11 @@ void main() { } } - if ((k & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((k & (WARP_SIZE - 1)) == lane) { output_weights[k / WARP_SIZE] = max_val; } - if ((max_expert & (WARP_SIZE - 1)) == gl_LocalInvocationID.x) { + if ((max_expert & (WARP_SIZE - 1)) == lane) { wt[max_expert / WARP_SIZE] = -INFINITY; ids[ids_offset + k] = max_expert; @@ -158,12 +159,12 @@ void main() { } if (late_softmax) { - softmax_warp_inplace(output_weights, n_expert_used, gl_LocalInvocationID.x, true); + softmax_warp_inplace(output_weights, n_expert_used, lane, true); } [[unroll]] for (uint i = 0; i < experts_per_thread; ++i) { - uint idx = i * WARP_SIZE + gl_LocalInvocationID.x; + uint idx = i * WARP_SIZE + lane; if (idx < n_expert_used) { weights[weights_offset + idx] = output_weights[i]; }