Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/infiniop/ops/rms_norm/ninetoothed/build.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import ninetoothed
from ntops.kernels import rms_norm

import infiniop.ninetoothed.build


def build():
MAX_NDIM = 5

ndim_values = range(1, MAX_NDIM + 1)
dtype_values = (ninetoothed.float16, ninetoothed.bfloat16, ninetoothed.float32)

constexpr_param_grid = {
"ndim": ndim_values,
"num_normalized_dims": (1,),
"input_dtype": dtype_values,
"weight_dtype": dtype_values,
"output_dtype": dtype_values,
"block_size": (1024,),
}

infiniop.ninetoothed.build.build(
rms_norm.premake,
constexpr_param_grid,
caller="cuda",
op_name="rms_norm",
output_dir=infiniop.ninetoothed.build.BUILD_DIRECTORY_PATH,
)
42 changes: 42 additions & 0 deletions src/infiniop/ops/rms_norm/nvidia/rms_norm_nvidia.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

#include "../cuda/kernel.cuh"

#ifdef ENABLE_NINETOOTHED
#include "../../../../../build/ninetoothed/rms_norm.h"
#endif

template <unsigned int BLOCK_SIZE, typename Tcompute, typename Tdata, typename Tweight>
INFINIOP_CUDA_KERNEL rmsnormKernel(
Tdata *__restrict__ y,
Expand Down Expand Up @@ -112,6 +116,7 @@ infiniStatus_t Descriptor::calculate(
size_t nhead = _info.shape.size() > 2 ? _info.shape[1] : 1;
auto cuda_stream = reinterpret_cast<cudaStream_t>(stream);

#ifndef ENABLE_NINETOOTHED
// launch kernel with different block sizes
if (_opaque->internal->maxThreadsPerBlock() == CUDA_BLOCK_SIZE_1024) {
CHECK_STATUS(launchKernel<CUDA_BLOCK_SIZE_1024>(batch_size, nhead, dim, y, _info.atype, stride_y_batch, stride_y_nhead, x, stride_x_batch, stride_x_nhead, w, _info.wtype, _info.epsilon, cuda_stream));
Expand All @@ -122,6 +127,43 @@ infiniStatus_t Descriptor::calculate(
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
#else
const auto &ndim{_info.ndim()};
uint64_t dim_ = dim;

std::vector<uint64_t> empty_shape_vec(ndim);
std::vector<int64_t> empty_strides_vec(ndim);
const auto &empty_shape{empty_shape_vec.data()};
const auto &empty_strides{empty_strides_vec.data()};

auto &x_shape_vec{_info.shape};
auto &x_strides_vec{_info.x_strides};
auto x_data{x};
auto x_shape{x_shape_vec.data()};
auto x_strides{x_strides_vec.data()};
const NineToothedTensor input{const_cast<void *>(x_data), const_cast<uint64_t *>(x_shape), const_cast<int64_t *>(x_strides)};
auto &w_shape_vec{_info.shape};
std::vector<int64_t> w_strides_vec(ndim);
w_strides_vec[ndim - 1] = 1;
auto w_data{w};
auto w_shape{w_shape_vec.data()};
auto w_strides{w_strides_vec.data()};
const NineToothedTensor weight{const_cast<void *>(w_data), const_cast<uint64_t *>(w_shape), const_cast<int64_t *>(w_strides)};
const NineToothedTensor eps{const_cast<float *>(&_info.epsilon), empty_shape, empty_strides};
auto &y_shape_vec{_info.shape};
auto &y_strides_vec{_info.y_strides};
auto y_data{y};
auto y_shape{y_shape_vec.data()};
auto y_strides{y_strides_vec.data()};
const NineToothedTensor output{y_data, const_cast<uint64_t *>(y_shape), const_cast<int64_t *>(y_strides)};
const NineToothedTensor num_normalized_elements{const_cast<uint64_t *>(&dim_), empty_shape, empty_strides};
constexpr auto num_normalized_dims{1};
constexpr auto block_size{1024};

if (launch_rms_norm(stream, input, weight, eps, output, num_normalized_elements, ndim, num_normalized_dims, _info.atype, _info.wtype, _info.atype, block_size)) {
return INFINI_STATUS_INTERNAL_ERROR;
}
#endif
return INFINI_STATUS_SUCCESS;
}
} // namespace op::rms_norm::nvidia
Loading