Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions operatorspy/tests/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,14 @@ def test_ascend(lib, test_cases) :
test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle)

def test_teco(lib, test_cases):
import torch_sdaa
device = DeviceEnum.DEVICE_TECO
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "sdaa", shape, strides, dtype)
destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
((1, 32, 128), None, torch.float16),
Expand Down Expand Up @@ -215,5 +223,7 @@ def test_ascend(lib, test_cases) :
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if args.teco:
test_teco(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
31 changes: 31 additions & 0 deletions src/ops/rotary_embedding/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
#ifdef ENABLE_ASCEND_NPU
#include "ascend/rotary_embedding.h"
#endif
#ifdef ENABLE_TECO_SDAA
#include "teco/rotary_embedding_sdaa.h"
#endif

struct RoPEDescriptor {
Device device;
Expand Down Expand Up @@ -52,6 +55,15 @@ __C infiniopStatus_t infiniopCreateRoPEDescriptor(infiniopHandle_t handle,
sin_table,
cos_table);
}
#endif
#ifdef ENABLE_TECO_SDAA
case DevTecoSDAA:
return tecoCreateRoPEDescriptor((TecoHandle_t) handle,
(RoPETecoDescriptor_t *) desc_ptr,
t,
pos_ids,
sin_table,
cos_table);
#endif
}
return STATUS_BAD_DEVICE;
Expand Down Expand Up @@ -79,6 +91,11 @@ __C infiniopStatus_t infiniopGetRoPEWorkspaceSize(infiniopRoPEDescriptor_t desc,
return ascendGetRoPEWorkspaceSize((RoPEAscendDescriptor_t) desc,
size);
}
#endif
#ifdef ENABLE_TECO_SDAA
case DevTecoSDAA:
return tecoGetRoPEWorkspaceSize((RoPETecoDescriptor_t) desc,
size);
#endif
}
return STATUS_BAD_DEVICE;
Expand Down Expand Up @@ -119,6 +136,16 @@ __C infiniopStatus_t infiniopRoPE(infiniopRoPEDescriptor_t desc,
cos_table,
stream);
}
#endif
#ifdef ENABLE_TECO_SDAA
case DevTecoSDAA:
return tecoRoPE((RoPETecoDescriptor_t) desc, workspace,
workspace_size,
t,
pos_ids,
sin_table,
cos_table,
stream);
#endif
}
return STATUS_BAD_DEVICE;
Expand All @@ -145,6 +172,10 @@ __C infiniopStatus_t infiniopDestroyRoPEDescriptor(infiniopRoPEDescriptor_t desc
case DevAscendNpu: {
return ascendDestroyRoPEDescriptor((RoPEAscendDescriptor_t) desc);
}
#endif
#ifdef ENABLE_TECO_SDAA
case DevTecoSDAA:
return tecoDestroyRoPEDescriptor((RoPETecoDescriptor_t) desc);
#endif
}
return STATUS_BAD_DEVICE;
Expand Down
43 changes: 43 additions & 0 deletions src/ops/rotary_embedding/teco/rotary_embedding_sdaa.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#ifndef __SDAA_ROPE_H__
#define __SDAA_ROPE_H__
#include "../../../devices/teco/teco_handle.h"
#include "../../utils.h"
#include "operators.h"
#include <sdaa_runtime.h>
struct RoPETecoDescriptor {
Device device;
int device_id;
DT dtype;
uint64_t seqlen;
uint64_t nhead;
uint64_t dhead;
uint64_t total_seqlen;
int x_stride_seqlen;
int x_stride_nhead;
};

typedef struct RoPETecoDescriptor *RoPETecoDescriptor_t;


infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle,
RoPETecoDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t t,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table);

infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size);

infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc,
void *workspace,
uint64_t workspace_size,
void *t,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream);

infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc);


#endif
159 changes: 159 additions & 0 deletions src/ops/rotary_embedding/teco/rotary_embedding_sdaa.scpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
#include "rotary_embedding_sdaa.h"

__local__ halfv16 x_local, y_local;
__local__ floatv16 sin_local, cos_local, tmp_local;

infiniopStatus_t tecoCreateRoPEDescriptor(TecoHandle_t handle,
RoPETecoDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t t,
infiniopTensorDescriptor_t pos_ids,
infiniopTensorDescriptor_t sin_table,
infiniopTensorDescriptor_t cos_table){
if (desc_ptr == nullptr)
return STATUS_MEMORY_NOT_ALLOCATED;

if (t->ndim != 3 ||
pos_ids->ndim != 1 ||
sin_table->ndim != 2 ||
cos_table->ndim != 2)
return STATUS_BAD_TENSOR_SHAPE;

auto seqlen = t->shape[0];
auto nhead = t->shape[1];
auto dhead = t->shape[2];
auto total_seqlen = sin_table->shape[0];

if (dhead % 2 != 0)
return STATUS_BAD_TENSOR_SHAPE;

if (pos_ids->shape[0] != seqlen ||
sin_table->shape[1] != dhead ||
cos_table->shape[1] != dhead ||
sin_table->shape[0] != cos_table->shape[0])
return STATUS_BAD_TENSOR_SHAPE;

if (t->strides[2] != 1 ||
pos_ids->strides[0] != 1 ||
sin_table->strides[1] != 1 ||
cos_table->strides[1] != 1)
return STATUS_BAD_TENSOR_STRIDES;

if (!dtype_eq(t->dt, F16))
return STATUS_BAD_TENSOR_DTYPE;

if (!dtype_eq(sin_table->dt, F32) || !dtype_eq(cos_table->dt, F32))
return STATUS_BAD_TENSOR_DTYPE;

if (!dtype_eq(pos_ids->dt, U64))
return STATUS_BAD_TENSOR_DTYPE;
int x_stride_seqlen = static_cast<int>(t->strides[0]);
int x_stride_nhead = static_cast<int>(t->strides[1]);
*desc_ptr = new RoPETecoDescriptor{
handle->device,
handle->device_id,
t->dt,
seqlen,
nhead,
dhead,
total_seqlen,
x_stride_seqlen,
x_stride_nhead};

return STATUS_SUCCESS;
}

