diff --git a/src/infiniop/ops/rms_norm/ninetoothed/build.py b/src/infiniop/ops/rms_norm/ninetoothed/build.py new file mode 100644 index 000000000..a01d7b053 --- /dev/null +++ b/src/infiniop/ops/rms_norm/ninetoothed/build.py @@ -0,0 +1,28 @@ +import ninetoothed +from ntops.kernels import rms_norm + +import infiniop.ninetoothed.build + + +def build(): + MAX_NDIM = 5 + + ndim_values = range(1, MAX_NDIM + 1) + dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32) + + constexpr_param_grid = { + "ndim": ndim_values, + "num_normalized_dims": (1,), + "input_dtype": dtype_values, + "weight_dtype": dtype_values, + "output_dtype": dtype_values, + "block_size": (1024,), + } + + infiniop.ninetoothed.build.build( + rms_norm.premake, + constexpr_param_grid, + caller="cuda", + op_name="rms_norm", + output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH, + ) diff --git a/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu b/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu index c0d379a61..98371b628 100644 --- a/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu +++ b/src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu @@ -8,6 +8,10 @@ #include "../cuda/kernel.cuh" +#ifdef ENABLE_NINETOOTHED +#include "../../../../../build/ninetoothed/rms_norm.h" +#endif + template INFINIOP_CUDA_KERNEL rmsnormKernel( Tdata *__restrict__ y, @@ -112,6 +116,7 @@ infiniStatus_t Descriptor::calculate( size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1; auto cuda_stream = reinterpret_cast(stream); +#ifndef ENABLE_NINETOOTHED // launch kernel with different block sizes if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) { CHECK_STATUS(launchKernel(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream)); @@ -122,6 +127,43 @@ infiniStatus_t Descriptor::calculate( } else { return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED; } +#else + const auto &ndim{_info.ndim()}; + uint64_t dim_ = dim; + + std::vector empty_shape_vec(ndim); + std::vector empty_strides_vec(ndim); + const auto &empty_shape{empty_shape_vec.data()}; + const auto &empty_strides{empty_strides_vec.data()}; + + auto &x_shape_vec{_info.shape}; + auto &x_strides_vec{_info.x_strides}; + auto x_data{x}; + auto x_shape{x_shape_vec.data()}; + auto x_strides{x_strides_vec.data()}; + const NineToothedTensor input{const_cast(x_data), const_cast(x_shape), const_cast(x_strides)}; + auto &w_shape_vec{_info.shape}; + std::vector w_strides_vec(ndim); + w_strides_vec[ndim - 1] = 1; + auto w_data{w}; + auto w_shape{w_shape_vec.data()}; + auto w_strides{w_strides_vec.data()}; + const NineToothedTensor weight{const_cast(w_data), const_cast(w_shape), const_cast(w_strides)}; + const NineToothedTensor eps{const_cast(&_info.epsilon), empty_shape, empty_strides}; + auto &y_shape_vec{_info.shape}; + auto &y_strides_vec{_info.y_strides}; + auto y_data{y}; + auto y_shape{y_shape_vec.data()}; + auto y_strides{y_strides_vec.data()}; + const NineToothedTensor output{y_data, const_cast(y_shape), const_cast(y_strides)}; + const NineToothedTensor num_normalized_elements{const_cast(&dim_), empty_shape, empty_strides}; + constexpr auto num_normalized_dims{1}; + constexpr auto block_size{1024}; + + if (launch_rms_norm(stream, input, weight, eps, output, num_normalized_elements, ndim, num_normalized_dims, _info.atype, _info.wtype, _info.atype, block_size)) { + return INFINI_STATUS_INTERNAL_ERROR; + } +#endif return INFINI_STATUS_SUCCESS; } } // namespace op::rms_norm::nvidia