diff --git a/operatorspy/tests/causal_softmax.py b/operatorspy/tests/causal_softmax.py index bc63d87a..222a456c 100644 --- a/operatorspy/tests/causal_softmax.py +++ b/operatorspy/tests/causal_softmax.py @@ -68,6 +68,7 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float1 None, ) ) + #print(x.flatten()[0], ans.flatten()[0]) assert torch.allclose(x, ans, atol=0, rtol=1e-2) check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor)) @@ -106,6 +107,14 @@ def test_ascend(lib, test_cases): test(lib, handle, "npu", x_shape, x_stride) destroy_handle(lib, handle) +def test_teco(lib, test_cases): + import torch_sdaa + device = DeviceEnum.DEVICE_TECO + handle = create_handle(lib, device) + for x_shape, x_stride in test_cases: + test(lib, handle, "sdaa", x_shape, x_stride) + + destroy_handle(lib, handle) if __name__ == "__main__": test_cases = [ @@ -147,6 +156,8 @@ def test_ascend(lib, test_cases): test_bang(lib, test_cases) if args.ascend: test_ascend(lib, test_cases) - if not (args.cpu or args.cuda or args.bang or args.ascend): + if args.teco: + test_teco(lib, test_cases) + if not (args.cpu or args.cuda or args.bang or args.ascend or args.teco): test_cpu(lib, test_cases) print("Test passed!") diff --git a/src/ops/causal_softmax/operator.cc b/src/ops/causal_softmax/operator.cc index ef10919f..6bab007d 100644 --- a/src/ops/causal_softmax/operator.cc +++ b/src/ops/causal_softmax/operator.cc @@ -7,8 +7,8 @@ #endif #ifdef ENABLE_NV_GPU #include "../../devices/cuda/common_cuda.h" -#include "cuda/causal_softmax.cuh" #include "../../devices/cuda/cuda_handle.h" +#include "cuda/causal_softmax.cuh" #endif #ifdef ENABLE_CAMBRICON_MLU #include "../../devices/bang/bang_handle.h" @@ -18,6 +18,9 @@ #ifdef ENABLE_ASCEND_NPU #include "ascend/causal_softmax_aclnn.h" #endif +#ifdef ENABLE_TECO_SDAA +#include "teco/causal_softmax_sdaa.h" +#endif __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( infiniopHandle_t handle, @@ -30,7 +33,7 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( #endif #ifdef ENABLE_NV_GPU case DevNvGpu: { - return cudaCreateCausalSoftmaxDescriptor((CudaHandle_t)handle, (CausalSoftmaxCudaDescriptor_t *) desc_ptr, y_desc); + return cudaCreateCausalSoftmaxDescriptor((CudaHandle_t) handle, (CausalSoftmaxCudaDescriptor_t *) desc_ptr, y_desc); } #endif @@ -44,6 +47,10 @@ __C infiniopStatus_t infiniopCreateCausalSoftmaxDescriptor( case DevAscendNpu: { return aclnnCreateCausalSoftmaxDescriptor((AscendHandle_t) handle, (CausalSoftmaxAclnnDescriptor_t *) desc_ptr, y_desc); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoCreateCausalSoftmaxDescriptor((TecoHandle_t) handle, (CausalSoftmaxTecoDescriptor_t *) desc_ptr, y_desc); #endif } return STATUS_BAD_DEVICE; @@ -72,6 +79,10 @@ __C infiniopStatus_t infiniopGetCausalSoftmaxWorkspaceSize(infiniopCausalSoftmax case DevAscendNpu: { return aclnnGetCausalSoftmaxWorkspaceSize((CausalSoftmaxAclnnDescriptor_t) desc, size); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoGetCausalSoftmaxWorkspaceSize((CausalSoftmaxTecoDescriptor_t) desc, size); #endif } return STATUS_BAD_DEVICE; @@ -99,6 +110,10 @@ __C infiniopStatus_t infiniopCausalSoftmax(infiniopCausalSoftmaxDescriptor_t des case DevAscendNpu: { return aclnnCausalSoftmax((CausalSoftmaxAclnnDescriptor_t) desc, workspace, workspace_size, data, stream); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoCausalSoftmax((CausalSoftmaxTecoDescriptor_t) desc, workspace, workspace_size, data, stream); #endif } return STATUS_BAD_DEVICE; @@ -126,6 +141,10 @@ __C infiniopStatus_t infiniopDestroyCausalSoftmaxDescriptor(infiniopCausalSoftma case DevAscendNpu: { return aclnnDestroyCausalSoftmaxDescriptor((CausalSoftmaxAclnnDescriptor_t) desc); } +#endif +#ifdef ENABLE_TECO_SDAA + case DevTecoSDAA: + return tecoDestroyCausalSoftmaxDescriptor((CausalSoftmaxTecoDescriptor_t) desc); #endif } return STATUS_BAD_DEVICE; diff --git a/src/ops/causal_softmax/teco/causal_softmax_sdaa.h b/src/ops/causal_softmax/teco/causal_softmax_sdaa.h new file mode 100644 index 00000000..4fa6ce00 --- /dev/null +++ b/src/ops/causal_softmax/teco/causal_softmax_sdaa.h @@ -0,0 +1,34 @@ +#ifndef __SDAA_CAUSAL_SOFTMAX_H__ +#define __SDAA_CAUSAL_SOFTMAX_H__ +#include "../../../devices/teco/teco_handle.h" +#include "../../utils.h" +#include "operators.h" +#include +struct CausalSoftmaxTecoDescriptor { + Device device; + int device_id; + DT dtype; + int ndim; + int *stride; + int *shape; +}; + +typedef struct CausalSoftmaxTecoDescriptor *CausalSoftmaxTecoDescriptor_t; + + +infiniopStatus_t tecoCreateCausalSoftmaxDescriptor(TecoHandle_t handle, + CausalSoftmaxTecoDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc); + +infiniopStatus_t tecoGetCausalSoftmaxWorkspaceSize(CausalSoftmaxTecoDescriptor_t desc, uint64_t *size); + +infiniopStatus_t tecoCausalSoftmax(CausalSoftmaxTecoDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream); + +infiniopStatus_t tecoDestroyCausalSoftmaxDescriptor(CausalSoftmaxTecoDescriptor_t desc); + + +#endif diff --git a/src/ops/causal_softmax/teco/causal_softmax_sdaa.scpp b/src/ops/causal_softmax/teco/causal_softmax_sdaa.scpp new file mode 100644 index 00000000..a8369682 --- /dev/null +++ b/src/ops/causal_softmax/teco/causal_softmax_sdaa.scpp @@ -0,0 +1,243 @@ +#include "causal_softmax_sdaa.h" + +__local__ halfv16 h_local; +__local__ floatv16 f_local; + +infiniopStatus_t tecoCreateCausalSoftmaxDescriptor(TecoHandle_t handle, + CausalSoftmaxTecoDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y_desc){ + if (y_desc->ndim < 2 || y_desc->shape[y_desc->ndim - 1] < y_desc->shape[y_desc->ndim - 2]) { + return STATUS_BAD_TENSOR_SHAPE; + } + + int ndim = y_desc->ndim; + int *shape = (int *)malloc(ndim * sizeof(int)); + int *stride = (int *)malloc(ndim * sizeof(int)); + + + for (int i = 0; i < ndim; i++) { + stride[i] = static_cast(y_desc->strides[i]); + shape[i] = static_cast(y_desc->shape[i]); + } + + *desc_ptr = new CausalSoftmaxTecoDescriptor{ + handle->device, + handle->device_id, + y_desc->dt, + ndim, + stride, + shape}; + return STATUS_SUCCESS; +} + +infiniopStatus_t tecoGetCausalSoftmaxWorkspaceSize(CausalSoftmaxTecoDescriptor_t desc, uint64_t *size) { + *size = desc->ndim * sizeof(int) * 2; + return STATUS_SUCCESS; +} + +template +__global__ void causalSoftmax(T *destination, int *shape, int *stride, int ndim, int mask){ + int othersize = 1; + for(int i = 0; i < ndim - 1; i++){ + othersize *= shape[i]; + } + int remain = othersize % threadDim; + int step_easy = (othersize - remain) / threadDim; + int step_hard = step_easy + 1; + int step = (threadIdx < remain ? step_hard : step_easy); + int ind_start = (threadIdx < remain ? threadIdx * step_hard : (remain * step_hard + (threadIdx - remain) * step_easy)); + + int dimsize = shape[ndim - 1]; + int buf_size = 16; + + for (int i = ind_start; i < ind_start + step; i++) { + int ind_d = 0; + int ind_i = i; + int lastI = ind_i % shape[ndim - 2]; + + int remain_dhead = (lastI + mask + 1) % buf_size; + int repeat = (lastI + mask + 1 - remain_dhead) / buf_size;//针对前面这部分做softmax + + int length = dimsize - (lastI + mask + 1); + int remainI = length % buf_size; + int rI = (length - remainI) / buf_size;//把后面这部分赋值为0 + + for (int j = ndim - 2; j >= 0; --j) { + ind_d += (ind_i % shape[j]) * stride[j]; + ind_i /= shape[j]; + } + //下面开始计算max,sum + + float new_max = destination[ind_d]; + float old_max = new_max; + float sum_value = 0.0f; + for(int r = 0; r < repeat; r++){ + int start = ind_d + r * buf_size; + if constexpr (std::is_same::value){ + simd_load(h_local, destination + start); + f_local = simd_cvt_h2f(h_local); + } + else if constexpr (std::is_same::value){ + simd_load(f_local, destination + start); + } + for(int k = 0; k < buf_size; k++){ + if(new_max < f_local[k]){ + new_max = f_local[k]; + } + } + for(int k = 0; k < buf_size; k++){ + f_local[k] = expf(f_local[k] - new_max); + } + if(r > 0){ + sum_value = sum_value * expf(old_max - new_max); + } + sum_value += simd_redsum(f_local); + old_max = new_max; + } + if(remain_dhead){ + int start = ind_d + repeat * buf_size; + for(int k = 0; k < remain_dhead; k++){ + if constexpr (std::is_same::value){ + if (new_max < static_cast(destination[start + k])){ + new_max = static_cast(destination[start + k]); + } + } + else if constexpr (std::is_same::value){ + if (new_max < destination[start + k]){ + new_max = destination[start + k]; + } + } + } + if (repeat > 0){ + sum_value = sum_value * expf(old_max - new_max); + } + for(int k = 0; k < remain_dhead; k++){ + if constexpr (std::is_same::value){ + sum_value += expf(static_cast(destination[start + k]) - new_max); + } + else if constexpr (std::is_same::value){ + sum_value += expf(destination[start + k] - new_max); + } + } + } + + float sum_inv = 1.0f / sum_value; + //下面开始做softmax变换 + for(int r = 0; r < repeat; r++){ + int start = ind_d + r * buf_size; + if constexpr (std::is_same::value){ + simd_load(h_local, destination + start); + f_local = simd_cvt_h2f(h_local); + } + else if constexpr (std::is_same::value){ + simd_load(f_local, destination + start); + } + + for(int k = 0; k < buf_size; k++){ + f_local[k] = expf(f_local[k] - new_max) * sum_inv; + } + if constexpr (std::is_same::value){ + h_local = simd_cvt_f2h(f_local); + simd_store(h_local, destination + start); + } + else if constexpr (std::is_same::value){ + simd_store(f_local, destination + start); + } + } + if(remain_dhead){ + int start = ind_d + repeat * buf_size; + for(int k = 0; k < remain_dhead; k++){ + if constexpr (std::is_same::value){ + destination[start + k] = static_cast(expf(static_cast(destination[start + k]) - new_max) * sum_inv); + } + else if constexpr (std::is_same::value){ + destination[start + k] = expf(destination[start + k] - new_max) * sum_inv; + } + } + + } + + //针对剩下部分赋值为0 + for(int r = 0; r < rI; r++){ + int start = ind_d + mask + 1 + lastI + r * buf_size; + if constexpr (std::is_same::value){ + for(int k = 0; k < buf_size; k++){ + destination[start + k] = static_cast(0.0f); + } + } + else if constexpr (std::is_same::value){ + for(int k = 0; k < buf_size; k++){ + destination[start + k] = 0.0f; + } + } + /*** + if constexpr (std::is_same::value){ + simd_load(h_local, destination + start); + for(int k = 0; k < buf_size; k++){ + h_local[k] = static_cast(0.0f); + } + simd_store(h_local, destination + start); + } + else if constexpr (std::is_same::value){ + simd_load(f_local, destination + start); + for(int k = 0; k < buf_size; k++){ + f_local[k] = 0.0f; + } + simd_store(f_local, destination + start); + } + ***/ + } + + if (remainI){ + int start = ind_d + mask + 1 + lastI + rI * buf_size; + if constexpr (std::is_same::value){ + for(int k = 0; k < remainI; k++){ + destination[start + k] = static_cast(0.0f); + } + } + else if constexpr (std::is_same::value){ + for(int k = 0; k < remainI; k++){ + destination[start + k] = 0.0f; + } + } + } + + } +} + +infiniopStatus_t tecoCausalSoftmax(CausalSoftmaxTecoDescriptor_t desc, + void *workspace, + uint64_t workspace_size, + void *data, + void *stream){ + int ndim = desc->ndim; + int mask = desc->shape[ndim - 1] - desc->shape[ndim - 2]; + + int *teco_stride = reinterpret_cast(workspace); + int *teco_shape = teco_stride + ndim; + + sdaaMemcpy(teco_stride, desc->stride, ndim * sizeof(int), sdaaMemcpyHostToDevice); + sdaaMemcpy(teco_shape, desc->shape, ndim * sizeof(int), sdaaMemcpyHostToDevice); + sdaaDeviceSynchronize(); + if(dtype_eq(desc->dtype, F16)){ + auto destination = reinterpret_cast(data); + causalSoftmax<<<1, (sdaaStream_t)stream>>>(destination, teco_shape, teco_stride, ndim, mask); + sdaaDeviceSynchronize(); + return STATUS_SUCCESS; + } + else if(dtype_eq(desc->dtype, F32)){ + auto destination = reinterpret_cast(data); + causalSoftmax<<<1, (sdaaStream_t)stream>>>(destination, teco_shape, teco_stride, ndim, mask); + sdaaDeviceSynchronize(); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} + +infiniopStatus_t tecoDestroyCausalSoftmaxDescriptor(CausalSoftmaxTecoDescriptor_t desc){ + //free(desc->stride); + //free(desc->shape); + delete desc; + return STATUS_SUCCESS; +}