infiniopStatus_t tecoGetRoPEWorkspaceSize(RoPETecoDescriptor_t desc, uint64_t *size) {
*size = 0;
return STATUS_SUCCESS;
}

__global__ void RoPE(half *destination,
const uint64_t *pos_ids,
const float *sin_table, const float *cos_table,
int x_stride_seqlen, int x_stride_nhead,
int seqlen, int nhead, int dhead){
int other_size = seqlen * nhead;
int remain = other_size % threadDim;
int step_easy = (other_size - remain) / threadDim;
int step_hard = step_easy + 1;
int step = (threadIdx < remain ? step_hard : step_easy);
int ind_start = (threadIdx < remain ? threadIdx * step_hard : remain * step_hard + (threadIdx - remain) * step_easy);

int buf_size = 16;
int remain_dhead = dhead % buf_size;
int repeat = (dhead - remain_dhead) / buf_size;

for(int i = ind_start; i < ind_start + step; i++){
int ind_i = i;
int ind_s = 0;

ind_s += (ind_i % nhead) * x_stride_nhead;
ind_i /= nhead;
ind_s += (ind_i % seqlen) * x_stride_seqlen;

int index = static_cast<int>(pos_ids[ind_i % seqlen]) * dhead;

for(int r = 0; r < repeat; r++){
int start_s = ind_s + r * buf_size;
int sin_cos_index = index + r * buf_size;

simd_load(x_local, destination + start_s);
simd_load(sin_local, sin_table + sin_cos_index);
simd_load(cos_local, cos_table + sin_cos_index);

tmp_local = simd_cvt_h2f(x_local);

for(int k = 0; k < buf_size / 2; k++){
float a = tmp_local[2 * k];
float b = tmp_local[2 * k + 1];
float sin0 = sin_local[2 * k], cos0 = cos_local[2 * k];
float sin1 = sin_local[2 * k + 1], cos1 = cos_local[2 * k + 1];
tmp_local[2 * k] = a * cos0 - b * sin0;
tmp_local[2 * k + 1] = a * sin1 + b * cos1;
}
y_local = simd_cvt_f2h(tmp_local);
simd_store(y_local, destination + start_s);

}
if(remain_dhead){
int start_s = ind_s + repeat * buf_size;
int sin_cos_index = index + repeat * buf_size;
for(int k = 0; k < remain_dhead / 2; k++){
float a = static_cast<float>(destination[start_s + 2 * k]);
float b = static_cast<float>(destination[start_s + 2 * k + 1]);
float sin0 = sin_table[sin_cos_index + 2 * k], cos0 = cos_local[sin_cos_index + 2 * k];
float sin1 = sin_local[sin_cos_index + 2 * k + 1], cos1 = cos_local[sin_cos_index + 2 * k + 1];
destination[start_s + 2 * k] = static_cast<half>(a * cos0 - b * sin0);
destination[start_s + 2 * k + 1] = static_cast<half>(a * sin1 + b * cos1);
}
}
}
}

infiniopStatus_t tecoRoPE(RoPETecoDescriptor_t desc,
void *workspace,
uint64_t workspace_size,
void *t,
void const *pos_ids,
void const *sin_table,
void const *cos_table,
void *stream){
auto t_ptr = reinterpret_cast<half *>(t);
auto sin_ptr = reinterpret_cast<const float *>(sin_table);
auto cos_ptr = reinterpret_cast<const float *>(cos_table);
auto pos_ptr = reinterpret_cast<const uint64_t *>(pos_ids);

int seqlen = static_cast<int>(desc->seqlen);
int nhead = static_cast<int>(desc->nhead);
int dhead = static_cast<int>(desc->dhead);
int x_stride_seqlen = desc->x_stride_seqlen;
int x_stride_nhead = desc->x_stride_nhead;

RoPE<<<1, (sdaaStream_t)stream>>>(t_ptr, pos_ptr, sin_ptr, cos_ptr, x_stride_seqlen, x_stride_nhead, seqlen, nhead, dhead);
return STATUS_SUCCESS;
}

infiniopStatus_t tecoDestroyRoPEDescriptor(RoPETecoDescriptor_t desc){
delete desc;
return STATUS_SUCCESS;
}