diff --git a/include/ops/layer_norm/layer_norm.h b/include/ops/layer_norm/layer_norm.h new file mode 100644 index 00000000..f49af8d8 --- /dev/null +++ b/include/ops/layer_norm/layer_norm.h @@ -0,0 +1,30 @@ +#ifndef LAYER_NORM_H +#define LAYER_NORM_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct LayerNormDescriptor { + Device device; +} LayerNormDescriptor; + +typedef LayerNormDescriptor *infiniopLayerNormDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateLayerNormDescriptor( + infiniopHandle_t handle, + infiniopLayerNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + + +__C infiniopStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor_t desc, uint64_t *size); +__C __export infiniopStatus_t infiniopLayerNorm(infiniopLayerNormDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, void *stream); + +__C __export infiniopStatus_t infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc); + +#endif diff --git a/operatorspy/tests/layer_norm.py b/operatorspy/tests/layer_norm.py new file mode 100644 index 00000000..cad3ed2a --- /dev/null +++ b/operatorspy/tests/layer_norm.py @@ -0,0 +1,168 @@ +from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p, c_float +import ctypes +import sys +import os + + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + create_workspace, + check_error, + rearrange_tensor, +) + +from operatorspy.tests.test_utils import get_args +import torch +import torch.nn as nn + +class LayerNormDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopLayerNormDescriptor_t = POINTER(LayerNormDescriptor) + + +def LayerNormFunction(input, scale, bias, eps): + normlize_shape = scale.shape + layer_norm = nn.LayerNorm(normlize_shape, elementwise_affine=True, eps = eps) + layer_norm.weight.data = scale + layer_norm.bias.data = bias + return layer_norm.forward(input) + + +def test(lib, handle, torch_device, x_shape, axis, x_dtype=torch.float16): + print( + f"Testing Layernorm on {torch_device} with test_shape:{x_shape}, axis:{axis} ,dtype:{x_dtype}" + ) + eps = 1e-5 + ndim = len(x_shape) + normlize_shape = [] + for i in range(axis, ndim): + normlize_shape.append(x_shape[i]) + + x = torch.rand(x_shape, dtype=x_dtype).to(torch_device) + scale = torch.rand(normlize_shape, dtype=x_dtype).to(torch_device) + bias = torch.rand(normlize_shape, dtype=x_dtype).to(torch_device) + y = torch.rand(x_shape, dtype=x_dtype).to(torch_device) + ans = LayerNormFunction(x, scale, bias, eps) + x_tensor = to_tensor(x, lib) + w_tensor = to_tensor(scale, lib) + b_tensor = to_tensor(bias, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopLayerNormDescriptor_t() + check_error( + lib.infiniopCreateLayerNormDescriptor( + handle, ctypes.byref(descriptor), x_tensor.descriptor, w_tensor.descriptor, b_tensor.descriptor, y_tensor.descriptor, eps + ) + ) + workspace_size = c_uint64(0) + check_error( + lib.infiniopGetLayerNormWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = create_workspace(workspace_size.value, torch_device) + check_error( + lib.infiniopLayerNorm( + descriptor, + workspace.data_ptr() if workspace is not None else None, + workspace_size.value, + x_tensor.data, + w_tensor.data, + b_tensor.data, + y_tensor.data, + None, + ) + ) + err = y.reshape(-1,1) - ans.reshape(-1,1) + print(max(abs(err))) + assert torch.allclose(y, ans, atol=1e-3, rtol=1e-3) + check_error(lib.infiniopDestroyLayerNormDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "cpu", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +def test_cuda(lib, test_cases): + device = DeviceEnum.DEVICE_CUDA + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "cuda", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +def test_bang(lib, test_cases): + import torch_mlu + + device = DeviceEnum.DEVICE_BANG + handle = create_handle(lib, device) + for x_shape, axis, x_dtype in test_cases: + test(lib, handle, "mlu", x_shape, axis, x_dtype) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, axis + # cnnllayernorm不支持axis=0, cpu torch.layernorm不支持half + #手写layernorm在float16上精度不足,但是在float32上可以通过测试 + #((32, 20, 512), 0, torch.float16), + ((32, 20, 512), 1, torch.float16), + ((32, 20, 512), 2, torch.float16), + + #((32, 20, 512), 0, torch.float32), + ((32, 20, 512), 1, torch.float32), + ((32, 20, 512), 2, torch.float32), + + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateLayerNormDescriptor.restype = c_int32 + lib.infiniopCreateLayerNormDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopLayerNormDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_float, + ] + + lib.infiniopLayerNorm.restype = c_int32 + lib.infiniopLayerNorm.argtypes = [ + infiniopLayerNormDescriptor_t, + c_void_p, + c_uint64, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyLayerNormDescriptor.restype = c_int32 + lib.infiniopDestroyLayerNormDescriptor.argtypes = [ + infiniopLayerNormDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if args.cuda: + test_cuda(lib, test_cases) + if args.bang: + test_bang(lib, test_cases) + + if not (args.cpu or args.cuda or args.bang): + test_cpu(lib, test_cases) + print("Test passed!") diff --git a/src/ops/layer_norm/bang/layer_norm_bang.cc b/src/ops/layer_norm/bang/layer_norm_bang.cc new file mode 100644 index 00000000..8de3558a --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_bang.cc @@ -0,0 +1,53 @@ +#include "layer_norm_bang.h" +#include "../../utils.h" +infiniopStatus_t bangCreateLayerNormDescriptor(BangHandle_t handle, LayerNormBangDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int size = 1; + int behindsize = 1; + for(int i = 0; i < ndim; i++){ + size *= static_cast(x_desc->shape[i]); + if(i >= ndim - wDim){ + behindsize *= static_cast(x_desc->shape[i]); + } + } + *desc_ptr = new LayerNormBangDescriptor{ + handle->device, + handle->device_id, + x_desc->dt, + size, + behindsize, + epsilon}; + + return STATUS_SUCCESS; +} +infiniopStatus_t bangGetLayerNormWorkspaceSize(LayerNormBangDescriptor_t desc, unsigned long int *size) { + *size = 32 * sizeof(desc->dtype);//taskDim * sizeof(T),taskDim不超过32 + return STATUS_SUCCESS; +} + +infiniopStatus_t bangDestroyLayerNormDescriptor(LayerNormBangDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/layer_norm/bang/layer_norm_bang.h b/src/ops/layer_norm/bang/layer_norm_bang.h new file mode 100644 index 00000000..08afe6c8 --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_bang.h @@ -0,0 +1,36 @@ +#ifndef __BANG_LAYER_NORM_H__ +#define __BANG_LAYER_NORM_H__ + +#include "../../../devices/bang/bang_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct LayerNormBangDescriptor { + Device device; + int device_id; + DT dtype; + int size; + int behindsize; + float epsilon; +}; + +typedef struct LayerNormBangDescriptor *LayerNormBangDescriptor_t; + +infiniopStatus_t bangCreateLayerNormDescriptor(BangHandle_t handle, + LayerNormBangDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + +infiniopStatus_t bangGetLayerNormWorkspaceSize(LayerNormBangDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t bangLayerNorm(LayerNormBangDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream); + +infiniopStatus_t bangDestroyLayerNormDescriptor(LayerNormBangDescriptor_t desc); + +#endif// __BANG_LAYER_NORM_H__ diff --git a/src/ops/layer_norm/bang/layer_norm_bang.mlu b/src/ops/layer_norm/bang/layer_norm_bang.mlu new file mode 100644 index 00000000..36f55504 --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_bang.mlu @@ -0,0 +1,392 @@ +#include "bang.h" +#include "cnrt.h" +#include "layer_norm_bang.h" +#include "../../../devices/bang/common_bang.h" + +const int SRC_MAX_SIZE = 1024 * 16; +__nram__ char nram_buffer[NRAM_MAX_SIZE]; + +template +__mlu_global__ void layer_norm(T const *input, T const *scale, T const *bias, T *output, T *tmpGdram, float eps, int size, int behindsize, int bSize){ + int frontsize = size / behindsize; + const int wSize = 128 / sizeof(T); + + const int maxNum = SRC_MAX_SIZE / sizeof(T); + + + T *src = (T *)nram_buffer;//[maxNum] + T *destSum = src + 3 * maxNum;//[3 * maxNum] + T *destSumFinal = destSum + maxNum;//[wSize] + T *s_src = destSumFinal + wSize;//[3 * maxNum] + T *b_src = s_src + 3 * maxNum;//[3 * maxNum] + //bSize是大于等于behindsize的最小2次幂 + + if (behindsize >= taskDim * maxNum){ + int segNum = maxNum / wSize; + int taskSize = taskDim * maxNum; + int remain = behindsize % taskSize; + int repeat = (behindsize - remain) / taskSize; + + int remainT = remain % taskDim; + int stepEasy = (remain - remainT) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remainT ? stepHard : stepEasy); + int indStart = repeat * taskSize + (taskId < remainT ? taskId * stepHard : (remainT * stepHard + (taskId - remainT) * stepEasy)); + for(int i = 0; i < frontsize; i++){ + int tid = i * behindsize; + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if(j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if(step){ + __bang_write_zero(src, maxNum); + __memcpy(src, input + tid + indStart, step * sizeof(T), GDRAM2NRAM); + __bang_add(destSum, destSum, src, maxNum); + } + __bang_mul_scalar(destSum, destSum, 1.0 / behindsize, maxNum); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0]存储的是当前task对应数据的规约和 + tmpGdram[taskId] = destSumFinal[0]; + __sync_all(); + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __memcpy(destSum, tmpGdram, taskDim * sizeof(T), GDRAM2NRAM); + __bang_reduce_sum(destSumFinal, destSum, wSize); + T mu = destSumFinal[0]; + //下面计算方差 + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if (j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_sub_scalar(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, mu, maxNum); + __bang_mul(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, maxNum); + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if (step){ + __bang_write_value(src, maxNum, mu);//保证后面减去均值为0 + __memcpy(src, input + tid + indStart, step * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul(src, src, src, maxNum); + __bang_add(destSum, destSum, src, maxNum); + } + __bang_mul_scalar(destSum, destSum, 1.0 / behindsize, maxNum); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0]存储的是当前task对应数据的规约和 + + tmpGdram[taskId] = destSumFinal[0]; + __sync_all(); + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __memcpy(destSum, tmpGdram, taskDim * sizeof(T), GDRAM2NRAM); + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = destSumFinal[0] + static_cast(eps); + sigma2 = 1.0 / pow(sigma2, 0.5); + //下面开始做变换 + for(int j = 0; j < repeat + 2; j++){ + if(j < repeat){ + __memcpy_async(src + j % 3 * maxNum, input + tid + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(s_src + j % 3 * maxNum, scale + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(b_src + j % 3 * maxNum, bias + j * taskSize + taskId * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0 && j < repeat + 1){ + __bang_sub_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, mu, maxNum); + __bang_mul_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, sigma2, maxNum); + __bang_mul(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, s_src + (j - 1) % 3 * maxNum, maxNum); + __bang_add(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, b_src + (j - 1) % 3 * maxNum, maxNum); + } + if(j > 1){ + __memcpy_async(output + tid + (j - 2) * taskSize + taskId * maxNum, src + (j - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if (step){ + __memcpy(src, input + tid + indStart, step * sizeof(T), GDRAM2NRAM); + __memcpy(s_src, scale + indStart, step * sizeof(T), GDRAM2NRAM); + __memcpy(b_src, bias + indStart, step * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul_scalar(src, src, sigma2, maxNum); + __bang_mul(src, src, s_src, maxNum); + __bang_add(src, src, b_src, maxNum); + __memcpy(output + tid + indStart, src, step * sizeof(T), NRAM2GDRAM); + } + } + } + else if(behindsize >= maxNum && behindsize < taskDim * maxNum){ + int segNum = maxNum / wSize; + int remainT = behindsize % maxNum; + int repeat = (behindsize - remainT) / maxNum; + + int remain = frontsize % taskDim; + int stepEasy = (frontsize - remain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy); + int indStart = (taskId < remain ? taskId * stepHard : (remain * stepHard + (taskId - remain) * stepEasy)); + for(int i = indStart; i < indStart + step; i++){ + int tid = i * behindsize; + //下面开始计算均值 + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if (j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if (remainT){ + __bang_write_zero(src, maxNum); + __memcpy(src, input + tid + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __bang_add(destSum, destSum, src, maxNum); + } + + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + //下面开始计算方差,destSumFinal[0]存储的就是均值 + T mu = destSumFinal[0] / behindsize; + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + for(int j = 0; j < repeat + 1; j++){ + if(j < repeat){ + __memcpy_async(src + j % 2 * maxNum, input + tid + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0){ + __bang_sub_scalar(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, mu, maxNum); + __bang_mul(src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, src + (j - 1) % 2 * maxNum, maxNum); + __bang_add(destSum, destSum, src + (j - 1) % 2 * maxNum, maxNum); + } + __sync_all_ipu(); + } + if (remainT){ + __bang_write_value(src, maxNum, mu);//保证后面减去均值为0 + __memcpy(src, input + tid + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul(src, src, src, maxNum); + __bang_add(destSum, destSum, src, maxNum); + } + + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = destSumFinal[0] / behindsize + static_cast(eps); + sigma2 = 1.0 / pow(sigma2, 0.5); + //下面开始做变换 + for(int j = 0; j < repeat + 2; j++){ + if(j < repeat){ + __memcpy_async(src + j % 3 * maxNum, input + tid + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(s_src + j % 3 * maxNum, scale + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + __memcpy_async(b_src + j % 3 * maxNum, bias + j * maxNum, maxNum * sizeof(T), GDRAM2NRAM); + } + if(j > 0 && j < repeat + 1){ + __bang_sub_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, mu, maxNum); + __bang_mul_scalar(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, sigma2, maxNum); + __bang_mul(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, s_src + (j - 1) % 3 * maxNum, maxNum); + __bang_add(src + (j - 1) % 3 * maxNum, src + (j - 1) % 3 * maxNum, b_src + (j - 1) % 3 * maxNum, maxNum); + } + if(j > 1){ + __memcpy_async(output + tid + (j - 2) * maxNum, src + (j - 2) % 3 * maxNum, maxNum * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if(remainT){ + __memcpy(src, input + tid + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __memcpy(s_src, scale + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __memcpy(b_src, bias + repeat * maxNum, remainT * sizeof(T), GDRAM2NRAM); + __bang_sub_scalar(src, src, mu, maxNum); + __bang_mul_scalar(src, src, sigma2, maxNum); + __bang_mul(src, src, s_src, maxNum); + __bang_add(src, src, b_src, maxNum); + __memcpy(output + tid + repeat * maxNum, src, remainT * sizeof(T), NRAM2GDRAM); + } + } + } + else{ + int multiple = maxNum / behindsize;//一个core一次可以处理multiple个behindsize + int taskSize = taskDim * multiple; + int remainT = frontsize % taskSize; + int repeat = (frontsize - remainT) / taskSize; + int remain = remainT % taskDim; + int stepEasy = (remainT - remain) / taskDim; + int stepHard = stepEasy + 1; + int step = (taskId < remain ? stepHard : stepEasy); + int indStart = (taskId < remain ? taskId * stepHard : (remain * stepHard + (taskId - remain) * stepEasy)); + int segNum = bSize / wSize; + __memcpy(s_src, scale, behindsize * sizeof(T), GDRAM2NRAM); + __memcpy(b_src, bias, behindsize * sizeof(T), GDRAM2NRAM); + int tid; + for(int i = 0; i < repeat + 2; i++){ + if(i < repeat){ + tid = i * taskSize * behindsize; + __memcpy_async(src + i % 3 * maxNum, input + tid + taskId * multiple * behindsize, multiple * behindsize * sizeof(T), GDRAM2NRAM); + } + if(i > 0 && i < repeat + 1){ + for(int m = 0; m < multiple; m++){ + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __bang_add(destSum, destSum, src + (i - 1) % 3 * maxNum + m *behindsize, behindsize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0] / behindsize = mu + T mu = destSumFinal[0] / behindsize; + __bang_write_zero(destSum, maxNum); + __bang_sub_scalar(destSum, src + (i - 1) % 3 * maxNum + m * behindsize, mu, behindsize); + + __bang_mul(destSum, destSum, destSum, bSize); + __bang_write_zero(destSumFinal, wSize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = 1.0 / (pow(destSumFinal[0] / behindsize + static_cast(eps), 0.5)); + //下面开始做变换 + __bang_sub_scalar(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, mu, behindsize); + __bang_mul_scalar(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, sigma2, behindsize); + __bang_mul(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, s_src, behindsize); + __bang_add(src + (i - 1) % 3 * maxNum + m * behindsize, src + (i - 1) % 3 * maxNum + m * behindsize, b_src, behindsize); + } + } + if(i > 1){ + tid = (i - 2) * taskSize * behindsize; + __memcpy_async(output + tid + taskId * multiple * behindsize, src + (i - 2) % 3 * maxNum, multiple * behindsize * sizeof(T), NRAM2GDRAM); + } + __sync_all_ipu(); + } + if(step){ + int tid = (repeat * taskSize + indStart) * behindsize; + __memcpy(src, input + tid, step * behindsize * sizeof(T), GDRAM2NRAM); + for(int m = 0; m < step; m++){ + __bang_write_zero(destSum, maxNum); + __bang_write_zero(destSumFinal, wSize); + __bang_add(destSum, destSum, src + m *behindsize, behindsize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + __bang_reduce_sum(destSumFinal, destSum, wSize);//destSumFinal[0] / behindsize = mu + T mu = destSumFinal[0] / behindsize; + __bang_write_zero(destSum, maxNum); + __bang_sub_scalar(destSum, src + m * behindsize, mu, behindsize); + + __bang_mul(destSum, destSum, destSum, bSize); + __bang_write_zero(destSumFinal, wSize); + for(int strip = segNum/2; strip > 0; strip = strip / 2){ + for(int i = 0; i < strip ; i++){ + __bang_add(destSum + i * wSize, destSum + i * wSize, destSum + (i + strip) * wSize, wSize); + } + } + + __bang_reduce_sum(destSumFinal, destSum, wSize); + T sigma2 = 1.0 / (pow(destSumFinal[0] / behindsize + static_cast(eps), 0.5)); + //下面开始做变换 + __bang_sub_scalar(src + m * behindsize, src + m * behindsize, mu, behindsize); + __bang_mul_scalar(src + m * behindsize, src + m * behindsize, sigma2, behindsize); + __bang_mul(src + m * behindsize, src + m * behindsize, s_src, behindsize); + __bang_add(src + m * behindsize, src + m * behindsize, b_src, behindsize); + + } + __memcpy(output + tid, src, step * behindsize * sizeof(T), NRAM2GDRAM); + } + } +} +template +void layer_normUnion(cnrtQueue_t queue, void *workspace, + uint64_t workspace_size, void const *input, void const *scale, void const *bias, void *output, float eps, int size, int behindsize){ + int wSize = 128 / sizeof(T); + int bSize; + float mi = log2(behindsize); + if (floor(mi) == mi) + { + bSize = behindsize; + } + else + { + bSize = static_cast(pow(2, floor(mi) + 1)); + } + if (bSize < wSize) + { + bSize = wSize; + } + auto source = reinterpret_cast(input); + auto weight = reinterpret_cast(scale); + auto _bias = reinterpret_cast(bias); + auto destination = reinterpret_cast(output); + + cnrtDim3_t k_dim; + cnrtFunctionType_t k_type; + + k_dim.x = 4; + k_dim.y = 1; + k_dim.z = 1; + + k_type = CNRT_FUNC_TYPE_UNION1; + T *tmpGdram = reinterpret_cast(workspace); + + layer_norm<<>>(source, weight, _bias, destination, tmpGdram, eps, size, behindsize, bSize); + + cnrtQueueSync(queue); +} +void layer_norm_bang(LayerNormBangDescriptor_t desc, void *workspace, + uint64_t workspace_size, void const *x, void const *w, void const *b, void *y, + void *stream){ + auto queue = reinterpret_cast(stream); + auto eps = desc->epsilon;//float + int size = desc->size; + int behindsize = desc->behindsize; + if (dtype_eq(desc->dtype, F16)){ + layer_normUnion(queue, workspace, workspace_size, x, w, b, y, eps, size, behindsize); + } + else if (dtype_eq(desc->dtype, F32)){ + layer_normUnion(queue, workspace, workspace_size, x, w, b, y, eps, size, behindsize); + } +} +infiniopStatus_t bangLayerNorm(LayerNormBangDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, + void const *w, + void const *b, + void *y, + void *stream) { + if (cnrtSetDevice(desc->device_id) != cnrtSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + layer_norm_bang(desc, workspace, workspace_size, x, w, b, y, stream); + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/layer_norm/bang/layer_norm_cnnl.cc b/src/ops/layer_norm/bang/layer_norm_cnnl.cc new file mode 100644 index 00000000..e4c2d652 --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_cnnl.cc @@ -0,0 +1,180 @@ +#include "layer_norm_cnnl.h" +#include "../../utils.h" +infiniopStatus_t cnnlCreateLayerNormDescriptor(BangHandle_t handle, LayerNormCnnlDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int axis = ndim - wDim; + std::vector inDim(ndim); + std::vector outDim(ndim); + std::vector filter_biasDim(wDim); + std::vector mean_rstdDim(axis); + size_t mean_rstd_size = 1; + for (int i = 0; i < ndim; i++) { + inDim[i] = static_cast(x_desc->shape[i]); + outDim[i] = static_cast(x_desc->shape[i]); + if(i >= axis){ + filter_biasDim[i - axis] = static_cast(x_desc->shape[i]); + } + else{ + mean_rstdDim[i] = static_cast(x_desc->shape[i]); + mean_rstd_size *= static_cast(x_desc->shape[i]); + } + } + size_t dtype_size = 0; + cnnlTensorDescriptor_t yDesc, xDesc, filter_bias_desc, mean_rstd_desc; + cnnlCreateTensorDescriptor(&yDesc); + cnnlCreateTensorDescriptor(&xDesc); + cnnlCreateTensorDescriptor(&filter_bias_desc); + cnnlCreateTensorDescriptor(&mean_rstd_desc); + + if(dtype_eq(x_desc->dt, F16)){ + cnnlGetSizeOfDataType(CNNL_DTYPE_HALF, &dtype_size); + cnnlSetTensorDescriptor( + xDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_HALF, + inDim.size(), inDim.data()); + cnnlSetTensorDescriptor( + yDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_HALF, + outDim.size(), outDim.data()); + cnnlSetTensorDescriptor( + filter_bias_desc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_HALF, + filter_biasDim.size(), filter_biasDim.data()); + cnnlSetTensorDescriptor( + mean_rstd_desc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_HALF, + mean_rstdDim.size(), mean_rstdDim.data()); + } + else if(dtype_eq(x_desc->dt, F32)){ + cnnlGetSizeOfDataType(CNNL_DTYPE_FLOAT, &dtype_size); + cnnlSetTensorDescriptor( + xDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + inDim.size(), inDim.data()); + cnnlSetTensorDescriptor( + yDesc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + outDim.size(), outDim.data()); + cnnlSetTensorDescriptor( + filter_bias_desc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + filter_biasDim.size(), filter_biasDim.data()); + cnnlSetTensorDescriptor( + mean_rstd_desc, CNNL_LAYOUT_ARRAY, CNNL_DTYPE_FLOAT, + mean_rstdDim.size(), mean_rstdDim.data()); + } + + + + size_t size_mean_rstd = mean_rstd_size * dtype_size; + size_t wsSize; + cnrtQueue_t queue; + CNRT_CHECK(cnrtQueueCreate(&queue)); + use_cnnl(handle->cnnl_handles, handle->device_id, queue, + [&](cnnlHandle_t handle) { + cnnlGetLayerNormOpWorkspaceSize(handle, axis, xDesc, &wsSize); + }); + CNRT_CHECK(cnrtQueueDestroy(queue)); + + *desc_ptr = new LayerNormCnnlDescriptor{ + handle->device, + handle->device_id, + x_desc->dt, + handle->cnnl_handles, + xDesc, + yDesc, + filter_bias_desc, + mean_rstd_desc, + axis, + size_mean_rstd, + wsSize, + epsilon}; + + return STATUS_SUCCESS; +} +infiniopStatus_t cnnlGetLayerNormWorkspaceSize(LayerNormCnnlDescriptor_t desc, unsigned long int *size) { + *size = 2 * desc->size_mean_rstd + desc->wsSize; + return STATUS_SUCCESS; +} +template +void layerNorm_cnnl(LayerNormCnnlDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream) { + + cnnlTensorDescriptor_t xDesc = desc->xDesc; + cnnlTensorDescriptor_t yDesc = desc->yDesc; + cnnlTensorDescriptor_t filter_bias_desc = desc->filter_bias_desc; + cnnlTensorDescriptor_t mean_rstd_desc = desc->mean_rstd_desc; + int axis = desc->axis; + float eps = desc->epsilon; + + T *mean_dev = reinterpret_cast(workspace); + T *rstd_dev = mean_dev + desc->size_mean_rstd; + + void *workspace_extra = reinterpret_cast(workspace) + 2 * desc->size_mean_rstd; + int wsSize = (int)workspace_size - 2 * desc->size_mean_rstd; + use_cnnl(desc->cnnl_handles, desc->device_id, (cnrtQueue_t) stream, + [&](cnnlHandle_t handle) { + cnnlLayerNormForward(handle, + xDesc, + x, + axis, + filter_bias_desc, + w, + b, + eps, + workspace_extra, + wsSize, + yDesc, + y, + mean_rstd_desc, + mean_dev, + rstd_dev); + }); + +} +infiniopStatus_t cnnlLayerNorm(LayerNormCnnlDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream) { + if (cnrtSetDevice(desc->device_id) != cnrtSuccess) { + return STATUS_BAD_DEVICE; + } + + if (dtype_eq(desc->dtype, F16)) { + layerNorm_cnnl(desc, workspace, workspace_size, x, w, b, y, stream); + + return STATUS_SUCCESS; + } + if (dtype_eq(desc->dtype, F32)) { + layerNorm_cnnl(desc, workspace, workspace_size, x, w, b, y, stream); + + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} +infiniopStatus_t cnnlDestroyLayerNormDescriptor(LayerNormCnnlDescriptor_t desc) { + desc->cnnl_handles = nullptr; + cnnlDestroyTensorDescriptor(desc->xDesc); + cnnlDestroyTensorDescriptor(desc->yDesc); + cnnlDestroyTensorDescriptor(desc->filter_bias_desc); + cnnlDestroyTensorDescriptor(desc->mean_rstd_desc); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/layer_norm/bang/layer_norm_cnnl.h b/src/ops/layer_norm/bang/layer_norm_cnnl.h new file mode 100644 index 00000000..eae2f6e9 --- /dev/null +++ b/src/ops/layer_norm/bang/layer_norm_cnnl.h @@ -0,0 +1,42 @@ +#ifndef __CNNL_LAYER_NORM_H__ +#define __CNNL_LAYER_NORM_H__ + +#include "../../../devices/bang/bang_handle.h" +#include "../../utils.h" +#include "operators.h" + +struct LayerNormCnnlDescriptor { + Device device; + int device_id; + DT dtype; + std::shared_ptr> cnnl_handles; + cnnlTensorDescriptor_t xDesc; + cnnlTensorDescriptor_t yDesc; + cnnlTensorDescriptor_t filter_bias_desc; + cnnlTensorDescriptor_t mean_rstd_desc; + int axis; + size_t size_mean_rstd; + size_t wsSize; + float epsilon; +}; + +typedef struct LayerNormCnnlDescriptor *LayerNormCnnlDescriptor_t; + +infiniopStatus_t cnnlCreateLayerNormDescriptor(BangHandle_t handle, + LayerNormCnnlDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + +infiniopStatus_t cnnlGetLayerNormWorkspaceSize(LayerNormCnnlDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t cnnlLayerNorm(LayerNormCnnlDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream); + +infiniopStatus_t cnnlDestroyLayerNormDescriptor(LayerNormCnnlDescriptor_t desc); + +#endif// __CNNL_LAYER_NORM_H__ diff --git a/src/ops/layer_norm/cpu/layer_norm_cpu.cc b/src/ops/layer_norm/cpu/layer_norm_cpu.cc new file mode 100644 index 00000000..6dd9bcc3 --- /dev/null +++ b/src/ops/layer_norm/cpu/layer_norm_cpu.cc @@ -0,0 +1,129 @@ +#include "layer_norm_cpu.h" +#include "../../../devices/cpu/common_cpu.h" +#include "../../utils.h" +#include + +infiniopStatus_t cpuCreateLayerNormDescriptor(infiniopHandle_t handle, LayerNormCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int size = 1; + int behindsize = 1; + for(int i = 0; i < ndim; i++){ + size *= static_cast(x_desc->shape[i]); + if(i >= ndim - wDim){ + behindsize *= static_cast(x_desc->shape[i]); + } + } + + *desc_ptr = new LayerNormCpuDescriptor{ + handle->device, + x_desc->dt, + size, + behindsize, + epsilon}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyLayerNormDescriptor(LayerNormCpuDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} + +void layer_norm_cpu(LayerNormCpuDescriptor_t desc, void const *x, void const *w, void const *b, void *y) { + int size = desc->size; + int behindsize = desc->behindsize; + int frontsize = size / behindsize; + float eps = desc->epsilon; + if (dtype_eq(desc->dtype, F32)) + { + auto source = reinterpret_cast(x); + auto weight = reinterpret_cast(w); + auto _bias = reinterpret_cast(b); + auto destination = reinterpret_cast(y); + for (int i = 0; i < frontsize; i++) + { + int tid = i * behindsize; + float mu = 0.0f; + for (int id = 0; id < behindsize; id++) + { + mu += source[tid + id]; + } + mu /= behindsize; + float sigma2Partial = 0.0f; + for (int id = 0; id < behindsize; id++) + { + sigma2Partial += (source[tid + id] - mu) * (source[tid + id] - mu); + } + float sigma2 = 1.0f / sqrt(sigma2Partial / behindsize + eps); + for (int id = 0; id < behindsize; id++) + { + destination[tid + id] = (source[tid + id] - mu) * weight[id] * sigma2 + _bias[id]; + } + } + } + else if (dtype_eq(desc->dtype, F16)) + { + auto source = reinterpret_cast(x); + auto weight = reinterpret_cast(w); + auto _bias = reinterpret_cast(b); + auto destination = reinterpret_cast(y); + for (int i = 0; i < frontsize; i++) + { + int tid = i * behindsize; + float mu = 0.0f; + for (int id = 0; id < behindsize; id++) + { + mu += f16_to_f32(source[tid + id]); + } + mu /= behindsize; + float sigma2Partial = 0.0f; + for (int id = 0; id < behindsize; id++) + { + sigma2Partial += (f16_to_f32(source[tid + id]) - mu) * (f16_to_f32(source[tid + id]) - mu); + } + float sigma2 = 1.0f / sqrt(sigma2Partial / behindsize + eps); + for (int id = 0; id < behindsize; id++) + { + float tmp = (f16_to_f32(source[tid + id]) - mu) * f16_to_f32(weight[id]) * sigma2 + f16_to_f32(_bias[id]); + destination[tid + id] = f32_to_f16(tmp); + } + } + } +} +infiniopStatus_t cpuGetLayerNormWorkspaceSize(LayerNormCpuDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} +infiniopStatus_t cpuLayerNorm(LayerNormCpuDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream) { + if (dtype_eq(desc->dtype, F16) || dtype_eq(desc->dtype, F32)) { + layer_norm_cpu(desc, x, w, b, y); + return STATUS_SUCCESS; + } + + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/layer_norm/cpu/layer_norm_cpu.h b/src/ops/layer_norm/cpu/layer_norm_cpu.h new file mode 100644 index 00000000..1428dab5 --- /dev/null +++ b/src/ops/layer_norm/cpu/layer_norm_cpu.h @@ -0,0 +1,31 @@ +#ifndef __CPU_LAYER_NORM_H__ +#define __CPU_LAYER_NORM_H__ + +#include "operators.h" + +struct LayerNormCpuDescriptor { + Device device; + DT dtype; + int size; + int behindsize; + float epsilon; +}; + +typedef struct LayerNormCpuDescriptor *LayerNormCpuDescriptor_t; + +infiniopStatus_t cpuCreateLayerNormDescriptor(infiniopHandle_t handle, LayerNormCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + +infiniopStatus_t cpuGetLayerNormWorkspaceSize(LayerNormCpuDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t cpuLayerNorm(LayerNormCpuDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream); +infiniopStatus_t cpuDestroyLayerNormDescriptor(LayerNormCpuDescriptor_t desc); + +#endif// __CPU_LAYER_NORM_H__ diff --git a/src/ops/layer_norm/cuda/layer_norm.cc b/src/ops/layer_norm/cuda/layer_norm.cc new file mode 100644 index 00000000..74ad2200 --- /dev/null +++ b/src/ops/layer_norm/cuda/layer_norm.cc @@ -0,0 +1,56 @@ +#include "layer_norm.cuh" +#include "../../utils.h" +#include "../../../devices/cuda/common_cuda.h" + +infiniopStatus_t cudaCreateLayerNormDescriptor(CudaHandle_t handle, + LayerNormCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + if (w_desc->ndim != b_desc->ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + int wDim = w_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(w_desc->shape[i] != b_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + int ndim = x_desc->ndim; + for(int i = 0; i < wDim; i++){ + if(x_desc->shape[i + ndim - wDim] != w_desc->shape[i]){ + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (!dtype_eq(x_desc->dt, F16) && !dtype_eq(x_desc->dt, F32)) { + return STATUS_BAD_TENSOR_DTYPE; + } + int size = 1; + int behindsize = 1; + for(int i = 0; i < ndim; i++){ + size *= static_cast(x_desc->shape[i]); + if(i >= ndim - wDim){ + behindsize *= static_cast(x_desc->shape[i]); + } + } + *desc_ptr = new LayerNormCudaDescriptor{ + handle->device, + handle->device_id, + x_desc->dt, + size, + behindsize, + epsilon}; + + return STATUS_SUCCESS; +} +infiniopStatus_t cudaGetLayerNormWorkspaceSize(LayerNormCudaDescriptor_t desc, unsigned long int *size) { + *size = 0; + return STATUS_SUCCESS; +} + +infiniopStatus_t cudaDestroyLayerNormDescriptor(LayerNormCudaDescriptor_t desc) { + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/layer_norm/cuda/layer_norm.cu b/src/ops/layer_norm/cuda/layer_norm.cu new file mode 100644 index 00000000..354a30fe --- /dev/null +++ b/src/ops/layer_norm/cuda/layer_norm.cu @@ -0,0 +1,154 @@ +#include "../../../devices/cuda/common_cuda.h" +#include "../../utils.h" +#include "layer_norm.cuh" +#include + +template +__launch_bounds__(BLOCK_DIM) + __global__ void blockLayernormKernel(T const *input, T const *scale, T const *bias, T *output, float eps, int behindsize) { + // 假设input= [A, B, C, D], axis = 2, frontsize = AB = blockDim.x, behindsize = CD + // 全局索引index = i(BCD) + j (CD) + k(D) + s + // blockIdx.x = i(B) + j;默认behindsize >= BLOCK_DIM + // scale,bias长度为behindsize,形状为[C,D] + int tid = blockIdx.x * behindsize; + float muPartial = 0.0f; + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM) { + muPartial += static_cast(input[tid + id]);// half很多操作不支持,运算过程使用float数据 + } + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage temp_storage; + __shared__ float mu; + float muBlock = BlockReduce(temp_storage).Reduce(muPartial, cub::Sum()); + if (threadIdx.x == 0) { + mu = muBlock * __fdividef(1.0F, behindsize); + }// threadIdx.x = 0对应的是全局sum + __syncthreads(); + float sigma2Partial = 0.0f; + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM) { + sigma2Partial += (static_cast(input[tid + id]) - mu) * (static_cast(input[tid + id]) - mu); + } + __shared__ float sigma2; + float sigma2Block = BlockReduce(temp_storage).Reduce(sigma2Partial, cub::Sum()); + if (threadIdx.x == 0) { + float sigmaTmp = sqrt(sigma2Block * __fdividef(1.0F, behindsize) + eps); + sigma2 = __fdividef(1.0F, sigmaTmp); + } + __syncthreads(); + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM) { + output[tid + id] = static_cast(static_cast(scale[id]) * (static_cast(input[tid + id]) - mu) * sigma2 + static_cast(bias[id])); + } +} +template +struct SumOp { + __device__ __forceinline__ T operator()(const T &a, const T &b) const { + return a + b; + } +}; + +template class ReductionOp, typename T, + int thread_group_width> +__inline__ __device__ T WarpAllReduce(T val) { + for (int mask = thread_group_width / 2; mask > 0; mask /= 2) { + val = ReductionOp()(val, __shfl_xor_sync(0xffffffff, val, mask)); + } + return val; +} +template +__global__ void warpLayernormKernel(T const *input, T const *scale, T const *bias, T *output, float eps, int behindsize) { + // 默认behindsize < 1024 + int otherIdx = blockIdx.x * blockDim.y + threadIdx.y; + int tid = otherIdx * behindsize; + float muPartial = 0.0f; + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM_x) { + muPartial += static_cast(input[tid + id]); + } + muPartial = WarpAllReduce(muPartial); + __shared__ float mu[BLOCK_DIM_y]; + + if (threadIdx.x == 0) { + mu[threadIdx.y] = muPartial * __fdividef(1.0F, behindsize); + }// threadIdx.x = 0对应的是全局sum + __syncthreads(); + float sigma2Partial = 0.0f; + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM_x) { + sigma2Partial += (static_cast(input[tid + id]) - mu[threadIdx.y]) * (static_cast(input[tid + id]) - mu[threadIdx.y]); + } + sigma2Partial = WarpAllReduce(sigma2Partial); + __shared__ float sigma2[BLOCK_DIM_y]; + + if (threadIdx.x == 0) { + float sigmaTmp = sqrt(sigma2Partial * __fdividef(1.0F, behindsize) + eps); + sigma2[threadIdx.y] = __fdividef(1.0F, sigmaTmp); + } + __syncthreads(); + for (int id = threadIdx.x; id < behindsize; id += BLOCK_DIM_x) { + output[tid + id] = static_cast(static_cast(scale[id]) * (static_cast(input[tid + id]) - mu[threadIdx.y]) * sigma2[threadIdx.y] + static_cast(bias[id])); + } +} + +template +void layer_norm_nv_gpu(LayerNormCudaDescriptor_t desc, void const *input, void const *scale, void const *bias, void *output, void *stream) { + int size = desc->size; + int behindsize = desc->behindsize; + int num_blocks = size / behindsize; + float eps = desc->epsilon; + if (behindsize >= 1024) { + int BLOCK_DIM = 1024; + blockLayernormKernel + <<>>((T *) input, (T *) scale, (T *) bias, (T *) output, eps, behindsize); + } else if (behindsize > 31) { + int BLOCK_DIM_x = 32; + int BLOCK_DIM_y = 32; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLayernormKernel + <<>>((T *) input, (T *) scale, (T *) bias, (T *) output, eps, behindsize); + } else if (behindsize > 15) { + int BLOCK_DIM_x = 16; + int BLOCK_DIM_y = 64; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLayernormKernel + <<>>((T *) input, (T *) scale, (T *) bias, (T *) output, eps, behindsize); + } else if (behindsize > 7) { + int BLOCK_DIM_x = 8; + int BLOCK_DIM_y = 128; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLayernormKernel + <<>>((T *) input, (T *) scale, (T *) bias, (T *) output, eps, behindsize); + } else { + int BLOCK_DIM_x = 4; + int BLOCK_DIM_y = 256; + int num_block_x = (num_blocks + BLOCK_DIM_y - 1) / BLOCK_DIM_y; + dim3 block_dim(BLOCK_DIM_x, BLOCK_DIM_y, 1); + dim3 grid_dim(num_block_x, 1, 1); + + warpLayernormKernel + <<>>((T *) input, (T *) scale, (T *) bias, (T *) output, eps, behindsize); + } +} + +infiniopStatus_t cudaLayerNorm(LayerNormCudaDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream) { + if (cudaSetDevice(desc->device_id) != cudaSuccess) { + return STATUS_BAD_DEVICE; + } + if (dtype_eq(desc->dtype, F16)) { + layer_norm_nv_gpu(desc, x, w, b, y, stream); + return STATUS_SUCCESS; + } + if (dtype_eq(desc->dtype, F32)) { + layer_norm_nv_gpu(desc, x, w, b, y, stream); + return STATUS_SUCCESS; + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/layer_norm/cuda/layer_norm.cuh b/src/ops/layer_norm/cuda/layer_norm.cuh new file mode 100644 index 00000000..6cdb1bb6 --- /dev/null +++ b/src/ops/layer_norm/cuda/layer_norm.cuh @@ -0,0 +1,35 @@ +#ifndef __NV_GPU_LAYER_NORM_H__ +#define __NV_GPU_LAYER_NORM_H__ + +#include "../../../devices/cuda/cuda_handle.h" +#include "operators.h" + +struct LayerNormCudaDescriptor { + Device device; + int device_id; + DT dtype; + int size; + int behindsize; + float epsilon; +}; + +typedef struct LayerNormCudaDescriptor *LayerNormCudaDescriptor_t; + +infiniopStatus_t cudaCreateLayerNormDescriptor(CudaHandle_t handle, + LayerNormCudaDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon); + +infiniopStatus_t cudaGetLayerNormWorkspaceSize(LayerNormCudaDescriptor_t desc, unsigned long int *size); + +infiniopStatus_t cudaLayerNorm(LayerNormCudaDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, + void *stream); + +infiniopStatus_t cudaDestroyLayerNormDescriptor(LayerNormCudaDescriptor_t desc); + +#endif// __NV_GPU_LAYER_NORM_H__ diff --git a/src/ops/layer_norm/operator.cc b/src/ops/layer_norm/operator.cc new file mode 100644 index 00000000..5a29079f --- /dev/null +++ b/src/ops/layer_norm/operator.cc @@ -0,0 +1,111 @@ +#include "../utils.h" +#include "operators.h" +#include "ops/layer_norm/layer_norm.h" + +#ifdef ENABLE_CPU +#include "cpu/layer_norm_cpu.h" +#endif +#ifdef ENABLE_NV_GPU +#include "../../devices/cuda/common_cuda.h" +#include "../../devices/cuda/cuda_handle.h" +#include "cuda/layer_norm.cuh" +#endif +#ifdef ENABLE_CAMBRICON_MLU +#include "../../devices/bang/bang_handle.h" +#include "bang/layer_norm_bang.h" +#include "bang/layer_norm_cnnl.h" +#endif + +__C infiniopStatus_t infiniopCreateLayerNormDescriptor( + infiniopHandle_t handle, + infiniopLayerNormDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t x_desc, + infiniopTensorDescriptor_t w_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t y_desc, + float epsilon) { + switch (handle->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuCreateLayerNormDescriptor(handle, (LayerNormCpuDescriptor_t *) desc_ptr, x_desc, w_desc, b_desc, y_desc, epsilon); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaCreateLayerNormDescriptor((CudaHandle_t) handle, (LayerNormCudaDescriptor_t *) desc_ptr,x_desc, w_desc, b_desc, y_desc, epsilon); + } +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangCreateLayerNormDescriptor((BangHandle_t) handle, (LayerNormBangDescriptor_t *) desc_ptr, x_desc, w_desc, b_desc, y_desc, epsilon); + return cnnlCreateLayerNormDescriptor((BangHandle_t) handle, (LayerNormCnnlDescriptor_t *) desc_ptr, x_desc, w_desc, b_desc, y_desc, epsilon); + } +#endif + } + return STATUS_BAD_DEVICE; +} +__C infiniopStatus_t infiniopGetLayerNormWorkspaceSize(infiniopLayerNormDescriptor_t desc, uint64_t *size) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuGetLayerNormWorkspaceSize((LayerNormCpuDescriptor_t) desc, size); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaGetLayerNormWorkspaceSize((LayerNormCudaDescriptor_t) desc, size); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangGetLayerNormWorkspaceSize((LayerNormBangDescriptor_t) desc, size); + return cnnlGetLayerNormWorkspaceSize((LayerNormCnnlDescriptor_t) desc, size); + } +#endif + } + return STATUS_BAD_DEVICE; +} +__C infiniopStatus_t infiniopLayerNorm(infiniopLayerNormDescriptor_t desc, void *workspace, + uint64_t workspace_size, + void const *x, void const *w, void const *b, void *y, void *stream) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuLayerNorm((LayerNormCpuDescriptor_t) desc, workspace, workspace_size, x, w, b, y, stream); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaLayerNorm((LayerNormCudaDescriptor_t) desc, workspace, workspace_size, x, w, b, y, stream); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangLayerNorm((LayerNormBangDescriptor_t) desc, workspace, workspace_size, x, w, b, y, stream); + return cnnlLayerNorm((LayerNormCnnlDescriptor_t) desc, workspace, workspace_size, x, w, b, y, stream); + } +#endif + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyLayerNormDescriptor(infiniopLayerNormDescriptor_t desc) { + switch (desc->device) { +#ifdef ENABLE_CPU + case DevCpu: + return cpuDestroyLayerNormDescriptor((LayerNormCpuDescriptor_t) desc); +#endif +#ifdef ENABLE_NV_GPU + case DevNvGpu: { + return cudaDestroyLayerNormDescriptor((LayerNormCudaDescriptor_t) desc); + } + +#endif +#ifdef ENABLE_CAMBRICON_MLU + case DevCambriconMlu: { + //return bangDestroyLayerNormDescriptor((LayerNormBangDescriptor_t) desc); + return cnnlDestroyLayerNormDescriptor((LayerNormCnnlDescriptor_t) desc); + } +#endif + } + return STATUS_BAD_DEVICE; +} diff --git a/src/ops/rms_norm/bang/rms_norm_cnnl.cc b/src/ops/rms_norm/bang/rms_norm_cnnl.cc deleted file mode 100644 index 01e9aacd..00000000 --- a/src/ops/rms_norm/bang/rms_norm_cnnl.cc +++ /dev/null @@ -1,56 +0,0 @@ -#include "rms_norm_cnnl.h" -#include "../../../devices/bang/common_bang.h" -#include "../../../devices/bang/handle_pool.h" -#include "../../utils.h" -#include "cnrt.h" - -RMSNormCnnlDescriptor::RMSNormCnnlDescriptor(Device device) { - this->device = device; - get_cnnl_pool(); -} - -void rms_norm_cnnl_f16(Tensor y, Tensor x, Tensor w, float epsilon, void *stream) { - ASSERT_EQ(y.layout->ndim, 2); - ASSERT_EQ(x.layout->ndim, 2); - ASSERT_EQ(w.layout->ndim, 1); - - auto n = y.layout->shape[0], - d = y.layout->shape[1]; - - ASSERT_EQ(x.layout->shape[0], n); - ASSERT_EQ(x.layout->shape[1], d); - ASSERT_EQ(w.layout->shape[0], d); - - cnnlTensorDescriptor_t yDesc, xDesc, wDesc; - cnnlCreateTensorDescriptor(&yDesc); - cnnlCreateTensorDescriptor(&xDesc); - cnnlCreateTensorDescriptor(&wDesc); - setCnnlTensor(yDesc, y.layout); - setCnnlTensor(xDesc, x.layout); - setCnnlTensor(wDesc, w.layout); - - cnnlFuseNormDescriptor_t opDesc; - cnnlCreateFuseNormDescriptor(&opDesc); - cnnlSetFuseNormDescriptor(opDesc, epsilon, 1.0, true, - false, false, false, false, - CNNL_DTYPE_HALF, CNNL_TRANSFORMER_RMSNORM); - - void *workspace; - - use_cnnl((cnrtQueue_t) stream, - [&](cnnlHandle_t handle) { - size_t wsSize; - cnnlGetFuseNormWorkspaceSize(handle, opDesc, xDesc, &wsSize); - cnrtMalloc(&workspace, wsSize); - cnnlFuseNorm(handle, opDesc, xDesc, x.data, - wDesc, w.data, nullptr, nullptr, - nullptr, nullptr, nullptr, nullptr, - workspace, wsSize, yDesc, y.data, nullptr, nullptr); - }); - - cnrtFree(workspace); - cnnlDestroyFuseNormDescriptor(opDesc); - cnnlDestroyTensorDescriptor(xDesc); - cnnlDestroyTensorDescriptor(yDesc); - cnnlDestroyTensorDescriptor(wDesc); -} diff --git a/src/ops/rms_norm/bang/rms_norm_cnnl.h b/src/ops/rms_norm/bang/rms_norm_cnnl.h deleted file mode 100644 index c76bf2d0..00000000 --- a/src/ops/rms_norm/bang/rms_norm_cnnl.h +++ /dev/null @@ -1,15 +0,0 @@ -#ifndef __CNNL_RMS_NORM_H__ -#define __CNNL_RMS_NORM_H__ - -#include "cnnl.h" -#include "cnnl_extra.h" -#include "operators.h" - -struct RMSNormCnnlDescriptor { - Device device; - RMSNormCnnlDescriptor(Device device); -}; - -void rms_norm_cnnl_f16(Tensor y, Tensor x, Tensor w, float epsilon, void *stream); - -#endif// __CNNL_RMS_NORM_H__ diff --git a/src/ops/rms_norm/operator.cc b/src/ops/rms_norm/operator.cc index e466d436..9aa4b206 100644 --- a/src/ops/rms_norm/operator.cc +++ b/src/ops/rms_norm/operator.cc @@ -13,7 +13,6 @@ #ifdef ENABLE_CAMBRICON_MLU #include "../../devices/bang/bang_handle.h" #include "bang/rms_norm_bang.h" -#include "bang/rms_norm_cnnl.h" #endif #ifdef ENABLE_ASCEND_NPU #include "ascend/rms_norm_aclnn.h" diff --git a/src/ops/utils.h b/src/ops/utils.h index ad2b65cc..fd124719 100644 --- a/src/ops/utils.h +++ b/src/ops/utils.h @@ -8,7 +8,7 @@ #include #include #include - +#include /* This file contains some useful macros and helper functions */ // check if an expression is true, and if not, print an error message and abort the program