Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions src/ggml-metal/ggml-metal-device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -953,6 +953,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax(ggml_metal_library_
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const struct ggml_tensor * op) {
assert(op->op == GGML_OP_COUNT_EQUAL);

char base[256];
char name[256];

snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_ARGSORT);

Expand Down Expand Up @@ -1623,6 +1642,48 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me
return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss(ggml_metal_library_t lib, const struct ggml_tensor * op) {
char base[256];
char name[256];

snprintf(base, 256, "kernel_cross_entropy_loss_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

if (res) {
ggml_metal_pipeline_set_smem(res, 32 * sizeof(float));
}

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const struct ggml_tensor * op) {
char base[256];
char name[256];

snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s", base);

ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
if (res) {
return res;
}

res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);

if (res) {
ggml_metal_pipeline_set_smem(res, 32 * sizeof(float));
}

return res;
}

ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_OPT_STEP_ADAMW);

Expand Down
3 changes: 3 additions & 0 deletions src/ggml-metal/ggml-metal-device.h
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0 (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argmax (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
Expand All @@ -142,6 +143,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cross_entropy_loss_back(ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);

Expand Down
6 changes: 6 additions & 0 deletions src/ggml-metal/ggml-metal-device.m
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,10 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_CROSS_ENTROPY_LOSS:
return has_simdgroup_reduction;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
return has_simdgroup_reduction;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
Expand Down Expand Up @@ -1017,6 +1021,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return false;
};
}
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return has_simdgroup_reduction;
Expand Down
13 changes: 13 additions & 0 deletions src/ggml-metal/ggml-metal-impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -879,6 +879,19 @@ typedef struct {
uint64_t nb01;
} ggml_metal_kargs_argmax;

typedef struct {
int32_t n_classes;
int32_t n_rows;
} ggml_metal_kargs_cross_entropy_loss;

typedef struct {
int32_t n_classes;
} ggml_metal_kargs_cross_entropy_loss_back;

typedef struct {
int32_t ne0;
} ggml_metal_kargs_count_equal;

typedef struct {
int64_t np;
} ggml_metal_kargs_opt_step_adamw;
Expand Down
119 changes: 119 additions & 0 deletions src/ggml-metal/ggml-metal-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,18 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_argmax(ctx, idx);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS:
{
n_fuse = ggml_metal_op_cross_entropy_loss(ctx, idx);
} break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
{
n_fuse = ggml_metal_op_cross_entropy_loss_back(ctx, idx);
} break;
case GGML_OP_COUNT_EQUAL:
{
n_fuse = ggml_metal_op_count_equal(ctx, idx);
} break;
case GGML_OP_OPT_STEP_ADAMW:
{
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
Expand Down Expand Up @@ -3640,6 +3652,84 @@ int ggml_metal_op_argmax(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];

GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, src0, nb);
GGML_TENSOR_LOCALS( int32_t, ne1, src1, ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, src1, nb);

const int64_t nclasses = ne00;
const int64_t nrows = ggml_nrows(src0);

ggml_metal_kargs_cross_entropy_loss args = {
/*.n_classes =*/ (int32_t) nclasses,
/*.n_rows =*/ (int32_t) nrows,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss(lib, op);

const int nth = 32;

const size_t smem = ggml_metal_pipeline_get_smem(pipeline);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);

return 1;
}

int ggml_metal_op_cross_entropy_loss_back(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

const ggml_tensor * grad = op->src[0];
const ggml_tensor * src0 = op->src[1];
const ggml_tensor * src1 = op->src[2];

GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne);

const int64_t nclasses = ne00;
const int64_t nrows = ggml_nrows(src0);

ggml_metal_kargs_cross_entropy_loss_back args = {
/*.n_classes =*/ (int32_t) nclasses,
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss_back(lib, op);

const int nth = 32;

const size_t smem = ggml_metal_pipeline_get_smem(pipeline);

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(grad), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 3);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 4);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);

ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);

return 1;
}

int ggml_metal_op_argsort(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down Expand Up @@ -3773,6 +3863,34 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
return 1;
}

int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;

const ggml_tensor * src0 = op->src[0];
const ggml_tensor * src1 = op->src[1];

ggml_metal_kargs_count_equal args = {
/*.ne0 =*/ (int32_t) ggml_nelements(src0),
};

ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);

const int nth = 256;

ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src0), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(src1), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);

ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, nth, 1, 1);

return 1;
}

int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

Expand Down Expand Up @@ -3815,6 +3933,7 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;


GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
Expand Down
3 changes: 3 additions & 0 deletions src/ggml-metal/ggml-metal-ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,11 @@ int ggml_metal_op_pad_reflect_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_arange (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cross_entropy_loss_back(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

Expand Down
Loading