diff --git a/src/ggml-metal/ggml-metal-device.cpp b/src/ggml-metal/ggml-metal-device.cpp index 0eefc0b13..78f7a0ea9 100644 --- a/src/ggml-metal/ggml-metal-device.cpp +++ b/src/ggml-metal/ggml-metal-device.cpp @@ -953,6 +953,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_ return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const struct ggml_tensor * op) { + assert(op->op == GGML_OP_COUNT_EQUAL); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_ARGSORT); @@ -1623,6 +1642,48 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss(ggml_metal_library_t lib, const struct ggml_tensor * op) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + if (res) { + ggml_metal_pipeline_set_smem(res, 32 * sizeof(float)); + } + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const struct ggml_tensor * op) { + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + if (res) { + ggml_metal_pipeline_set_smem(res, 32 * sizeof(float)); + } + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_OPT_STEP_ADAMW); diff --git a/src/ggml-metal/ggml-metal-device.h b/src/ggml-metal/ggml-metal-device.h index 39ee6e342..051cc5383 100644 --- a/src/ggml-metal/ggml-metal-device.h +++ b/src/ggml-metal/ggml-metal-device.h @@ -126,6 +126,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); @@ -142,6 +143,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); diff --git a/src/ggml-metal/ggml-metal-device.m b/src/ggml-metal/ggml-metal-device.m index acf9dfd5f..916a32b06 100644 --- a/src/ggml-metal/ggml-metal-device.m +++ b/src/ggml-metal/ggml-metal-device.m @@ -879,6 +879,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); case GGML_OP_ARGMAX: return has_simdgroup_reduction; + case GGML_OP_CROSS_ENTROPY_LOSS: + return has_simdgroup_reduction; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + return has_simdgroup_reduction; case GGML_OP_NORM: case GGML_OP_RMS_NORM: return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0])); @@ -1017,6 +1021,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_COUNT_EQUAL: + return true; case GGML_OP_OPT_STEP_ADAMW: case GGML_OP_OPT_STEP_SGD: return has_simdgroup_reduction; diff --git a/src/ggml-metal/ggml-metal-impl.h b/src/ggml-metal/ggml-metal-impl.h index 0fae97029..b8fcd9870 100644 --- a/src/ggml-metal/ggml-metal-impl.h +++ b/src/ggml-metal/ggml-metal-impl.h @@ -879,6 +879,19 @@ typedef struct { uint64_t nb01; } ggml_metal_kargs_argmax; +typedef struct { + int32_t n_classes; + int32_t n_rows; +} ggml_metal_kargs_cross_entropy_loss; + +typedef struct { + int32_t n_classes; +} ggml_metal_kargs_cross_entropy_loss_back; + +typedef struct { + int32_t ne0; +} ggml_metal_kargs_count_equal; + typedef struct { int64_t np; } ggml_metal_kargs_opt_step_adamw; diff --git a/src/ggml-metal/ggml-metal-ops.cpp b/src/ggml-metal/ggml-metal-ops.cpp index e46970dad..8624161b6 100644 --- a/src/ggml-metal/ggml-metal-ops.cpp +++ b/src/ggml-metal/ggml-metal-ops.cpp @@ -428,6 +428,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argmax(ctx, idx); } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + n_fuse = ggml_metal_op_cross_entropy_loss(ctx, idx); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + n_fuse = ggml_metal_op_cross_entropy_loss_back(ctx, idx); + } break; + case GGML_OP_COUNT_EQUAL: + { + n_fuse = ggml_metal_op_count_equal(ctx, idx); + } break; case GGML_OP_OPT_STEP_ADAMW: { n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); @@ -3640,6 +3652,84 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + + GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, src0, nb); + GGML_TENSOR_LOCALS( int32_t, ne1, src1, ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, src1, nb); + + const int64_t nclasses = ne00; + const int64_t nrows = ggml_nrows(src0); + + ggml_metal_kargs_cross_entropy_loss args = { + /*.n_classes =*/ (int32_t) nclasses, + /*.n_rows =*/ (int32_t) nrows, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss(lib, op); + + const int nth = 32; + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_cross_entropy_loss_back(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_tensor * grad = op->src[0]; + const ggml_tensor * src0 = op->src[1]; + const ggml_tensor * src1 = op->src[2]; + + GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne); + + const int64_t nclasses = ne00; + const int64_t nrows = ggml_nrows(src0); + + ggml_metal_kargs_cross_entropy_loss_back args = { + /*.n_classes =*/ (int32_t) nclasses, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss_back(lib, op); + + const int nth = 32; + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(grad), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4); + ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); + + ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + + return 1; +} + int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3773,6 +3863,34 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + + ggml_metal_kargs_count_equal args = { + /*.ne0 =*/ (int32_t) ggml_nelements(src0), + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op); + + const int nth = 256; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 1); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 2); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1); + + return 1; +} + int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -3815,6 +3933,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { ggml_metal_library_t lib = ctx->lib; ggml_metal_encoder_t enc = ctx->enc; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); GGML_TENSOR_LOCALS( int32_t, ne, op, ne); diff --git a/src/ggml-metal/ggml-metal-ops.h b/src/ggml-metal/ggml-metal-ops.h index 332e550ee..1f4e2b43b 100644 --- a/src/ggml-metal/ggml-metal-ops.h +++ b/src/ggml-metal/ggml-metal-ops.h @@ -80,8 +80,11 @@ int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx); int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx); +int ggml_metal_op_cross_entropy_loss_back(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/src/ggml-metal/ggml-metal.metal b/src/ggml-metal/ggml-metal.metal index 59e576170..c39bd696d 100644 --- a/src/ggml-metal/ggml-metal.metal +++ b/src/ggml-metal/ggml-metal.metal @@ -1832,6 +1832,194 @@ typedef decltype(kernel_sum_rows) kernel_sum_rows_t; template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template +kernel void kernel_cross_entropy_loss( + constant ggml_metal_kargs_cross_entropy_loss & args, + device const char * logits_ptr, + device const char * labels_ptr, + device float * dst, + threadgroup float * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + device const T * logits = (device const T *)logits_ptr; + device const T * labels = (device const T *)labels_ptr; + + const int nclasses = args.n_classes; + const int nrows = args.n_rows; + + const ulong offset = (ulong)tgpig.x * (ulong)nclasses; + logits += offset; + labels += offset; + + const uint nsg = (ntg.x + 31) / 32; + float max_logit = -INFINITY; + + for (int i = tpitg.x; i < nclasses; i += ntg.x) { + max_logit = fmax(max_logit, (float)logits[i]); + } + max_logit = simd_max(max_logit); + + if (ntg.x > 32) { + if (tiisg == 0) { + shmem[sgitg] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + max_logit = (tiisg < nsg) ? shmem[tiisg] : -INFINITY; + max_logit = simd_max(max_logit); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + shmem[0] = max_logit; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + max_logit = shmem[0]; + } + + float sum = 0.0f; + for (int i = tpitg.x; i < nclasses; i += ntg.x) { + sum += exp((float)logits[i] - max_logit); + } + sum = simd_sum(sum); + + + if (ntg.x > 32) { + if (tiisg == 0) { + shmem[sgitg] = sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + sum = (tiisg < nsg) ? shmem[tiisg] : 0.0f; + sum = simd_sum(sum); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + shmem[0] = sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = shmem[0]; + } + + const float log_sum = log(sum); + + float row_loss = 0.0f; + for (int i = tpitg.x; i < nclasses; i += ntg.x) { + const float log_softmax = (float)logits[i] - max_logit - log_sum; + row_loss += log_softmax * (float)labels[i]; + } + row_loss = simd_sum(row_loss); + + if (sgitg == 0 && tiisg == 0) { + shmem[0] = -row_loss; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (tgpig.x == 0 && sgitg == 0 && tiisg == 0) { + float total_loss = 0.0f; + for (int i = 0; i < nrows; i++) { + total_loss += dst[i]; + } + dst[0] = total_loss / (float)nrows; + } +} + +template +kernel void kernel_cross_entropy_loss_back( + constant ggml_metal_kargs_cross_entropy_loss_back & args, + device const float * grad, + device const char * logits_ptr, + device const char * labels_ptr, + device char * dst_ptr, + threadgroup float * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]], + uint3 tpg[[threadgroups_per_grid]]) { + + device const T * logits = (device const T *)logits_ptr; + device const T * labels = (device const T *)labels_ptr; + device T * dst = (device T *)dst_ptr; + + const int nclasses = args.n_classes; + const float grad_scale = grad[0] / (float)tpg.x; + + const ulong offset = (ulong)tgpig.x * (ulong)nclasses; + logits += offset; + labels += offset; + dst += offset; + + const uint nsg = (ntg.x + 31) / 32; + float maxval = -INFINITY; + + for (int i = tpitg.x; i < nclasses; i += ntg.x) { + maxval = fmax(maxval, (float)logits[i]); + } + maxval = simd_max(maxval); + + if (ntg.x > 32) { + if (tiisg == 0) { + shmem[sgitg] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + maxval = (tiisg < nsg) ? shmem[tiisg] : -INFINITY; + maxval = simd_max(maxval); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + shmem[0] = maxval; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + maxval = shmem[0]; + } + + float sum = 0.0f; + for (int i = tpitg.x; i < nclasses; i += ntg.x) { + const float val = exp((float)logits[i] - maxval); + + dst[i] = (T)val; + sum += val; + } + + threadgroup_barrier(mem_flags::mem_device); + + sum = simd_sum(sum); + + if (ntg.x > 32) { + if (tiisg == 0) { + shmem[sgitg] = sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (sgitg == 0) { + sum = (tiisg < nsg) ? shmem[tiisg] : 0.0f; + sum = simd_sum(sum); + } + threadgroup_barrier(mem_flags::mem_threadgroup); + if (sgitg == 0 && tiisg == 0) { + shmem[0] = sum; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + sum = shmem[0]; + } + + const float sm_scale = 1.0f / sum; + + for (int i = tpitg.x; i < nclasses; i += ntg.x) { + const float val = (float)dst[i]; + dst[i] = (T)((val * sm_scale - (float)labels[i]) * grad_scale); + } +} + template kernel void kernel_cumsum_blk( constant ggml_metal_kargs_cumsum_blk & args, @@ -9672,3 +9860,51 @@ kernel void kernel_opt_step_sgd_f32( x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; } + +typedef decltype(kernel_cross_entropy_loss) kernel_cross_entropy_loss_t; + +template [[host_name("kernel_cross_entropy_loss_f32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; +template [[host_name("kernel_cross_entropy_loss_f16")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; +template [[host_name("kernel_cross_entropy_loss_i32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; +template [[host_name("kernel_cross_entropy_loss_i16")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss; + +typedef decltype(kernel_cross_entropy_loss_back) kernel_cross_entropy_loss_back_t; + +template [[host_name("kernel_cross_entropy_loss_back_f32")]] kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; +template [[host_name("kernel_cross_entropy_loss_back_f16")]] kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; +template [[host_name("kernel_cross_entropy_loss_back_i32")]] kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; +template [[host_name("kernel_cross_entropy_loss_back_i16")]] kernel kernel_cross_entropy_loss_back_t kernel_cross_entropy_loss_back; + +template +kernel void kernel_count_equal( + constant ggml_metal_kargs_count_equal & args, + device const char * src0_ptr, + device const char * src1_ptr, + device int * dst, + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + + device const T * src0 = (device const T *)src0_ptr; + device const T * src1 = (device const T *)src1_ptr; + + device atomic_int * dst_atomic = (device atomic_int *)dst; + + if (tpitg.x == 0) { + atomic_store_explicit(dst_atomic, 0, memory_order_relaxed); + } + + threadgroup_barrier(mem_flags::mem_device); + + for (int i = tpitg.x; i < args.ne0; i += ntg.x) { + if (src0[i] == src1[i]) { + atomic_fetch_add_explicit(dst_atomic, 1, memory_order_relaxed); + } + } +} + +typedef decltype(kernel_count_equal) kernel_count_equal_t; + +template [[host_name("kernel_count_equal_f32")]] kernel kernel_count_equal_t kernel_count_equal; +template [[host_name("kernel_count_equal_f16")]] kernel kernel_count_equal_t kernel_count_equal; +template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal; +template [[host_name("kernel_count_equal_i16")]] kernel kernel_count_equal_t kernel_count_equal;