diff --git a/include/infinicore.h b/include/infinicore.h index f9d4662d8..192331048 100644 --- a/include/infinicore.h +++ b/include/infinicore.h @@ -34,6 +34,7 @@ typedef enum { INFINI_STATUS_BAD_TENSOR_SHAPE = 11, INFINI_STATUS_BAD_TENSOR_STRIDES = 12, INFINI_STATUS_INSUFFICIENT_WORKSPACE = 13, + INFINI_STATUS_NOT_ALIGNED = 14, } infiniStatus_t; typedef enum { @@ -70,6 +71,9 @@ typedef enum { INFINI_DTYPE_C64 = 17, INFINI_DTYPE_C128 = 18, INFINI_DTYPE_BF16 = 19, + INFINI_DTYPE_F8_E4M3 = 20, + INFINI_DTYPE_F8_E5M2 = 21, + INFINI_DTYPE_F8_UE8M0 = 22, } infiniDtype_t; #endif // __INFINICORE_API_H__ diff --git a/include/infiniop/operator_descriptor.h b/include/infiniop/operator_descriptor.h index b47271f1a..58ddb2c64 100644 --- a/include/infiniop/operator_descriptor.h +++ b/include/infiniop/operator_descriptor.h @@ -7,7 +7,9 @@ // Base descriptor for all operators struct InfiniopDescriptor; -__C __export infiniStatus_t infiniopGetDescriptorDeviceType(const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type); -__C __export infiniStatus_t infiniopGetDescriptorDeviceId(const struct InfiniopDescriptor *desc_ptr, int *device_id); +__C __export infiniStatus_t infiniopGetDescriptorDeviceType( + const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type); +__C __export infiniStatus_t infiniopGetDescriptorDeviceId( + const struct InfiniopDescriptor *desc_ptr, int *device_id); #endif //__INFINIOP_OPERATOR_DESCRIPTOR_API_H__ diff --git a/include/infiniop/ops/linear.h b/include/infiniop/ops/linear.h new file mode 100644 index 000000000..7c731d934 --- /dev/null +++ b/include/infiniop/ops/linear.h @@ -0,0 +1,26 @@ +#ifndef __INFINIOP_LINEAR_API_H__ +#define __INFINIOP_LINEAR_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopLinearDescriptor_t; + +__C __export infiniStatus_t infiniopCreateLinearDescriptor( + infiniopHandle_t handle, infiniopLinearDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t d_desc, infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t c_desc); + +__C __export infiniStatus_t +infiniopGetLinearWorkspaceSize(infiniopLinearDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopLinear( + infiniopLinearDescriptor_t desc, float alpha, const void *a, + const void *a_scale, const void *b, const void *b_scale, float beta, + const void *c, const void *c_scale, const void *bias, void *d, + const void *d_scale, bool is_blockwise, bool is_a_1d_scaled, + bool is_b_1d_scaled, void *workspace, size_t workspace_size, void *stream); + +__C __export infiniStatus_t +infiniopDestroyLinearDescriptor(infiniopLinearDescriptor_t desc); + +#endif \ No newline at end of file diff --git a/include/infiniop/ops/quantize.h b/include/infiniop/ops/quantize.h new file mode 100644 index 000000000..43ddded78 --- /dev/null +++ b/include/infiniop/ops/quantize.h @@ -0,0 +1,24 @@ +#ifndef __INFINIOP_QUANTIZE_API_H__ +#define __INFINIOP_QUANTIZE_API_H__ + +#include "../operator_descriptor.h" + +typedef struct InfiniopDescriptor *infiniopQuantizeDescriptor_t; + +__C __export infiniStatus_t infiniopCreateQuantizeDescriptor( + infiniopHandle_t handle, infiniopQuantizeDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t output_q_desc, + infiniopTensorDescriptor_t output_s_desc); + +__C __export infiniStatus_t infiniopGetQuantizeWorkspaceSize( + infiniopQuantizeDescriptor_t desc, size_t *size); + +__C __export infiniStatus_t infiniopQuantize( + infiniopQuantizeDescriptor_t desc, void *workspace, size_t workspace_size, + void *input, void *output_q, void *output_s, int group_size, double eps, + double min_8bit, double max_8bit, bool scale_ue8m0, void *stream); + +__C __export infiniStatus_t +infiniopDestroyQuantizeDescriptor(infiniopQuantizeDescriptor_t desc); +#endif \ No newline at end of file diff --git a/scripts/install.py b/scripts/install.py index 2e420ee9f..d789ff813 100644 --- a/scripts/install.py +++ b/scripts/install.py @@ -7,6 +7,7 @@ PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) os.chdir(PROJECT_DIR) + def run_cmd(cmd): subprocess.run(cmd, text=True, encoding="utf-8", check=True, shell=True) diff --git a/src/infiniop/devices/nvidia/nvidia_common.cu b/src/infiniop/devices/nvidia/nvidia_common.cu index 536dff853..a3ec59d49 100644 --- a/src/infiniop/devices/nvidia/nvidia_common.cu +++ b/src/infiniop/devices/nvidia/nvidia_common.cu @@ -49,6 +49,18 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn &f) const { + auto handle = blaslt_handles.pop(); + if (!handle) { + CHECK_CUBLASLT(cublasLtCreate(&(*handle))); + } + CHECK_STATUS(f(*handle)); + blaslt_handles.push(std::move(*handle)); + return INFINI_STATUS_SUCCESS; +} +#endif + int Handle::Internal::warpSize() const { return _warp_size; } int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; } int Handle::Internal::blockSizeX() const { return _block_size[0]; } diff --git a/src/infiniop/devices/nvidia/nvidia_handle.cuh b/src/infiniop/devices/nvidia/nvidia_handle.cuh index 1dcb4521e..1c65493d2 100644 --- a/src/infiniop/devices/nvidia/nvidia_handle.cuh +++ b/src/infiniop/devices/nvidia/nvidia_handle.cuh @@ -5,14 +5,26 @@ #include "../pool.h" #include "nvidia_handle.h" #include +#include +#include #include #ifdef ENABLE_CUDNN_API #include #endif +#ifdef ENABLE_CUBLASLT_API +#include +#if CUDA_VERSION >= 12090 +#define SUPPORT_FP8_BLOCKWISE_SCALE 1 +#else +#define SUPPORT_FP8_BLOCKWISE_SCALE 0 +#endif +#endif + #define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS) #define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS) +#define CHECK_CUBLASLT(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS) namespace device::nvidia { @@ -21,6 +33,9 @@ class Handle::Internal { #ifdef ENABLE_CUDNN_API Pool dnn_handles; #endif +#ifdef ENABLE_CUBLASLT_API + Pool blaslt_handles; +#endif int _warp_size, _max_threads_per_block, @@ -37,6 +52,9 @@ public: #ifdef ENABLE_CUDNN_API infiniStatus_t useCudnn(cudaStream_t stream, const Fn &f) const; #endif +#ifdef ENABLE_CUBLASLT_API + infiniStatus_t useCublasLt(cudaStream_t stream, const Fn &f) const; +#endif int warpSize() const; int maxThreadsPerBlock() const; diff --git a/src/infiniop/elementwise/cpu/elementwise_cpu.h b/src/infiniop/elementwise/cpu/elementwise_cpu.h index 487cb5bdb..9eae7e3a9 100644 --- a/src/infiniop/elementwise/cpu/elementwise_cpu.h +++ b/src/infiniop/elementwise/cpu/elementwise_cpu.h @@ -6,26 +6,22 @@ #include /** - * @brief Define the process for initializing a Descriptor of an elementwise operation - * for its CPU implementation + * @brief Define the process for initializing a Descriptor of an elementwise + * operation for its CPU implementation * * @param HANDLE The device handle. * @param DTYPE The output dtype. * @param OUT_DESC The output tensor descriptor. * @param INPUT_DESC_VEC A vector containing input tensor descriptors. */ -#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \ +#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, \ + INPUT_DESC_VEC) \ \ auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \ CHECK_RESULT(info_result); \ \ - *desc_ptr = new Descriptor( \ - DTYPE, \ - info_result.take(), \ - nullptr, \ - 0, \ - HANDLE->device, \ - HANDLE->device_id); + *desc_ptr = new Descriptor(DTYPE, info_result.take(), nullptr, 0, \ + HANDLE->device, HANDLE->device_id); namespace op::elementwise::cpu { @@ -62,18 +58,17 @@ class DeviceImpl final { * @return infiniStatus_t Status indicating success or failure. */ template - infiniStatus_t calculate( - const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args); + infiniStatus_t calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, Args &&...args); /** * @brief Dispatches an elementwise operation with heterogeneous input types. * - * Supports operations where each input may have a different type, as defined by Op. - * The number of input types must match the operation's expected input count. + * Supports operations where each input may have a different type, as defined + * by Op. The number of input types must match the operation's expected input + * count. * * @tparam Op The elementwise operation to perform. * @tparam Tout Output data type. @@ -86,15 +81,12 @@ class DeviceImpl final { * @param args Additional backend-specific arguments. * @return infiniStatus_t Status indicating success or failure. */ - template = 0> - infiniStatus_t calculate( - const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args); + infiniStatus_t calculate(const op::elementwise::ElementwiseInfo &info, + void *output, + const std::vector &inputs, + void *stream, Args &&...args); }; // Define the Opaque struct for CPU, which is empty @@ -106,74 +98,86 @@ utils::Result DeviceImpl::create(Args &&...args) { } // Perform elementwise operation for different input types -template = 0> -void calculate_impl(const op::elementwise::ElementwiseInfo &info, - void *output, +void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, - std::index_sequence, - Args &&...args) { + std::index_sequence, Args &&...args) { Tout *out = reinterpret_cast(output); - std::tuple input_ptrs = {reinterpret_cast(inputs[Is])...}; + std::tuple input_ptrs = { + reinterpret_cast(inputs[Is])...}; ptrdiff_t output_size = info.getOutputSize(); #pragma omp parallel for for (ptrdiff_t i = 0; i < output_size; ++i) { size_t out_idx = info.isOutputContiguous() ? i - : op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides()); + : op::common_cpu::indexToOffset( + i, info.getNdim(), info.getOutputShape(), + info.getOutputStrides()); auto get_input_idx = [&](size_t input_id) { return info.getInputContiguous()[input_id] ? i - : op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)); + : op::common_cpu::indexToOffset( + i, info.getNdim(), info.getInputShape(input_id), + info.getInputStrides(input_id)); }; - out[out_idx] = utils::cast( - Op{}.template operator()(std::get(input_ptrs)[get_input_idx(Is)]..., std::forward(args)...)); + out[out_idx] = utils::cast(Op{}.template operator()( + std::get(input_ptrs)[get_input_idx(Is)]..., + std::forward(args)...)); } } // Invoke elementwise operation for different input types -template > -infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args) { +template > +infiniStatus_t +DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, const std::vector &inputs, + void *stream, Args &&...args) { static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch"); - calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); + calculate_impl(info, output, inputs, + std::make_index_sequence{}, + std::forward(args)...); return INFINI_STATUS_SUCCESS; } // Perform elementwise operation when all inputs have the same type template -void calculate_impl(const op::elementwise::ElementwiseInfo &info, - void *output, +void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output, const std::vector &inputs, - std::index_sequence, - Args &&...args) { + std::index_sequence, Args &&...args) { Tdata *out = reinterpret_cast(output); - std::array ins = {reinterpret_cast(inputs[Is])...}; + std::array ins = { + reinterpret_cast(inputs[Is])...}; const ptrdiff_t output_size = info.getOutputSize(); #pragma omp parallel for if (output_size > 1024) for (ptrdiff_t i = 0; i < output_size; ++i) { size_t out_idx = info.isOutputContiguous() ? i - : op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides()); + : op::common_cpu::indexToOffset( + i, info.getNdim(), info.getOutputShape(), + info.getOutputStrides()); auto get_input_idx = [&](size_t input_id) { return info.getInputContiguous()[input_id] ? i - : op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id)); + : op::common_cpu::indexToOffset( + i, info.getNdim(), info.getInputShape(input_id), + info.getInputStrides(input_id)); }; if constexpr (std::is_same_v || std::is_same_v) { - out[out_idx] = utils::cast(Op{}(utils::cast(ins[Is][get_input_idx(Is)])..., std::forward(args)...)); + out[out_idx] = utils::cast( + Op{}(utils::cast(ins[Is][get_input_idx(Is)])..., + std::forward(args)...)); } else { out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward(args)...); } @@ -182,16 +186,16 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info, // Invoke elementwise operation when all inputs have the same type template -infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, - void *output, - const std::vector &inputs, - void *stream, - Args &&...args) { +infiniStatus_t +DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info, + void *output, const std::vector &inputs, + void *stream, Args &&...args) { constexpr size_t N = Op::num_inputs; - calculate_impl(info, output, inputs, std::make_index_sequence{}, std::forward(args)...); + calculate_impl(info, output, inputs, std::make_index_sequence{}, + std::forward(args)...); return INFINI_STATUS_SUCCESS; } } // namespace op::elementwise::cpu -#endif // __INFINIOP_ELEMENTWISE_CPU_H__ +#endif // __INFINIOP_ELEMENTWISE_CPU_H__ \ No newline at end of file diff --git a/src/infiniop/ops/linear/info.h b/src/infiniop/ops/linear/info.h new file mode 100644 index 000000000..2fd44a01f --- /dev/null +++ b/src/infiniop/ops/linear/info.h @@ -0,0 +1,130 @@ +#ifndef __GEMM_INFO_H__ +#define __GEMM_INFO_H__ + +#include "../../../utils.h" +#include "../../operator.h" +#include "../../tensor.h" +#include + +namespace op::linear { + +struct BlasMatrix { + size_t ndim; + size_t batch; + ptrdiff_t stride; + size_t rows; + size_t cols; + ptrdiff_t row_stride; + ptrdiff_t col_stride; + infiniDtype_t dtype; + + static utils::Result create(infiniopTensorDescriptor_t layout) { + BlasMatrix ans; + + if (layout->ndim() == 2) { + ans.ndim = 2; + ans.batch = 1; + ans.stride = 0; + ans.rows = layout->dim(0); + ans.cols = layout->dim(1); + ans.row_stride = layout->stride(0); + ans.col_stride = layout->stride(1); + ans.dtype = layout->dtype(); + } else if (layout->ndim() == 3) { + ans.ndim = 3; + ans.batch = layout->dim(0); + ans.stride = ans.batch == 1 ? 0 : layout->stride(0); + ans.rows = layout->dim(1); + ans.cols = layout->dim(2); + ans.row_stride = layout->stride(1); + ans.col_stride = layout->stride(2); + ans.dtype = layout->dtype(); + } else { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + if (ans.row_stride != 1 && ans.col_stride != 1) { + return INFINI_STATUS_BAD_TENSOR_STRIDES; + } + + return utils::Result(ans); + } + + bool match_batch(size_t _batch) const { + return batch == _batch || batch == 1; + } + + void transpose() { + std::swap(rows, cols); + std::swap(row_stride, col_stride); + } + + ptrdiff_t ld() const { return row_stride == 1 ? col_stride : row_stride; } +}; + +enum class MatrixLayout : char { + COL_MAJOR, + ROW_MAJOR, +}; + +class MatmulInfo { + MatmulInfo() = default; + +public: + BlasMatrix a_matrix; + BlasMatrix b_matrix; + BlasMatrix c_matrix; + BlasMatrix d_matrix; + + size_t m, n, k, batch; + bool is_transed; + + static utils::Result create(infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t c_desc, + infiniopTensorDescriptor_t d_desc, + MatrixLayout layout) { + + auto a_matrix = BlasMatrix::create(a_desc); + CHECK_RESULT(a_matrix); + + auto b_matrix = BlasMatrix::create(b_desc); + CHECK_RESULT(b_matrix); + + auto c_matrix = BlasMatrix::create(c_desc); + CHECK_RESULT(c_matrix); + + auto d_matrix = BlasMatrix::create(d_desc); + CHECK_RESULT(d_matrix); + + if ((c_matrix->rows != a_matrix->rows || c_matrix->cols != b_matrix->cols || a_matrix->cols != b_matrix->rows) && a_desc->dtype() != INFINI_DTYPE_F8_E4M3 && a_desc->dtype() != INFINI_DTYPE_F8_E5M2) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + auto batch = c_matrix->batch; + if (!a_matrix->match_batch(batch) || !b_matrix->match_batch(batch)) { + return INFINI_STATUS_BAD_TENSOR_SHAPE; + } + + auto is_transed = false; + if (((layout == MatrixLayout::COL_MAJOR && c_matrix->col_stride == 1) || (layout == MatrixLayout::ROW_MAJOR && c_matrix->row_stride == 1)) && a_desc->dtype() != INFINI_DTYPE_F8_E4M3 && a_desc->dtype() != INFINI_DTYPE_F8_E5M2) { + c_matrix->transpose(); + b_matrix->transpose(); + a_matrix->transpose(); + std::swap(a_matrix, b_matrix); + is_transed = true; + } + + auto m = c_matrix->rows; + auto n = c_matrix->cols; + auto k = a_matrix->cols; + + return utils::Result( + MatmulInfo{a_matrix.take(), b_matrix.take(), c_matrix.take(), + d_matrix.take(), m, n, k, batch, is_transed}); + } +}; + +} // namespace op::linear + +#endif // __GEMM_INFO_H__ diff --git a/src/infiniop/ops/linear/linear.h b/src/infiniop/ops/linear/linear.h new file mode 100644 index 000000000..8f53531b2 --- /dev/null +++ b/src/infiniop/ops/linear/linear.h @@ -0,0 +1,45 @@ +#ifndef __LINEAR_H__ +#define __LINEAR_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::linear::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + infiniDtype_t _dtype; \ + op::linear::MatmulInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(infiniDtype_t dtype, op::linear::MatmulInfo info, \ + size_t workspace_size_, Opaque *opaque, \ + infiniDevice_t device_type, int device_id) \ + : InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \ + _dtype(dtype), _info(info), _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create(infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t d_desc, \ + infiniopTensorDescriptor_t a_desc, \ + infiniopTensorDescriptor_t b_desc, \ + infiniopTensorDescriptor_t c_desc); \ + \ + infiniStatus_t calculate(float alpha, const void *a, const void *a_scale, \ + const void *b, const void *b_scale, float beta, \ + const void *c, const void *c_scale, \ + const void *bias, void *d, const void *d_scale, \ + bool is_blockwise, bool is_a_1d_scaled, \ + bool is_b_1d_scaled, void *workspace, \ + size_t workspace_size, void *stream) const; \ + }; \ + } + +#endif // __LINEAR_H__ diff --git a/src/infiniop/ops/linear/nvidia/linear_nvidia.cu b/src/infiniop/ops/linear/nvidia/linear_nvidia.cu new file mode 100644 index 000000000..6befb535e --- /dev/null +++ b/src/infiniop/ops/linear/nvidia/linear_nvidia.cu @@ -0,0 +1,329 @@ +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "linear_nvidia.cuh" + +namespace op::linear::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { + delete _opaque; +} + +infiniStatus_t Descriptor::create( + infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t d_desc, + infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, + infiniopTensorDescriptor_t c_desc) { + auto handle = reinterpret_cast(handle_); + auto dtype = d_desc->dtype(); + + CHECK_DTYPE(dtype, INFINI_DTYPE_F8_E4M3, INFINI_DTYPE_F8_E5M2, INFINI_DTYPE_F16, INFINI_DTYPE_BF16, INFINI_DTYPE_F32); + + auto result = op::linear::MatmulInfo::create(a_desc, b_desc, c_desc, d_desc, op::linear::MatrixLayout::COL_MAJOR); + CHECK_RESULT(result); + + *desc_ptr = new Descriptor( + dtype, result.take(), 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t Descriptor::calculate( + float alpha, + const void *a, + const void *a_scale, + const void *b, + const void *b_scale, + float beta, + const void *c, + const void *c_scale, + const void *bias, + void *d, + const void *d_scale, + bool is_blockwise, + bool is_a_1d_scaled, + bool is_b_1d_scaled, + void *workspace, + size_t workspace_size, + void *stream) const { + cublasComputeType_t compute_type; + int returnedResults = 0; + const int8_t fast_accum_mode = 0; + cublasLtMatmulPreference_t preference = NULL; + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + cublasLtEpilogue_t epilogue; + + cublasLtMatmulDesc_t lt_desc = NULL; + cublasLtMatrixLayout_t a_layout = NULL, b_layout = NULL, c_layout = NULL, d_layout = NULL; + + cudaDataType a_type, b_type, c_type, d_type, scale_type, bias_type; + + switch (_info.a_matrix.dtype) { + case INFINI_DTYPE_F8_E4M3: + a_type = CUDA_R_8F_E4M3; + compute_type = CUBLAS_COMPUTE_32F; + scale_type = CUDA_R_32F; + switch (_info.b_matrix.dtype) { + case INFINI_DTYPE_F8_E4M3: + b_type = CUDA_R_8F_E4M3; + switch (_info.c_matrix.dtype) { + case INFINI_DTYPE_BF16: + c_type = CUDA_R_16BF; + switch (_info.d_matrix.dtype) { + case INFINI_DTYPE_BF16: + d_type = CUDA_R_16BF; + bias_type = CUDA_R_16BF; + break; + case INFINI_DTYPE_F8_E4M3: + d_type = CUDA_R_8F_E4M3; + bias_type = CUDA_R_16BF; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + break; + case INFINI_DTYPE_F16: + c_type = CUDA_R_16F; + switch (_info.d_matrix.dtype) { + case INFINI_DTYPE_F16: + d_type = CUDA_R_16F; + bias_type = CUDA_R_16F; + break; + case INFINI_DTYPE_F8_E4M3: + d_type = CUDA_R_8F_E4M3; + bias_type = CUDA_R_16F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + break; + case INFINI_DTYPE_F32: + c_type = CUDA_R_32F; + switch (_info.d_matrix.dtype) { + case INFINI_DTYPE_F32: + d_type = CUDA_R_32F; + bias_type = CUDA_R_32F; + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + break; + default: + return INFINI_STATUS_NOT_IMPLEMENTED; + } + break; + default: + return INFINI_STATUS_NOT_IMPLEMENTED; + } + break; + + case INFINI_DTYPE_F16: + a_type = b_type = c_type = d_type = CUDA_R_16F; + bias_type = CUDA_R_16F; + compute_type = CUBLAS_COMPUTE_32F; + scale_type = CUDA_R_32F; + break; + + case INFINI_DTYPE_BF16: + a_type = b_type = c_type = d_type = CUDA_R_16BF; + bias_type = CUDA_R_16BF; + compute_type = CUBLAS_COMPUTE_32F; + scale_type = CUDA_R_32F; + break; + + case INFINI_DTYPE_F32: + a_type = b_type = c_type = d_type = CUDA_R_32F; + bias_type = CUDA_R_16BF; + compute_type = CUBLAS_COMPUTE_32F; + scale_type = CUDA_R_32F; + break; + + default: + return INFINI_STATUS_NOT_IMPLEMENTED; + } + // auto info = result.take(); + + /* To use tensor- or block-scaled FP8 kernels: + * A must be transposed and B non-transposed (The “TN” format) + * on Ada (compute capability 8.9), Hopper (compute capability 9.0), + * and Blackwell GeForce (compute capability 12.x) GPUs. + */ + CHECK_CUBLASLT(cublasLtMatmulDescCreate(<_desc, compute_type, scale_type)); + BlasMatrix a_matrix = _info.a_matrix, b_matrix = _info.b_matrix; + cublasOperation_t op_a, op_b; + if (_info.a_matrix.dtype == INFINI_DTYPE_F8_E4M3) { + bool transa = true; + bool transb = false; + a_matrix = _info.b_matrix; + b_matrix = _info.a_matrix; + + const int m = transa ? a_matrix.rows : a_matrix.cols; + const int k = transa ? a_matrix.cols : a_matrix.rows; + const int n = transb ? b_matrix.cols : b_matrix.rows; + int lda = k, ldb = k, ldc = m, ldd = m; + + op_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + op_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + +// Note: in cuBLAS term, tensor name A and B are swapped. +#if SUPPORT_FP8_BLOCKWISE_SCALE + if (is_blockwise) { + cublasLtMatmulMatrixScale_t a_scale_mode, b_scale_mode; + if (is_b_1d_scaled && is_a_1d_scaled) { + a_scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; + b_scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; + } else if (!is_b_1d_scaled && is_a_1d_scaled) { + // So this corresponds to 2Dx1D GEMM. + a_scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; + b_scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + } else if (is_b_1d_scaled && !is_a_1d_scaled) { + // So this corresponds to 1Dx2D GEMM. + a_scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_BLK128x128_32F; + b_scale_mode = CUBLASLT_MATMUL_MATRIX_SCALE_VEC128_32F; + } else { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_A_SCALE_MODE, &a_scale_mode, sizeof(a_scale_mode))); + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_B_SCALE_MODE, &b_scale_mode, sizeof(b_scale_mode))); + } +#endif + + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&a_layout, a_type, op_a == CUBLAS_OP_N ? m : k, op_a == CUBLAS_OP_N ? k : m, lda)); + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&b_layout, b_type, op_b == CUBLAS_OP_N ? k : n, op_b == CUBLAS_OP_N ? n : k, ldb)); + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&c_layout, c_type, m, n, ldc)); + + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&d_layout, d_type, m, n, ldd)); + + } else { + op_a = _info.a_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; + op_b = _info.b_matrix.row_stride == 1 ? CUBLAS_OP_N : CUBLAS_OP_T; + + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&a_layout, a_type, op_a ? _info.k : _info.m, op_a ? _info.m : _info.k, _info.a_matrix.ld())); + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&b_layout, b_type, op_b ? _info.n : _info.k, op_b ? _info.k : _info.n, _info.b_matrix.ld())); + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&c_layout, c_type, _info.m, _info.n, _info.c_matrix.ld())); + CHECK_CUBLASLT(cublasLtMatrixLayoutCreate(&d_layout, d_type, _info.m, _info.n, _info.c_matrix.ld())); + } + + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_TRANSA, &op_a, sizeof(op_a))); + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_TRANSB, &op_b, sizeof(op_b))); + + // set scale type to FP32 (needs to be FP16 if and only if using CUBLAS_COMPUTE_16F, so it's FP32 even for FP8!) + // CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_PE, &, sizeof(scale_type))); + // cuBlasLt requires C in fp8 mode to be BF16 or FP32 + + int batch = static_cast(_info.batch); + + if (batch > 1) { + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(a_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(b_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(c_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(d_layout, CUBLASLT_MATRIX_LAYOUT_BATCH_COUNT, &batch, sizeof(batch))); + + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(a_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &a_matrix.stride, sizeof(a_matrix.stride))); + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(b_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &b_matrix.stride, sizeof(b_matrix.stride))); + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(c_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &_info.c_matrix.stride, sizeof(_info.c_matrix.stride))); + CHECK_CUBLASLT(cublasLtMatrixLayoutSetAttribute(d_layout, CUBLASLT_MATRIX_LAYOUT_STRIDED_BATCH_OFFSET, &_info.c_matrix.stride, sizeof(_info.c_matrix.stride))); + } + + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, &fast_accum_mode, sizeof(fast_accum_mode))); + + if (_info.is_transed && a_type != CUDA_R_8F_E4M3 && a_type != CUDA_R_8F_E5M2) { + std::swap(a, b); + } else if (a_type == CUDA_R_8F_E4M3 || a_type == CUDA_R_8F_E5M2) { + std::swap(a, b); + std::swap(a_scale, b_scale); + } + + if ((a_type == CUDA_R_8F_E4M3 || b_type == CUDA_R_8F_E4M3) && (((uintptr_t)a % 16) != 0 || ((uintptr_t)b % 16) != 0 || ((uintptr_t)c % 16) != 0 || ((uintptr_t)d % 16) != 0)) { + return INFINI_STATUS_NOT_ALIGNED; + } + + if (a_scale) { + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &a_scale, sizeof(a_scale))); + } + if (b_scale) { + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &b_scale, sizeof(b_scale))); + } + if (c_scale) { + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_C_SCALE_POINTER, &c_scale, sizeof(c_scale))); + } + if (d_scale) { + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &d_scale, sizeof(d_scale))); + } + + if (bias) { + epilogue = CUBLASLT_EPILOGUE_BIAS; + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_EPILOGUE, &epilogue, sizeof(epilogue))); + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_type, sizeof(bias_type))); + CHECK_CUBLASLT(cublasLtMatmulDescSetAttribute(lt_desc, CUBLASLT_MATMUL_DESC_BIAS_POINTER, &bias, sizeof(bias))); + } + + CHECK_CUBLASLT(cublasLtMatmulPreferenceCreate(&preference)); + CHECK_CUBLASLT(cublasLtMatmulPreferenceSetAttribute(preference, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, &workspace_size, sizeof(workspace_size))); + + CHECK_STATUS(_opaque->internal->useCublasLt( + (cudaStream_t)stream, + [&](cublasLtHandle_t handle) { + CHECK_CUBLASLT(cublasLtMatmulAlgoGetHeuristic( + handle, + lt_desc, + a_layout, + b_layout, + c_layout, + d_layout, + preference, + 1, + &heuristicResult, + &returnedResults)); + if (returnedResults == 0) { + return INFINI_STATUS_NOT_IMPLEMENTED; + } + + CHECK_CUBLASLT( + cublasLtMatmul( + handle, + lt_desc, + &alpha, + a, + a_layout, + b, + b_layout, + &beta, + c, + c_layout, + d, + d_layout, + nullptr, + workspace, + workspace_size, + reinterpret_cast(stream))); + + if (preference) { + CHECK_CUBLASLT(cublasLtMatmulPreferenceDestroy(preference)); + } + if (d_layout) { + CHECK_CUBLASLT(cublasLtMatrixLayoutDestroy(d_layout)); + } + if (c_layout) { + CHECK_CUBLASLT(cublasLtMatrixLayoutDestroy(c_layout)); + } + if (b_layout) { + CHECK_CUBLASLT(cublasLtMatrixLayoutDestroy(b_layout)); + } + if (a_layout) { + CHECK_CUBLASLT(cublasLtMatrixLayoutDestroy(a_layout)); + } + if (lt_desc) { + CHECK_CUBLASLT(cublasLtMatmulDescDestroy(lt_desc)); + } + return INFINI_STATUS_SUCCESS; + })); + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::linear::nvidia diff --git a/src/infiniop/ops/linear/nvidia/linear_nvidia.cuh b/src/infiniop/ops/linear/nvidia/linear_nvidia.cuh new file mode 100644 index 000000000..6c41b69ab --- /dev/null +++ b/src/infiniop/ops/linear/nvidia/linear_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __LINEAR_CUDA_CUH__ +#define __LINEAR_CUDA_CUH__ + +#include "../linear.h" + +DESCRIPTOR(nvidia) + +#endif // __LINEAR_CUH__ diff --git a/src/infiniop/ops/linear/operator.cc b/src/infiniop/ops/linear/operator.cc new file mode 100644 index 000000000..325c64cac --- /dev/null +++ b/src/infiniop/ops/linear/operator.cc @@ -0,0 +1,93 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/linear.h" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/linear_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateLinearDescriptor( + infiniopHandle_t handle, infiniopLinearDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t d_desc, infiniopTensorDescriptor_t a_desc, + infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t c_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::linear::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + d_desc, a_desc, b_desc, c_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t +infiniopGetLinearWorkspaceSize(infiniopLinearDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__C infiniStatus_t infiniopLinear( + infiniopLinearDescriptor_t desc, float alpha, const void *a, + const void *a_scale, const void *b, const void *b_scale, float beta, + const void *c, const void *c_scale, const void *bias, void *d, + const void *d_scale, bool is_blockwise, bool is_a_1d_scaled, + bool is_b_1d_scaled, void *workspace, size_t workspace_size, void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(alpha, a, a_scale, b, b_scale, beta, c, c_scale, bias, d, \ + d_scale, is_blockwise, is_a_1d_scaled, is_b_1d_scaled, \ + workspace, workspace_size, stream) + switch (desc->device_type) { + +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroyLinearDescriptor(infiniopLinearDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast(desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef DELETE +} \ No newline at end of file diff --git a/src/infiniop/ops/quantize/info.h b/src/infiniop/ops/quantize/info.h new file mode 100644 index 000000000..0a5809646 --- /dev/null +++ b/src/infiniop/ops/quantize/info.h @@ -0,0 +1,36 @@ +#ifndef __QUANTIZE_INFO_H__ +#define __QUANTIZE_INFO_H__ + +#include "../../../utils.h" +#include "../../tensor.h" +#include + +// 需要一个值,要的是input的大小(x * y * z) +// 需要 out_q.dtype out_s.stride +// input.size input.dim +// + +namespace op::quantize { + +class QuantizeInfo { + QuantizeInfo() = default; + +public: + infiniopTensorDescriptor_t _input_desc, _output_q_desc, _output_s_desc; + + infiniopTensorDescriptor_t input() const { return _input_desc; } + infiniopTensorDescriptor_t output_q() const { return _output_q_desc; } + infiniopTensorDescriptor_t output_s() const { return _output_s_desc; } + + static utils::Result + create(infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t output_q_desc, + infiniopTensorDescriptor_t output_s_desc) { + + return utils::Result(QuantizeInfo{input_desc, output_q_desc, output_s_desc}); + } +}; + +} // namespace op::quantize + +#endif // __DEQUANTIZE_INFO_H__ \ No newline at end of file diff --git a/src/infiniop/ops/quantize/nvidia/quantize_group_8bit.cu b/src/infiniop/ops/quantize/nvidia/quantize_group_8bit.cu new file mode 100644 index 000000000..b8e6cb269 --- /dev/null +++ b/src/infiniop/ops/quantize/nvidia/quantize_group_8bit.cu @@ -0,0 +1,112 @@ +#include + +#include "../../../devices/nvidia/nvidia_handle.cuh" +#include "quantize_group_8bit.cuh" +#include "quantize_group_8bit_nvidia.cuh" +#include + +namespace op::quantize::nvidia { + +struct Descriptor::Opaque { + std::shared_ptr internal; +}; + +Descriptor::~Descriptor() { delete _opaque; } + +infiniStatus_t Descriptor::create(infiniopHandle_t handle_, + Descriptor **desc_ptr, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t output_q_desc, + infiniopTensorDescriptor_t output_s_desc) { + auto handle = reinterpret_cast(handle_); + auto dtype = output_q_desc->dtype(); + CHECK_DTYPE(dtype, INFINI_DTYPE_I8, INFINI_DTYPE_F8_E4M3, INFINI_DTYPE_F8_UE8M0); + auto result = QuantizeInfo::create(input_desc, output_q_desc, output_s_desc); + + *desc_ptr = new Descriptor(dtype, result.take(), 0, + new Opaque{handle->internal()}, + handle->device, handle->device_id); + return INFINI_STATUS_SUCCESS; +} + +infiniStatus_t +Descriptor::calculate(void *workspace, + size_t workspace_size, + void *input, + void *output_q, + void *output_s, + int group_size, + double eps, + double min_8bit, + double max_8bit, + bool scale_ue8m0, + void *stream) const { + auto cuda_stream = reinterpret_cast(stream); + const int num_groups = _info.input()->numel() / group_size; + constexpr int THREADS_PER_GROUP = 16; + + int groups_per_block = 1; + + if (num_groups % 16 == 0) { + groups_per_block = 16; + } else if (num_groups % 8 == 0) { + groups_per_block = 8; + } else if (num_groups % 4 == 0) { + groups_per_block = 4; + } else if (num_groups % 2 == 0) { + groups_per_block = 2; + } + + auto dst_type = _info.output_q()->dtype(); + const int num_blocks = num_groups / groups_per_block; + const int num_threads = groups_per_block * THREADS_PER_GROUP; + + const bool is_column_major = _info.output_s()->stride(0) < _info.output_s()->stride(1); + const int hidden_dim = _info.input()->shape()[_info.input()->ndim() - 1]; + const int num_groups_per_row = hidden_dim / group_size; + const int scale_stride = _info.output_s()->stride(1); +#define LAUNCH_KERNEL(T, DST_DTYPE) \ + do { \ + dim3 grid(num_blocks); \ + dim3 block(num_threads); \ + if (is_column_major) { \ + if (scale_ue8m0) { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input), output_q, \ + static_cast(output_s), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit, num_groups_per_row, scale_stride); \ + } else { \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input), output_q, \ + static_cast(output_s), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit, num_groups_per_row, scale_stride); \ + } \ + } else { \ + assert(!scale_ue8m0); \ + per_token_group_quant_8bit_kernel \ + <<>>( \ + static_cast(input), output_q, \ + static_cast(output_s), group_size, \ + num_groups, groups_per_block, (float)eps, (float)min_8bit, \ + (float)max_8bit); \ + } \ + } while (0) + + if (_info.input()->dtype() == INFINI_DTYPE_F16 && _dtype == INFINI_DTYPE_F8_E4M3) { + switch (_dtype) { + case INFINI_DTYPE_F8_E4M3: + LAUNCH_KERNEL(half, __nv_fp8_e4m3); + break; + default: + return INFINI_STATUS_BAD_TENSOR_DTYPE; + } + } +#undef LAUNCH_KERNEL + return INFINI_STATUS_SUCCESS; +} + +} // namespace op::quantize::nvidia \ No newline at end of file diff --git a/src/infiniop/ops/quantize/nvidia/quantize_group_8bit.cuh b/src/infiniop/ops/quantize/nvidia/quantize_group_8bit.cuh new file mode 100644 index 000000000..f07c5c8c9 --- /dev/null +++ b/src/infiniop/ops/quantize/nvidia/quantize_group_8bit.cuh @@ -0,0 +1,117 @@ +#include +#include +#include +#include + +#include "../../../devices/nvidia/nvidia_common.cuh" + +__device__ __forceinline__ float GroupReduceMax(float val, const int tid) { + unsigned mask = 0xffff; + + val = fmaxf(val, __shfl_xor_sync(mask, val, 8)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 4)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 2)); + val = fmaxf(val, __shfl_xor_sync(mask, val, 1)); + return val; +} + +template < + typename T, + typename DST_DTYPE, + bool IS_COLUMN_MAJOR = false, + bool SCALE_UE8M0 = false, + typename scale_packed_t = std::conditional_t> +__global__ void per_token_group_quant_8bit_kernel( + const T *__restrict__ input, + void *__restrict__ output_q, + scale_packed_t *__restrict__ output_s, + const int group_size, + const int num_groups, + const int groups_per_block, + const float eps, + const float min_8bit, + const float max_8bit, + const int num_groups_per_row = 0, + const int scale_stride = 0) { + const int threads_per_group = 16; + const int64_t local_group_id = threadIdx.x / threads_per_group; + const int lane_id = threadIdx.x % threads_per_group; + + const int64_t block_group_id = blockIdx.x * groups_per_block; + const int64_t global_group_id = block_group_id + local_group_id; + const int64_t block_group_offset = global_group_id * group_size; + + float local_absmax = eps; + + using scale_element_t = std::conditional_t; + static_assert(sizeof(scale_packed_t) % sizeof(scale_element_t) == 0); + + const T *group_input = input + block_group_offset; + DST_DTYPE *group_output = static_cast(output_q) + block_group_offset; + scale_element_t *scale_output; + + if constexpr (IS_COLUMN_MAJOR) { + const int num_elems_per_pack = static_cast(sizeof(scale_packed_t) / sizeof(scale_element_t)); + const int row_idx = global_group_id / num_groups_per_row; + const int col_idx_unpacked = global_group_id % num_groups_per_row; + const int col_idx = col_idx_unpacked / num_elems_per_pack; + const int pack_idx = col_idx_unpacked % num_elems_per_pack; + scale_output = reinterpret_cast(output_s) + (col_idx * scale_stride * num_elems_per_pack + row_idx * num_elems_per_pack + pack_idx); + } else { + static_assert(!SCALE_UE8M0); + scale_output = output_s + global_group_id; + } + + constexpr uint32_t vec_size = 16 / sizeof(T); + + const int32_t num_vec_elems = group_size / vec_size; + + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + T *input_vec = new T[vec_size]; + for (int j = 0; j < vec_size; j++) { + input_vec[j] = *(group_input + i * vec_size + j); + } + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); + float abs_val = fabsf(val); + local_absmax = fmaxf(local_absmax, abs_val); + } + delete input_vec; + } + + local_absmax = GroupReduceMax(local_absmax, lane_id); + + float y_s = local_absmax / max_8bit; + if constexpr (SCALE_UE8M0) { + y_s = exp2f(ceilf(log2f(fmaxf(y_s, 1e-10f)))); + } + + // TODO can optimize + scale_element_t y_s_quant; + if constexpr (SCALE_UE8M0) { + y_s_quant = (uint8_t)(((int)log2f(y_s)) + 127); + } else { + y_s_quant = y_s; + } + + if (lane_id == 0) { + *scale_output = y_s_quant; + } + + for (int32_t i = lane_id; i < num_vec_elems; i += 16) { + T *input_vec = new T[vec_size]; + for (int j = 0; j < vec_size; j++) { + input_vec[j] = *(group_input + i * vec_size + j); + } + +#pragma unroll + for (uint32_t j = 0; j < vec_size; ++j) { + float val = static_cast(input_vec[j]); + float q_val = fminf(fmaxf(val / y_s, min_8bit), max_8bit); + group_output[i * vec_size + j] = DST_DTYPE(q_val); + } + delete input_vec; + } +} diff --git a/src/infiniop/ops/quantize/nvidia/quantize_group_8bit_nvidia.cuh b/src/infiniop/ops/quantize/nvidia/quantize_group_8bit_nvidia.cuh new file mode 100644 index 000000000..b06875011 --- /dev/null +++ b/src/infiniop/ops/quantize/nvidia/quantize_group_8bit_nvidia.cuh @@ -0,0 +1,8 @@ +#ifndef __QUANTIZE_CUH__ +#define __QUANTIZE_CUH__ + +#include "../quantize.h" + +DESCRIPTOR(nvidia) + +#endif // __QUANTIZE_CUH__ diff --git a/src/infiniop/ops/quantize/operator.cc b/src/infiniop/ops/quantize/operator.cc new file mode 100644 index 000000000..23ad8f305 --- /dev/null +++ b/src/infiniop/ops/quantize/operator.cc @@ -0,0 +1,95 @@ +#include "../../operator.h" +#include "../../handle.h" +#include "infiniop/ops/quantize.h" + +#ifdef ENABLE_NVIDIA_API +#include "nvidia/quantize_group_8bit_nvidia.cuh" +#endif + +__C infiniStatus_t infiniopCreateQuantizeDescriptor( + infiniopHandle_t handle, infiniopQuantizeDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t input_desc, + infiniopTensorDescriptor_t output_q_desc, + infiniopTensorDescriptor_t output_s_desc) { + +#define CREATE(CASE, NAMESPACE) \ + case CASE: \ + return op::quantize::NAMESPACE::Descriptor::create( \ + handle, \ + reinterpret_cast(desc_ptr), \ + input_desc, output_q_desc, output_s_desc) + + switch (handle->device) { +#ifdef ENABLE_NVIDIA_API + CREATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CREATE +} + +__C infiniStatus_t infiniopGetQuantizeWorkspaceSize( + infiniopQuantizeDescriptor_t desc, size_t *size) { +#define GET(CASE, NAMESPACE) \ + case CASE: \ + *size = reinterpret_cast(desc) \ + ->workspaceSize(); \ + return INFINI_STATUS_SUCCESS + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + GET(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } +#undef GET +} + +__C infiniStatus_t infiniopQuantize(infiniopQuantizeDescriptor_t desc, + void *workspace, size_t workspace_size, + void *input, void *output_q, void *output_s, + int group_size, double eps, double min_8bit, + double max_8bit, bool scale_ue8m0, + void *stream) { + +#define CALCULATE(CASE, NAMESPACE) \ + case CASE: \ + return reinterpret_cast(desc) \ + ->calculate(workspace, workspace_size, input, output_q, output_s, \ + group_size, eps, min_8bit, max_8bit, scale_ue8m0, stream) + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + CALCULATE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef CALCULATE +} + +__C infiniStatus_t +infiniopDestroyQuantizeDescriptor(infiniopQuantizeDescriptor_t desc) { + +#define DELETE(CASE, NAMESPACE) \ + case CASE: \ + delete reinterpret_cast( \ + desc); \ + return INFINI_STATUS_SUCCESS; + + switch (desc->device_type) { +#ifdef ENABLE_NVIDIA_API + DELETE(INFINI_DEVICE_NVIDIA, nvidia); +#endif + default: + return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED; + } + +#undef DELETE +} + +// #endif \ No newline at end of file diff --git a/src/infiniop/ops/quantize/quantize.h b/src/infiniop/ops/quantize/quantize.h new file mode 100644 index 000000000..9ceb157f6 --- /dev/null +++ b/src/infiniop/ops/quantize/quantize.h @@ -0,0 +1,41 @@ +#ifndef __QUANTIZE_H__ +#define __QUANTIZE_H__ + +#include "../../operator.h" +#include "info.h" + +#define DESCRIPTOR(NAMESPACE) \ + \ + namespace op::quantize::NAMESPACE { \ + class Descriptor final : public InfiniopDescriptor { \ + struct Opaque; \ + Opaque *_opaque; \ + infiniDtype_t _dtype; \ + QuantizeInfo _info; \ + size_t _workspace_size; \ + \ + Descriptor(infiniDtype_t dtype, QuantizeInfo info, size_t workspace_size_, \ + Opaque *opaque, infiniDevice_t device_type, int device_id) \ + : InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \ + _dtype(dtype), _info(info), _workspace_size(workspace_size_) {} \ + \ + public: \ + ~Descriptor(); \ + \ + size_t workspaceSize() const { return _workspace_size; } \ + \ + static infiniStatus_t create(infiniopHandle_t handle, \ + Descriptor **desc_ptr, \ + infiniopTensorDescriptor_t input_desc, \ + infiniopTensorDescriptor_t output_q_desc, \ + infiniopTensorDescriptor_t output_s_desc); \ + \ + infiniStatus_t calculate(void *workspace, size_t workspace_size, \ + void *input, void *output_q, void *output_s, \ + int group_size, double eps, double min_8bit, \ + double max_8bit, bool scale_ue8m0, \ + void *stream) const; \ + }; \ + } + +#endif // __QUANTIZE_H__ diff --git a/test/infiniop-test/test_generate/__init__.py b/test/infiniop-test/test_generate/__init__.py index a61f63f7c..8db1e6755 100644 --- a/test/infiniop-test/test_generate/__init__.py +++ b/test/infiniop-test/test_generate/__init__.py @@ -1 +1,8 @@ -from .infiniop_test import InfiniopTestCase, InfiniopTestWriter, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor +from .infiniop_test import ( + InfiniopTestCase, + InfiniopTestWriter, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, + process_zero_stride_tensor, +) diff --git a/test/infiniop-test/test_generate/testcases/add.py b/test/infiniop-test/test_generate/testcases/add.py index 2adf19a9f..4a55acc29 100644 --- a/test/infiniop-test/test_generate/testcases/add.py +++ b/test/infiniop-test/test_generate/testcases/add.py @@ -4,7 +4,14 @@ from typing import List from numpy.lib.stride_tricks import as_strided -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, + process_zero_stride_tensor, +) def add( @@ -26,7 +33,6 @@ def __init__( c: np.ndarray, shape_c: List[int] | None, stride_c: List[int] | None, - ): super().__init__("add") self.a = a @@ -39,7 +45,6 @@ def __init__( self.shape_c = shape_c self.stride_c = stride_c - def write_test(self, test_writer: "InfiniopTestWriter"): super().write_test(test_writer) if self.shape_a is not None: @@ -49,12 +54,22 @@ def write_test(self, test_writer: "InfiniopTestWriter"): if self.shape_c is not None: test_writer.add_array(test_writer.gguf_key("c.shape"), self.shape_c) if self.stride_a is not None: - test_writer.add_array(test_writer.gguf_key("a.strides"), gguf_strides(*self.stride_a)) + test_writer.add_array( + test_writer.gguf_key("a.strides"), gguf_strides(*self.stride_a) + ) if self.stride_b is not None: - test_writer.add_array(test_writer.gguf_key("b.strides"), gguf_strides(*self.stride_b)) + test_writer.add_array( + test_writer.gguf_key("b.strides"), gguf_strides(*self.stride_b) + ) test_writer.add_array( test_writer.gguf_key("c.strides"), - gguf_strides(*self.stride_c if self.stride_c is not None else contiguous_gguf_strides(self.shape_c)) + gguf_strides( + *( + self.stride_c + if self.stride_c is not None + else contiguous_gguf_strides(self.shape_c) + ) + ), ) test_writer.add_tensor( test_writer.gguf_key("a"), self.a, raw_dtype=np_dtype_to_ggml(self.a.dtype) @@ -116,7 +131,6 @@ def write_test(self, test_writer: "InfiniopTestWriter"): stride_c=stride_c, ) test_cases.append(test_case) - + test_writer.add_tests(test_cases) test_writer.save() - \ No newline at end of file diff --git a/test/infiniop-test/test_generate/testcases/causal_softmax.py b/test/infiniop-test/test_generate/testcases/causal_softmax.py index 74c3efcf0..037701865 100644 --- a/test/infiniop-test/test_generate/testcases/causal_softmax.py +++ b/test/infiniop-test/test_generate/testcases/causal_softmax.py @@ -4,7 +4,13 @@ from typing import List from enum import Enum, auto -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, +) def causal_softmax(x): @@ -37,8 +43,8 @@ def __init__( super().__init__("causal_softmax") self.x = x self.y = y - self.shape_x=shape_x - self.shape_y=shape_y + self.shape_x = shape_x + self.shape_y = shape_y self.stride_x = stride_x self.stride_y = stride_y @@ -49,10 +55,18 @@ def write_test(self, test_writer: "InfiniopTestWriter"): if self.shape_y is not None: test_writer.add_array(test_writer.gguf_key("y.shape"), self.shape_y) if self.stride_x is not None: - test_writer.add_array(test_writer.gguf_key("x.strides"), gguf_strides(*self.stride_x)) + test_writer.add_array( + test_writer.gguf_key("x.strides"), gguf_strides(*self.stride_x) + ) test_writer.add_array( test_writer.gguf_key("y.strides"), - gguf_strides(*self.stride_y if self.stride_y is not None else contiguous_gguf_strides(self.shape_y)) + gguf_strides( + *( + self.stride_y + if self.stride_y is not None + else contiguous_gguf_strides(self.shape_y) + ) + ), ) test_writer.add_tensor( test_writer.gguf_key("x"), @@ -102,6 +116,6 @@ def write_test(self, test_writer: "InfiniopTestWriter"): stride_y, ) test_cases.append(test_case) - + test_writer.add_tests(test_cases) test_writer.save() diff --git a/test/infiniop-test/test_generate/testcases/clip.py b/test/infiniop-test/test_generate/testcases/clip.py index f08a59929..786153197 100644 --- a/test/infiniop-test/test_generate/testcases/clip.py +++ b/test/infiniop-test/test_generate/testcases/clip.py @@ -2,7 +2,13 @@ import gguf from typing import List, Optional, Tuple -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, +) def clip( @@ -35,7 +41,7 @@ def random_tensor(shape, dtype): Returns: Random tensor with the specified shape and dtype """ - return (np.random.rand(*shape).astype(dtype) * 4.0 - 2.0) + return np.random.rand(*shape).astype(dtype) * 4.0 - 2.0 class ClipTestCase(InfiniopTestCase): @@ -52,7 +58,7 @@ def __init__( max_val: np.ndarray, max_stride: Optional[List[int]], y: np.ndarray, - y_shape: Optional[List[int]], + y_shape: Optional[List[int]], y_stride: Optional[List[int]], ): super().__init__("clip") @@ -63,7 +69,7 @@ def __init__( self.max_val = max_val self.max_stride = max_stride self.y = y - self.y_shape=y_shape + self.y_shape = y_shape self.y_stride = y_stride def write_test(self, test_writer: "InfiniopTestWriter"): @@ -71,57 +77,64 @@ def write_test(self, test_writer: "InfiniopTestWriter"): # Add strides as arrays if they exist if self.x_stride is not None: - test_writer.add_array(test_writer.gguf_key("x.strides"), gguf_strides(*self.x_stride)) + test_writer.add_array( + test_writer.gguf_key("x.strides"), gguf_strides(*self.x_stride) + ) if self.min_stride is not None: - test_writer.add_array(test_writer.gguf_key("min_val.strides"), gguf_strides(*self.min_stride)) + test_writer.add_array( + test_writer.gguf_key("min_val.strides"), gguf_strides(*self.min_stride) + ) if self.max_stride is not None: - test_writer.add_array(test_writer.gguf_key("max_val.strides"), gguf_strides(*self.max_stride)) + test_writer.add_array( + test_writer.gguf_key("max_val.strides"), gguf_strides(*self.max_stride) + ) if self.y_shape is not None: test_writer.add_array(test_writer.gguf_key("y.shape"), self.y_shape) test_writer.add_array( test_writer.gguf_key("y.strides"), - gguf_strides(*self.y_stride if self.y_stride is not None else contiguous_gguf_strides(self.y_shape)) + gguf_strides( + *( + self.y_stride + if self.y_stride is not None + else contiguous_gguf_strides(self.y_shape) + ) + ), ) # Add tensors to the test test_writer.add_tensor( - test_writer.gguf_key("x"), - self.x, - raw_dtype=np_dtype_to_ggml(self.x.dtype) + test_writer.gguf_key("x"), self.x, raw_dtype=np_dtype_to_ggml(self.x.dtype) ) test_writer.add_tensor( test_writer.gguf_key("min_val"), self.min_val, - raw_dtype=np_dtype_to_ggml(self.min_val.dtype) + raw_dtype=np_dtype_to_ggml(self.min_val.dtype), ) test_writer.add_tensor( test_writer.gguf_key("max_val"), self.max_val, - raw_dtype=np_dtype_to_ggml(self.max_val.dtype) + raw_dtype=np_dtype_to_ggml(self.max_val.dtype), ) test_writer.add_tensor( - test_writer.gguf_key("y"), - self.y, - raw_dtype=np_dtype_to_ggml(self.y.dtype) + test_writer.gguf_key("y"), self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype) ) # Calculate the expected result ans = clip( self.x.astype(np.float64), self.min_val.astype(np.float64), - self.max_val.astype(np.float64) + self.max_val.astype(np.float64), ) # Add the expected result to the test test_writer.add_tensor( - test_writer.gguf_key("ans"), - ans, - raw_dtype=gguf.GGMLQuantizationType.F64 + test_writer.gguf_key("ans"), ans, raw_dtype=gguf.GGMLQuantizationType.F64 ) + if __name__ == "__main__": test_writer = InfiniopTestWriter("clip.gguf") @@ -130,23 +143,23 @@ def write_test(self, test_writer: "InfiniopTestWriter"): # Test case shapes shapes = [ - (10,), # 1D tensor - (5, 10), # 2D tensor - (2, 3, 4), # 3D tensor - (7, 13), # Prime dimensions - (1, 1), # Minimum shape - (100, 100), # Large shape - (16, 16, 16), # Large 3D + (10,), # 1D tensor + (5, 10), # 2D tensor + (2, 3, 4), # 3D tensor + (7, 13), # Prime dimensions + (1, 1), # Minimum shape + (100, 100), # Large shape + (16, 16, 16), # Large 3D ] # Test case min/max values min_max_values = [ - (-1.0, 1.0), # Standard range - (0.0, 2.0), # Positive range - (-2.0, 0.0), # Negative range - (-1000.0, 1000.0), # Large range - (-0.001, 0.001), # Small range - (0.0, 0.0), # min=max + (-1.0, 1.0), # Standard range + (0.0, 2.0), # Positive range + (-2.0, 0.0), # Negative range + (-1000.0, 1000.0), # Large range + (-0.001, 0.001), # Small range + (0.0, 0.0), # min=max ] # Data types to test @@ -171,7 +184,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): max_stride=None, y=y, y_shape=shape, - y_stride=None + y_stride=None, ) ) @@ -199,7 +212,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): max_stride=row_stride, y=y, y_shape=shape, - y_stride=row_stride + y_stride=row_stride, ) ) @@ -219,7 +232,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): max_stride=col_stride, y=y, y_shape=shape, - y_stride=col_stride + y_stride=col_stride, ) ) @@ -239,7 +252,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): max_stride=row_stride, y=y, y_shape=shape, - y_stride=col_stride + y_stride=col_stride, ) ) diff --git a/test/infiniop-test/test_generate/testcases/mul.py b/test/infiniop-test/test_generate/testcases/mul.py index 00c427bcb..ad4f6b806 100644 --- a/test/infiniop-test/test_generate/testcases/mul.py +++ b/test/infiniop-test/test_generate/testcases/mul.py @@ -2,30 +2,36 @@ import gguf from typing import List -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, +) -def mul( - a: np.ndarray, - b: np.ndarray -): + +def mul(a: np.ndarray, b: np.ndarray): return np.multiply(a, b) + def random_tensor(shape, dtype): rate = 1e-3 var = 0.5 * rate # 数值范围在[-5e-4, 5e-4] return rate * np.random.rand(*shape).astype(dtype) - var + class MulTestCase(InfiniopTestCase): def __init__( self, a: np.ndarray, - shape_a: List[int] | None, + shape_a: List[int] | None, stride_a: List[int] | None, b: np.ndarray, - shape_b: List[int] | None, + shape_b: List[int] | None, stride_b: List[int] | None, c: np.ndarray, - shape_c: List[int] | None, + shape_c: List[int] | None, stride_c: List[int] | None, ): super().__init__("mul") @@ -39,7 +45,6 @@ def __init__( self.shape_c = shape_c self.stride_c = stride_c - def write_test(self, test_writer: "InfiniopTestWriter"): super().write_test(test_writer) if self.shape_a is not None: @@ -49,12 +54,22 @@ def write_test(self, test_writer: "InfiniopTestWriter"): if self.shape_c is not None: test_writer.add_array(test_writer.gguf_key("c.shape"), self.shape_c) if self.stride_a is not None: - test_writer.add_array(test_writer.gguf_key("a.strides"), gguf_strides(*self.stride_a)) + test_writer.add_array( + test_writer.gguf_key("a.strides"), gguf_strides(*self.stride_a) + ) if self.stride_b is not None: - test_writer.add_array(test_writer.gguf_key("b.strides"), gguf_strides(*self.stride_b)) + test_writer.add_array( + test_writer.gguf_key("b.strides"), gguf_strides(*self.stride_b) + ) test_writer.add_array( test_writer.gguf_key("c.strides"), - gguf_strides(*self.stride_c if self.stride_c is not None else contiguous_gguf_strides(self.shape_c)) + gguf_strides( + *( + self.stride_c + if self.stride_c is not None + else contiguous_gguf_strides(self.shape_c) + ) + ), ) test_writer.add_tensor( @@ -68,7 +83,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): ) a_fp64 = self.a.astype(np.float64) b_fp64 = self.b.astype(np.float64) - + ans_fp64 = np.multiply(a_fp64, b_fp64) ans = mul(self.a, self.b) test_writer.add_tensor( @@ -80,7 +95,8 @@ def write_test(self, test_writer: "InfiniopTestWriter"): raw_dtype=np_dtype_to_ggml(ans_fp64.dtype), ) -if __name__ == '__main__': + +if __name__ == "__main__": test_writer = InfiniopTestWriter("mul.gguf") test_cases = [] @@ -96,16 +112,15 @@ def write_test(self, test_writer: "InfiniopTestWriter"): ((2048, 2560), (2560, 1), (1, 2048), (2560, 1)), ((4, 48, 64), (64 * 48, 64, 1), (1, 4, 192), None), ((4, 48, 64), None, (1, 4, 192), (48 * 64, 64, 1)), - ] + ] _TENSOR_DTYPES_ = [np.float32, np.float16] - + for dtype in _TENSOR_DTYPES_: for shape, stride_a, stride_b, stride_c in _TEST_CASES_: a = random_tensor(shape, dtype) b = random_tensor(shape, dtype) c = np.empty(tuple(0 for _ in shape), dtype=dtype) - test_cases.append( MulTestCase( a=a, @@ -118,7 +133,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): shape_c=shape, stride_c=stride_c, ) - ) - + ) + test_writer.add_tests(test_cases) test_writer.save() diff --git a/test/infiniop-test/test_generate/testcases/rearrange.py b/test/infiniop-test/test_generate/testcases/rearrange.py index 9617a1fc0..3d3a0e73b 100644 --- a/test/infiniop-test/test_generate/testcases/rearrange.py +++ b/test/infiniop-test/test_generate/testcases/rearrange.py @@ -1,14 +1,21 @@ import torch from typing import List -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, +) + def row_major_strides(shape): """生成张量的行优先stride - + Args: shape: 张量形状 - + Returns: 行优先strides列表 """ @@ -19,12 +26,13 @@ def row_major_strides(shape): strides.insert(0, stride) return strides + def column_major_strides(shape): """生成张量的列优先stride - + Args: shape: 张量形状 - + Returns: 列优先strides列表 """ @@ -35,6 +43,7 @@ def column_major_strides(shape): strides.append(stride) return strides + def rearrange_using_torch(src: torch.Tensor, dst_strides: List[int]) -> torch.Tensor: """ 使用torch的rearrange函数计算结果 @@ -66,27 +75,35 @@ def __init__( self.shape = shape self.src_strides = src_strides self.dst_strides = dst_strides - + def write_test(self, test_writer: "InfiniopTestWriter"): super().write_test(test_writer) - + # 写入形状信息 if self.shape is not None: test_writer.add_array(test_writer.gguf_key("src.shape"), self.shape) test_writer.add_array(test_writer.gguf_key("dst.shape"), self.shape) - + # 写入strides信息 if self.src_strides is not None: - test_writer.add_array(test_writer.gguf_key("src.strides"), gguf_strides(*self.src_strides)) + test_writer.add_array( + test_writer.gguf_key("src.strides"), gguf_strides(*self.src_strides) + ) test_writer.add_array( test_writer.gguf_key("dst.strides"), - gguf_strides(*self.dst_strides if self.dst_strides is not None else contiguous_gguf_strides(self.shape)) + gguf_strides( + *( + self.dst_strides + if self.dst_strides is not None + else contiguous_gguf_strides(self.shape) + ) + ), ) - + # 转换torch tensor为numpy用于写入文件 src_numpy = self.src.detach().cpu().numpy() dst_numpy = self.dst.detach().cpu().numpy() - + # 写入张量数据 test_writer.add_tensor( test_writer.gguf_key("src"), @@ -98,9 +115,13 @@ def write_test(self, test_writer: "InfiniopTestWriter"): dst_numpy, raw_dtype=np_dtype_to_ggml(dst_numpy.dtype), ) - + # 计算并写入答案 - dst_strides_for_ans = self.dst_strides if self.dst_strides is not None else list(contiguous_gguf_strides(self.shape)) + dst_strides_for_ans = ( + self.dst_strides + if self.dst_strides is not None + else list(contiguous_gguf_strides(self.shape)) + ) ans_torch = rearrange_using_torch(self.src, dst_strides_for_ans) ans_numpy = ans_torch.detach().cpu().numpy() test_writer.add_tensor( @@ -109,6 +130,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): raw_dtype=np_dtype_to_ggml(src_numpy.dtype), ) + if __name__ == "__main__": test_writer = InfiniopTestWriter("rearrange.gguf") test_cases = [] @@ -117,12 +139,20 @@ def write_test(self, test_writer: "InfiniopTestWriter"): # (shape, src_stride, dst_stride) ((100, 100), (1, 100), (100, 1)), ((4, 4), (1, 4), (4, 1)), - ((4, 6, 64), (64, 4*64, 1), (6*64, 64, 1)), + ((4, 6, 64), (64, 4 * 64, 1), (6 * 64, 64, 1)), ((2000, 2000), (1, 2000), (2000, 1)), ((2001, 2001), (1, 2001), (2001, 1)), ((2, 2, 2, 4), (16, 8, 4, 1), (16, 8, 1, 2)), - ((3, 4, 7, 53, 9), row_major_strides((3, 4, 7, 53, 9)), column_major_strides((3, 4, 7, 53, 9))), - ((3, 4, 50, 50, 5, 7), row_major_strides((3, 4, 50, 50, 5, 7)), column_major_strides((3, 4, 50, 50, 5, 7))), + ( + (3, 4, 7, 53, 9), + row_major_strides((3, 4, 7, 53, 9)), + column_major_strides((3, 4, 7, 53, 9)), + ), + ( + (3, 4, 50, 50, 5, 7), + row_major_strides((3, 4, 50, 50, 5, 7)), + column_major_strides((3, 4, 50, 50, 5, 7)), + ), ] _TENSOR_DTYPES_ = [torch.float32, torch.float16] @@ -132,7 +162,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): src = torch.rand(*shape, dtype=dtype) # 生成目标张量,使用正确的形状 dst = torch.empty(shape, dtype=dtype) - + test_case = RearrangeTestCase( src=src, dst=dst, @@ -140,7 +170,7 @@ def write_test(self, test_writer: "InfiniopTestWriter"): src_strides=src_strides, dst_strides=dst_strides, ) - test_cases.append(test_case) + test_cases.append(test_case) test_writer.add_tests(test_cases) - test_writer.save() + test_writer.save() diff --git a/test/infiniop-test/test_generate/testcases/rms_norm.py b/test/infiniop-test/test_generate/testcases/rms_norm.py index cc1937aae..b86e2cee5 100644 --- a/test/infiniop-test/test_generate/testcases/rms_norm.py +++ b/test/infiniop-test/test_generate/testcases/rms_norm.py @@ -1,11 +1,19 @@ import numpy as np from typing import List -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, +) + def random_tensor(shape: tuple, dtype: np.dtype) -> np.ndarray: return np.random.uniform(-1.0, 1.0, shape).astype(dtype) * 0.001 + def rms_norm(x: np.ndarray, w: np.ndarray, epsilon: float) -> np.ndarray: """ 使用numpy计算rms_norm结果 @@ -16,13 +24,14 @@ def rms_norm(x: np.ndarray, w: np.ndarray, epsilon: float) -> np.ndarray: Returns: 输出张量, 形状与 input 相同 """ - squared = x ** 2 + squared = x**2 mean = np.mean(squared, axis=-1, keepdims=True) rms = np.sqrt(mean + epsilon) - + normalized = x / rms return normalized * w + class RMSNormTestCase(InfiniopTestCase): def __init__( self, @@ -40,9 +49,9 @@ def __init__( self.y = y self.shape = shape self.epsilon = epsilon - self.x_strides=x_strides - self.y_strides=y_strides - + self.x_strides = x_strides + self.y_strides = y_strides + def write_test(self, test_writer: "InfiniopTestWriter"): super().write_test(test_writer) test_writer.add_float32(test_writer.gguf_key("epsilon"), self.epsilon) @@ -50,10 +59,18 @@ def write_test(self, test_writer: "InfiniopTestWriter"): test_writer.add_array(test_writer.gguf_key("x.shape"), self.shape) test_writer.add_array(test_writer.gguf_key("y.shape"), self.shape) if self.x_strides is not None: - test_writer.add_array(test_writer.gguf_key("x.strides"), gguf_strides(*self.x_strides)) + test_writer.add_array( + test_writer.gguf_key("x.strides"), gguf_strides(*self.x_strides) + ) test_writer.add_array( test_writer.gguf_key("y.strides"), - gguf_strides(*self.y_strides if self.y_strides is not None else contiguous_gguf_strides(self.shape)) + gguf_strides( + *( + self.y_strides + if self.y_strides is not None + else contiguous_gguf_strides(self.shape) + ) + ), ) test_writer.add_tensor( test_writer.gguf_key("x"), @@ -70,13 +87,16 @@ def write_test(self, test_writer: "InfiniopTestWriter"): self.y, raw_dtype=np_dtype_to_ggml(self.y.dtype), ) - ans = rms_norm(self.x.astype(np.float64), self.w.astype(np.float64), self.epsilon) + ans = rms_norm( + self.x.astype(np.float64), self.w.astype(np.float64), self.epsilon + ) test_writer.add_tensor( test_writer.gguf_key("ans"), ans, raw_dtype=np_dtype_to_ggml(np.float64), ) + if __name__ == "__main__": test_writer = InfiniopTestWriter("rms_norm.gguf") test_cases = [] @@ -116,9 +136,9 @@ def write_test(self, test_writer: "InfiniopTestWriter"): shape=shape, x_strides=x_strides, y_strides=y_strides, - epsilon=epsilon + epsilon=epsilon, ) - test_cases.append(test_case) + test_cases.append(test_case) test_writer.add_tests(test_cases) test_writer.save() diff --git a/test/infiniop-test/test_generate/testcases/rope.py b/test/infiniop-test/test_generate/testcases/rope.py index 7af729940..ee6016445 100644 --- a/test/infiniop-test/test_generate/testcases/rope.py +++ b/test/infiniop-test/test_generate/testcases/rope.py @@ -4,7 +4,14 @@ from typing import List from enum import Enum -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, +) + class Algorithm(Enum): GPT_J = 0 @@ -21,7 +28,6 @@ def _rope(sin, cos, t1, t2): return t_out_1, t_out_2 - dh = t.shape[-1] assert dh % 2 == 0, "Embedding dimension must be even." @@ -36,7 +42,7 @@ def _rope(sin, cos, t1, t2): t_out[..., 0::2] = t_out_even t_out[..., 1::2] = t_out_odd else: - half_dim = dh // 2 + half_dim = dh // 2 t_first = t[..., :half_dim] t_second = t[..., half_dim:] @@ -51,7 +57,9 @@ def _rope(sin, cos, t1, t2): def sin_cos_table(pos, dim, theta, dtype): assert dim % 2 == 0, "Embedding dimension must be even." - freqs = 1.0 / (theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(np.float32) / dim)) + freqs = 1.0 / ( + theta ** (np.arange(0, dim, 2)[: (dim // 2)].astype(np.float32) / dim) + ) angles = np.outer(pos, freqs) @@ -103,19 +111,33 @@ def write_test(self, test_writer: "InfiniopTestWriter"): test_writer.add_array(test_writer.gguf_key("x.shape"), self.shape_x) test_writer.add_array( test_writer.gguf_key("y.strides"), - gguf_strides(*self.stride_y if self.stride_y is not None else contiguous_gguf_strides(self.shape_y)) + gguf_strides( + *( + self.stride_y + if self.stride_y is not None + else contiguous_gguf_strides(self.shape_y) + ) + ), ) if self.stride_x is not None: - test_writer.add_array(test_writer.gguf_key("x.strides"), gguf_strides(*self.stride_x)) + test_writer.add_array( + test_writer.gguf_key("x.strides"), gguf_strides(*self.stride_x) + ) test_writer.add_tensor( - test_writer.gguf_key("pos_ids"), self.pos_ids, raw_dtype=np_dtype_to_ggml(self.pos_ids.dtype) + test_writer.gguf_key("pos_ids"), + self.pos_ids, + raw_dtype=np_dtype_to_ggml(self.pos_ids.dtype), ) test_writer.add_tensor( - test_writer.gguf_key("sin_table"), self.sin_table, raw_dtype=np_dtype_to_ggml(self.sin_table.dtype) + test_writer.gguf_key("sin_table"), + self.sin_table, + raw_dtype=np_dtype_to_ggml(self.sin_table.dtype), ) test_writer.add_tensor( - test_writer.gguf_key("cos_table"), self.cos_table, raw_dtype=np_dtype_to_ggml(self.cos_table.dtype) + test_writer.gguf_key("cos_table"), + self.cos_table, + raw_dtype=np_dtype_to_ggml(self.cos_table.dtype), ) ans = rotary_embedding( self.x.astype(np.float64), @@ -128,8 +150,6 @@ def write_test(self, test_writer: "InfiniopTestWriter"): ) - - if __name__ == "__main__": # ============================================================================== # Configuration (Internal Use Only) @@ -146,7 +166,6 @@ def write_test(self, test_writer: "InfiniopTestWriter"): ((3, 32, 128), (8000, 200, 1), (7000, 128, 1)), ] - _ALGO = [ Algorithm.GPT_J, Algorithm.GPT_NEOX, @@ -162,7 +181,9 @@ def write_test(self, test_writer: "InfiniopTestWriter"): x = np.random.rand(*shape).astype(dtype) y = np.empty(tuple(0 for _ in shape), dtype=dtype) pos_ids = np.arange(0, x.shape[0], dtype=np.int32) - sin_table, cos_table = sin_cos_table(pos_ids, x.shape[2], theta=1e5, dtype=dtype) + sin_table, cos_table = sin_cos_table( + pos_ids, x.shape[2], theta=1e5, dtype=dtype + ) test_case = RoPETestCase( y=y, x=x, diff --git a/test/infiniop-test/test_generate/testcases/swiglu.py b/test/infiniop-test/test_generate/testcases/swiglu.py index cb692b613..aa3450fed 100644 --- a/test/infiniop-test/test_generate/testcases/swiglu.py +++ b/test/infiniop-test/test_generate/testcases/swiglu.py @@ -2,7 +2,14 @@ import gguf from typing import List -from .. import InfiniopTestWriter, InfiniopTestCase, np_dtype_to_ggml, gguf_strides, contiguous_gguf_strides, process_zero_stride_tensor +from .. import ( + InfiniopTestWriter, + InfiniopTestCase, + np_dtype_to_ggml, + gguf_strides, + contiguous_gguf_strides, + process_zero_stride_tensor, +) def swiglu( @@ -26,7 +33,6 @@ def __init__( c: np.ndarray, shape_c: List[int] | None, stride_c: List[int] | None, - ): super().__init__("swiglu") self.a = a @@ -39,7 +45,6 @@ def __init__( self.shape_c = shape_c self.stride_c = stride_c - def write_test(self, test_writer: "InfiniopTestWriter"): super().write_test(test_writer) if self.shape_a is not None: @@ -47,14 +52,24 @@ def write_test(self, test_writer: "InfiniopTestWriter"): if self.shape_b is not None: test_writer.add_array(test_writer.gguf_key("b.shape"), self.shape_b) if self.shape_c is not None: - test_writer.add_array(test_writer.gguf_key("c.shape"), self.shape_c) + test_writer.add_array(test_writer.gguf_key("c.shape"), self.shape_c) if self.stride_a is not None: - test_writer.add_array(test_writer.gguf_key("a.strides"), gguf_strides(*self.stride_a)) + test_writer.add_array( + test_writer.gguf_key("a.strides"), gguf_strides(*self.stride_a) + ) if self.stride_b is not None: - test_writer.add_array(test_writer.gguf_key("b.strides"), gguf_strides(*self.stride_b)) + test_writer.add_array( + test_writer.gguf_key("b.strides"), gguf_strides(*self.stride_b) + ) test_writer.add_array( test_writer.gguf_key("c.strides"), - gguf_strides(*self.stride_c if self.stride_c is not None else contiguous_gguf_strides(self.shape_c)) + gguf_strides( + *( + self.stride_c + if self.stride_c is not None + else contiguous_gguf_strides(self.shape_c) + ) + ), ) test_writer.add_tensor( test_writer.gguf_key("a"), self.a, raw_dtype=np_dtype_to_ggml(self.a.dtype) diff --git a/test/infiniop/attention.py b/test/infiniop/attention.py index aa7241963..7c85e1a04 100644 --- a/test/infiniop/attention.py +++ b/test/infiniop/attention.py @@ -23,7 +23,6 @@ ) - def causal_softmax(x): type = x.dtype mask = torch.tril(torch.ones_like(x), diagonal=-1).flip(dims=[-2, -1]) diff --git a/test/infiniop/libinfiniop/datatypes.py b/test/infiniop/libinfiniop/datatypes.py index 633aaafa7..41738397f 100644 --- a/test/infiniop/libinfiniop/datatypes.py +++ b/test/infiniop/libinfiniop/datatypes.py @@ -19,6 +19,8 @@ class InfiniDtype: C32 = 17 C64 = 18 BF16 = 19 + F8E4M3 = 20 + F8E5M2 = 21 InfiniDtypeNames = { @@ -42,4 +44,6 @@ class InfiniDtype: InfiniDtype.C32: "C32", InfiniDtype.C64: "C64", InfiniDtype.BF16: "BF16", + InfiniDtype.F8E4M3: "F8E4M3", + InfiniDtype.F8E5M2: "F8E5M2", } diff --git a/test/infiniop/libinfiniop/op_register.py b/test/infiniop/libinfiniop/op_register.py index ba1ce33df..e6b677baa 100644 --- a/test/infiniop/libinfiniop/op_register.py +++ b/test/infiniop/libinfiniop/op_register.py @@ -4,7 +4,7 @@ infiniopOperatorDescriptor_t, ) -from ctypes import c_int32, c_void_p, c_size_t, POINTER, c_float +from ctypes import c_int32, c_void_p, c_size_t, POINTER, c_float, c_bool, c_double class OpRegister: @@ -495,6 +495,92 @@ def conv_(lib): ] +@OpRegister.operator +def linear_(lib): + lib.infiniopCreateLinearDescriptor.restype = c_int32 + lib.infiniopCreateLinearDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetLinearWorkspaceSize.restype = c_int32 + lib.infiniopGetLinearWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopLinear.restype = c_int32 + lib.infiniopLinear.argtypes = [ + infiniopOperatorDescriptor_t, + c_float, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_float, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_bool, + c_bool, + c_bool, + c_void_p, + c_size_t, + c_void_p, + ] + + lib.infiniopDestroyLinearDescriptor.restype = c_int32 + lib.infiniopDestroyLinearDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + +@OpRegister.operator +def quantize_(lib): + lib.infiniopCreateQuantizeDescriptor.restype = c_int32 + lib.infiniopCreateQuantizeDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopOperatorDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + + lib.infiniopGetQuantizeWorkspaceSize.restype = c_int32 + lib.infiniopGetQuantizeWorkspaceSize.argtypes = [ + infiniopOperatorDescriptor_t, + POINTER(c_size_t), + ] + + lib.infiniopQuantize.restype = c_int32 + lib.infiniopQuantize.argtypes = [ + infiniopOperatorDescriptor_t, + c_void_p, + c_size_t, + c_void_p, + c_void_p, + c_void_p, + c_int32, + c_double, + c_double, + c_double, + c_bool, + # c_void_p, + # c_void_p, + c_void_p, + ] + + lib.infiniopDestroyQuantizeDescriptor.restype = c_int32 + lib.infiniopDestroyQuantizeDescriptor.argtypes = [ + infiniopOperatorDescriptor_t, + ] + + @OpRegister.operator def topkrouter_(lib): lib.infiniopCreateTopkrouterDescriptor.restype = c_int32 diff --git a/test/infiniop/libinfiniop/utils.py b/test/infiniop/libinfiniop/utils.py index 162b199fe..9bb08cf58 100644 --- a/test/infiniop/libinfiniop/utils.py +++ b/test/infiniop/libinfiniop/utils.py @@ -80,7 +80,13 @@ def __init__( torch_shape, dtype=to_torch_dtype(dt), device=torch_device_map[device] ) elif mode == "randint": - self._torch_tensor = torch.randint(-2000000000,2000000000, torch_shape,dtype=to_torch_dtype(dt), device=torch_device_map[device]) + self._torch_tensor = torch.randint( + -2000000000, + 2000000000, + torch_shape, + dtype=to_torch_dtype(dt), + device=torch_device_map[device], + ) elif mode == "manual": assert set_tensor is not None assert torch_shape == list(set_tensor.shape) @@ -120,14 +126,19 @@ def data(self): def is_broadcast(self): return self.strides is not None and 0 in self.strides - + @staticmethod - def from_binary(binary_file, shape, strides, dt: InfiniDtype, device: InfiniDeviceEnum): + def from_binary( + binary_file, shape, strides, dt: InfiniDtype, device: InfiniDeviceEnum + ): data = np.fromfile(binary_file, dtype=to_numpy_dtype(dt)) base = torch.from_numpy(data) - torch_tensor = torch.as_strided(base, size=shape, stride=strides).to(torch_device_map[device]) + torch_tensor = torch.as_strided(base, size=shape, stride=strides).to( + torch_device_map[device] + ) return TestTensor( - shape, strides, dt, device, mode="binary", set_tensor=torch_tensor) + shape, strides, dt, device, mode="binary", set_tensor=torch_tensor + ) @staticmethod def from_torch(torch_tensor, dt: InfiniDtype, device: InfiniDeviceEnum): @@ -137,6 +148,26 @@ def from_torch(torch_tensor, dt: InfiniDtype, device: InfiniDeviceEnum): shape_, strides_, dt, device, mode="manual", set_tensor=torch_tensor ) + def convert_pricesion(self, dtype: InfiniDtype): + torch_shape = [] + torch_strides = [] if self.strides is not None else None + for i in range(len(self.shape)): + if self.strides is not None and self.strides[i] == 0: + torch_shape.append(1) + torch_strides.append(1) + elif self.strides is not None and self.strides[i] != 0: + torch_shape.append(self.shape[i]) + torch_strides.append(self.strides[i]) + else: + torch_shape.append(self.shape[i]) + self._torch_tensor = self._torch_tensor.to(to_torch_dtype(dtype)) + self.dt = dtype + if self.strides is not None: + self._data_tensor = rearrange_tensor(self._torch_tensor, torch_strides) + else: + self._data_tensor = self._torch_tensor.clone() + super().__init__(self.dt, self.shape, self.strides) + def to_torch_dtype(dt: InfiniDtype, compatability_mode=False): if dt == InfiniDtype.I8: @@ -165,6 +196,10 @@ def to_torch_dtype(dt: InfiniDtype, compatability_mode=False): return torch.int32 if compatability_mode else torch.uint32 elif dt == InfiniDtype.U64: return torch.int64 if compatability_mode else torch.uint64 + elif dt == InfiniDtype.F8E4M3: + return torch.float8_e4m3fn + elif dt == InfiniDtype.F8E5M2: + return torch.float8_e5m2 else: raise ValueError("Unsupported data type") @@ -200,7 +235,6 @@ def to_numpy_dtype(dt: InfiniDtype, compatability_mode=False): raise ValueError("Unsupported data type") - class TestWorkspace: def __init__(self, size, device): if size != 0: @@ -476,7 +510,7 @@ def print_discrepancy( actual = actual.to("cpu") expected = expected.to("cpu") - + actual_isnan = torch.isnan(actual) expected_isnan = torch.isnan(expected) diff --git a/test/infiniop/linear.py b/test/infiniop/linear.py new file mode 100644 index 000000000..68d2f18f5 --- /dev/null +++ b/test/infiniop/linear.py @@ -0,0 +1,194 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), None, None, None), + (1.0, 0.0, (2, 4, 2048), (2, 2048, 2048), (2, 4, 2048), None, None, None), + (1.0, 0.0, (1, 2048), (2048, 2048), (1, 2048), (4096, 1), (4096, 1), (4096, 1)), + (1.0, 1.0, (6, 2048), (2048, 2560), (6, 2560), (2048, 1), (1, 2048), (2560, 1)), + (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 64, 6), (4, 8 * 6, 6), None, None, None), +] + +# Data types used for testing +# _TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16, InfiniDtype.F32] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 0, "rtol": 1e-3}, + InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# PyTorch implementation for matrix multiplication +def gemm(d, _c, beta, _a, _b, alpha): + try: + if _c.ndim == 2: + torch.addmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) + elif _c.ndim == 3: + torch.baddbmm(_c, _a, _b, beta=beta, alpha=alpha, out=d) + else: + raise + except Exception: + torch.matmul(_a, _b, out=d) + d.mul_(alpha).add_(_c, alpha=beta) + + +# The argument list should be (lib, handle, torch_device, , dtype) +# The should keep the same order as the one specified in _TEST_CASES +def test( + handle, + device, + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride=None, + b_stride=None, + c_stride=None, + dtype=InfiniDtype.F16, + sync=None, +): + print( + f"Testing LINEAR on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta}," + f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape}," + f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{InfiniDtypeNames[dtype]}" + ) + + # Initialize tensors + a = TestTensor(a_shape, a_stride, dtype, device) + b = TestTensor(b_shape, b_stride, dtype, device) + c = TestTensor(c_shape, c_stride, dtype, device, mode="ones") + d = TestTensor(c_shape, c_stride, dtype, device, mode="zeros") + ans = TestTensor(c_shape, c_stride, dtype, device, mode="zeros") + + # Compute the PyTorch reference result + def torch_gemm(): + gemm( + ans.torch_tensor(), + c.torch_tensor(), + beta, + a.torch_tensor(), + b.torch_tensor(), + alpha, + ) + + torch_gemm() + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateLinearDescriptor( + handle, + ctypes.byref(descriptor), + d.descriptor, + a.descriptor, + b.descriptor, + c.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, c]: + tensor.destroy_desc() + + # Get workspace size and create workspace + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetLinearWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + # Execute infiniop gemm operator + def lib_linear(): + check_error( + LIBINFINIOP.infiniopLinear( + descriptor, + alpha, + a.data(), + None, + b.data(), + None, + beta, + c.data(), + None, + None, + d.data(), + None, + False, + False, + False, + workspace.data(), + 0, + None, + ) + ) + + lib_linear() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + assert torch.allclose(d.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyLinearDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/linear_fp8.py b/test/infiniop/linear_fp8.py new file mode 100644 index 000000000..618c27861 --- /dev/null +++ b/test/infiniop/linear_fp8.py @@ -0,0 +1,265 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from libinfiniop import to_torch_dtype, torch_device_map +import numpy as np + + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride + (1.0, 1.0, (16, 2048), (2048, 2048), (16, 2048), None, None, None), + (1.0, 0.0, (2, 16, 2048), (2, 2048, 2048), (2, 16, 2048), None, None, None), + (1.0, 1.0, (6, 2048), (2560, 2048), (6, 2560), None, None, None), + (1.0 / 8.0, 0.0, (4, 8 * 6, 64), (4, 16, 64), (4, 8 * 6, 16), None, None, None), +] + + +# A B C D BIAS +# _TENSOR_DTYPES = [[InfiniDtype.F8E4M3, InfiniDtype.F8E4M3, InfiniDtype.F16, InfiniDtype.F16, InfiniDtype.F16]] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 0, "rtol": 1e-3}, + InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2}, + InfiniDtype.F8E4M3: {"atol": 0, "rtol": 5e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# PyTorch implementation for matrix multiplication + + +def linear_f8e4m3(_a, _b, _ans, scale_a, scale_b, alpha, bias, beta, _c): + _a = _a.to(torch.float32) + _a *= alpha + _a = _a.to(torch.float8_e4m3fn) + assert torch.cuda.get_device_capability() >= (9, 0) + if len(_a.shape) > 2: + for i in range(0, _a.shape[0]): + torch._scaled_mm( + _a[i], + _b[i].T, + scale_a=scale_a, + scale_b=scale_b, + bias=bias, + out_dtype=_ans.dtype, + out=_ans[i], + ) + else: + torch._scaled_mm( + _a, + _b.T, + scale_a=scale_a, + scale_b=scale_b, + bias=bias, + out_dtype=_ans.dtype, + out=_ans, + ) + _ans += beta * _c + + +# Data types used for testing +# _TENSOR_DTYPES = [InfiniDtype.F8E4M3, InfiniDtype.F8E4M3, InfiniDtype.F8E4M3] +_TENSOR_DTYPES = [InfiniDtype.F8E4M3] + +# A B C D BIAS +FP8_SUPPORT_COMBINES = [ + [ + InfiniDtype.F8E4M3, + InfiniDtype.F8E4M3, + InfiniDtype.F16, + InfiniDtype.F16, + InfiniDtype.F16, + ], + [ + InfiniDtype.F8E4M3, + InfiniDtype.F8E4M3, + InfiniDtype.BF16, + InfiniDtype.BF16, + InfiniDtype.BF16, + ], +] + + +# The argument list should be (lib, handle, torch_device, , dtype) +# The should keep the same order as the one specified in _TEST_CASES +def test( + handle, + device, + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride=None, + b_stride=None, + c_stride=None, + dtype=InfiniDtype.F16, + sync=None, +): + for precision in range(0, len(FP8_SUPPORT_COMBINES)): + print( + f"Testing LINEAR_FP8 on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta}," + f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape}," + f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{InfiniDtypeNames[dtype]}" + f" a_dtype:{InfiniDtypeNames[FP8_SUPPORT_COMBINES[precision][0]]}, b_dtype:{InfiniDtypeNames[FP8_SUPPORT_COMBINES[precision][1]]}" + f" c_dtype:{InfiniDtypeNames[FP8_SUPPORT_COMBINES[precision][2]]}, d_dtype:{InfiniDtypeNames[FP8_SUPPORT_COMBINES[precision][3]]}" + ) + + # Initialize tensors + a = TestTensor(a_shape, a_stride, InfiniDtype.F16, device) + if a.dt != FP8_SUPPORT_COMBINES[precision][0]: + a.convert_pricesion(FP8_SUPPORT_COMBINES[precision][0]) + b = TestTensor(b_shape, b_stride, InfiniDtype.F16, device) + if b.dt != FP8_SUPPORT_COMBINES[precision][1]: + b.convert_pricesion(FP8_SUPPORT_COMBINES[precision][1]) + c = TestTensor( + c_shape, c_stride, FP8_SUPPORT_COMBINES[precision][2], device, mode="zeros" + ) + d = TestTensor( + c_shape, c_stride, FP8_SUPPORT_COMBINES[precision][3], device, mode="zeros" + ) + ans = TestTensor( + c_shape, c_stride, FP8_SUPPORT_COMBINES[precision][4], device, mode="zeros" + ) + bias = TestTensor( + (c_shape[-1],), None, FP8_SUPPORT_COMBINES[precision][2], device + ) + + scale_a = torch.tensor(1.0, device=torch_device_map[device]) + scale_b = torch.tensor(1.0, device=torch_device_map[device]) + + def torch_linear(): + linear_f8e4m3( + a.torch_tensor(), + b.torch_tensor(), + ans.torch_tensor(), + scale_a=scale_a, + scale_b=scale_b, + alpha=alpha, + bias=bias.torch_tensor(), + beta=beta, + _c=c.torch_tensor(), + ) + + torch_linear() + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateLinearDescriptor( + handle, + ctypes.byref(descriptor), + d.descriptor, + a.descriptor, + b.descriptor, + c.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, c]: + tensor.destroy_desc() + + # Get workspace size and create workspace + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetLinearWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + scale_a_ = scale_a.clone() + scale_b_ = scale_b.clone() + + # Execute infiniop gemm operator + def lib_linear(): + check_error( + LIBINFINIOP.infiniopLinear( + descriptor, + alpha, + a.data(), + scale_a_.data_ptr(), + b.data(), + scale_b_.data_ptr(), + beta, + c.data(), + None, + bias.data(), + d.data(), + None, + False, + False, + False, + workspace.data(), + workspace_size.value, + None, + ) + ) + + lib_linear() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + debug(c.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol) + + assert torch.allclose( + d.actual_tensor(), ans.torch_tensor(), atol=atol, rtol=rtol + ) + + # Profiling workflow + if PROFILE: + # fmt: off + profile_operation("PyTorch", lambda: torch_linear(), device, NUM_PRERUN, NUM_ITERATIONS) + profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyLinearDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/linear_fp8_blockwise.py b/test/infiniop/linear_fp8_blockwise.py new file mode 100644 index 000000000..9ed023866 --- /dev/null +++ b/test/infiniop/linear_fp8_blockwise.py @@ -0,0 +1,406 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) +from libinfiniop import to_torch_dtype, torch_device_map +import numpy as np + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride + (1.0, 1.0, (256, 1024), (384, 1024), (256, 384), None, None, None), + (2.0, 0.0, (256, 1024), (384, 1024), (256, 384), None, None, None), + (1.0, 2.0, (256, 1024), (384, 1024), (256, 384), None, None, None), + (0.5, 1.5, (128, 2048), (512, 2048), (128, 512), None, None, None), +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F8E4M3] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0, "rtol": 1e-2}, + InfiniDtype.F32: {"atol": 0, "rtol": 1e-3}, + InfiniDtype.BF16: {"atol": 0, "rtol": 5e-2}, + InfiniDtype.F8E4M3: {"atol": 0, "rtol": 5e-2}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +def pyop_ref_1x128_1x128(a, alpha, b, beta, a_descales, b_descales, C, bias): + """Reference implementation for 1x128 @ 1x128 pattern""" + assert a.dtype == torch.float8_e4m3fn, "a.dtype != torch.float8_e4m3fn" + assert b.dtype == torch.float8_e4m3fn, "b.dtype != torch.float8_e4m3fn" + assert a_descales.dtype == torch.float32, "a_descales.dtype != torch.float32" + assert b_descales.dtype == torch.float32, "b_descales.dtype != torch.float32" + a = a.to(torch.float32) + a *= alpha + a = a.to(torch.float8_e4m3fn) + + M, N, K = a.shape[0], b.shape[0], a.shape[1] + assert K == b.shape[1], "K != b.shape[1]" + + a_scales_m = a_descales.shape[1] + a_scales_k = a_descales.shape[0] + b_scales_n = b_descales.shape[1] + b_scales_k = b_descales.shape[0] + + assert a_scales_m == M, "a_scales_m != M" + assert a_scales_k * 128 == K, "a_scales_k * 128 != K" + assert b_scales_n == N, "b_scales_n != N" + assert b_scales_k * 128 == K, "b_scales_k * 128 != K" + + a = a.to(torch.float32) + b = b.to(torch.float32) + + out = torch.zeros((M, N), dtype=torch.float32).to(a.device) + + for i in range(0, M): + for j in range(0, N): + for k in range(0, K, 128): + out[i, j] += ( + (a[i, k : k + 128] @ b[j, k : k + 128].t()) + * a_descales[k // 128, i] + * b_descales[k // 128, j] + ) + C *= beta + C += out + bias + + +def pyop_ref_1x128_128x128(a, alpha, b, beta, a_descales, b_descales, C, bias): + """Reference implementation for 1x128 @ 128x128 pattern""" + assert a.dtype == torch.float8_e4m3fn, "a.dtype != torch.float8_e4m3fn" + assert b.dtype == torch.float8_e4m3fn, "b.dtype != torch.float8_e4m3fn" + assert a_descales.dtype == torch.float32, "a_descales.dtype != torch.float32" + assert b_descales.dtype == torch.float32, "b_descales.dtype != torch.float32" + a = a.to(torch.float32) + a *= alpha + a = a.to(torch.float8_e4m3fn) + + M, N, K = a.shape[0], b.shape[0], a.shape[1] + assert K == b.shape[1], "K != b.shape[1]" + + a_scales_m = a_descales.shape[1] + a_scales_k = a_descales.shape[0] + b_scales_k = b_descales.shape[1] + b_scales_n = b_descales.shape[0] + + assert a_scales_m == M, "a_scales_m != M" + assert a_scales_k * 128 == K, "a_scales_k * 128 != K" + assert b_scales_n * 128 == N, "b_scales_n * 128 != N" + assert b_scales_k * 128 == K, "b_scales_k * 128 != K" + + a = a.to(torch.float32) + b = b.to(torch.float32) + + out = torch.zeros((M, N), dtype=torch.float32).to(a.device) + + for i in range(0, M): + for j in range(0, N, 128): + for k in range(0, K, 128): + out[i, j : j + 128] += ( + (a[i, k : k + 128] @ b[j : j + 128, k : k + 128].t()) + * a_descales[k // 128, i] + * b_descales[j // 128, k // 128] + ) + C *= beta + C += out + bias + + +def pyop_ref_128x128_1x128(a, alpha, b, beta, a_descales, b_descales, C, bias): + """Reference implementation for 128x128 @ 1x128 pattern""" + assert a.dtype == torch.float8_e4m3fn, "a.dtype != torch.float8_e4m3fn" + assert b.dtype == torch.float8_e4m3fn, "b.dtype != torch.float8_e4m3fn" + assert a_descales.dtype == torch.float32, "a_descales.dtype != torch.float32" + assert b_descales.dtype == torch.float32, "b_descales.dtype != torch.float32" + a = a.to(torch.float32) + a *= alpha + a = a.to(torch.float8_e4m3fn) + + M, N, K = a.shape[0], b.shape[0], a.shape[1] + assert K == b.shape[1], "K != b.shape[1]" + + a_scales_m = a_descales.shape[0] + a_scales_k = a_descales.shape[1] + b_scales_k = b_descales.shape[0] + b_scales_n = b_descales.shape[1] + + assert a_scales_m * 128 == M, "a_scales_m * 128 != M" + assert a_scales_k * 128 == K, "a_scales_k * 128 != K" + assert b_scales_n == N, "b_scales_n != N" + assert b_scales_k * 128 == K, "b_scales_k * 128 != K" + + a = a.to(torch.float32) + b = b.to(torch.float32) + + out = torch.zeros((M, N), dtype=torch.float32).to(a.device) + + for i in range(0, M, 128): + for j in range(0, N): + for k in range(0, K, 128): + out[i : i + 128, j] += ( + (a[i : i + 128, k : k + 128] @ b[j, k : k + 128].t()) + * a_descales[i // 128, k // 128] + * b_descales[k // 128, j] + ) + C *= beta + C += out + bias + + +# The argument list should be (lib, handle, torch_device, , dtype) +# The should keep the same order as the one specified in _TEST_CASES + +# A B C D BIAS +FP8_SUPPORT_COMBINES = [ + [ + InfiniDtype.F8E4M3, + InfiniDtype.F8E4M3, + InfiniDtype.F16, + InfiniDtype.F16, + InfiniDtype.F16, + ], + [ + InfiniDtype.F8E4M3, + InfiniDtype.F8E4M3, + InfiniDtype.BF16, + InfiniDtype.BF16, + InfiniDtype.BF16, + ], +] + + +def test( + handle, + device, + alpha, + beta, + a_shape, + b_shape, + c_shape, + a_stride=None, + b_stride=None, + c_stride=None, + dtype=InfiniDtype.F16, + sync=None, +): + + for precision in range(0, len(FP8_SUPPORT_COMBINES)): + print( + f"Testing FP8 Linear BlockWise on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta}," + f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape}," + f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}" + f" input_dtype:{InfiniDtypeNames[FP8_SUPPORT_COMBINES[precision][0]]}, output_dtype:{InfiniDtypeNames[FP8_SUPPORT_COMBINES[precision][-1]]}" + ) + for a_1d, b_1d in [(True, False), (False, True)]: + # 1x128 @ 1x128 的乘法非常慢,所以先注掉,等需要的时候再加上 + # for a_1d, b_1d in [(True, False), (False, True), (True, True)]: + # Initialize tensors + a = TestTensor(a_shape, None, InfiniDtype.F16, device) + if a.dt != FP8_SUPPORT_COMBINES[precision][0]: + a.convert_pricesion(FP8_SUPPORT_COMBINES[precision][0]) + b = TestTensor(b_shape, None, InfiniDtype.F16, device) + if b.dt != FP8_SUPPORT_COMBINES[precision][1]: + b.convert_pricesion(FP8_SUPPORT_COMBINES[precision][1]) + c = TestTensor(c_shape, None, FP8_SUPPORT_COMBINES[precision][2], device) + d = TestTensor( + c_shape, None, FP8_SUPPORT_COMBINES[precision][3], device, mode="zeros" + ) + + bias = ( + torch.ones( + c_shape, + device=torch_device_map[device], + dtype=c.torch_tensor().dtype, + ) + * 0.6 + ) + + if a_1d and not b_1d: + scale_a_ = TestTensor( + (int(a_shape[0] / 128), int(a_shape[1] / 128)), + None, + InfiniDtype.F32, + device, + ) + scale_b_ = TestTensor( + (int(b_shape[1] / 128), int(b_shape[0])), + None, + InfiniDtype.F32, + device, + ) + pyop_ref_128x128_1x128( + a.torch_tensor(), + alpha, + b.torch_tensor(), + beta, + scale_a_.torch_tensor(), + scale_b_.torch_tensor(), + c.torch_tensor(), + bias, + ) + elif not a_1d and b_1d: + scale_a_ = TestTensor( + (int(a_shape[1] / 128), int(a_shape[0])), + None, + InfiniDtype.F32, + device, + ) + scale_b_ = TestTensor( + (int(b_shape[0] / 128), int(b_shape[1] / 128)), + None, + InfiniDtype.F32, + device, + ) + pyop_ref_1x128_128x128( + a.torch_tensor(), + alpha, + b.torch_tensor(), + beta, + scale_a_.torch_tensor(), + scale_b_.torch_tensor(), + c.torch_tensor(), + bias, + ) + elif a_1d and b_1d: + scale_a_ = TestTensor( + (int(a_shape[1] / 128), int(a_shape[0])), + None, + InfiniDtype.F32, + device, + ) + scale_b_ = TestTensor( + (int(b_shape[1] / 128), int(b_shape[0])), + None, + InfiniDtype.F32, + device, + ) + pyop_ref_1x128_1x128( + a.torch_tensor(), + alpha, + b.torch_tensor(), + beta, + scale_a_.torch_tensor(), + scale_b_.torch_tensor(), + c.torch_tensor(), + bias, + ) + else: + raise Exception("不支持scale均为二维块量化的情况") + + if sync is not None: + sync() + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateLinearDescriptor( + handle, + ctypes.byref(descriptor), + d.descriptor, + a.descriptor, + b.descriptor, + c.descriptor, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + for tensor in [a, b, c]: + tensor.destroy_desc() + + # Get workspace size and create workspace + workspace_size = c_uint64(33554432) + check_error( + LIBINFINIOP.infiniopGetLinearWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + bias_ = bias.clone() + + # Execute infiniop gemm operator + def lib_linear(): + check_error( + LIBINFINIOP.infiniopLinear( + descriptor, + alpha, + a.data(), + scale_a_.data(), + b.data(), + scale_b_.data(), + beta, + c.data(), + None, + bias_.data_ptr(), + d.data(), + None, + True, # block_wise + a_1d, # a_1d + b_1d, # b_1d + workspace.data(), + 0, + None, + ) + ) + + lib_linear() + + # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + # if DEBUG: + # debug(c.actual_tensor(), f, atol=atol, rtol=rtol) + assert torch.allclose( + d.actual_tensor(), + c.torch_tensor().to(d.torch_tensor().dtype), + atol=atol, + rtol=rtol, + ) + + # Profiling workflow + if PROFILE: + raise NotImplementedError + # fmt: off + # profile_operation("PyTorch", lambda: torch_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + # profile_operation(" lib", lambda: lib_linear(), device, NUM_PRERUN, NUM_ITERATIONS) + # fmt: on + check_error(LIBINFINIOP.infiniopDestroyLinearDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/quantize.py b/test/infiniop/quantize.py new file mode 100644 index 000000000..5c733b9b7 --- /dev/null +++ b/test/infiniop/quantize.py @@ -0,0 +1,158 @@ +import torch +import ctypes +from ctypes import c_uint64 +from libinfiniop import ( + LIBINFINIOP, + TestTensor, + get_test_devices, + check_error, + test_operator, + get_args, + debug, + get_tolerance, + profile_operation, + TestWorkspace, + InfiniDtype, + InfiniDtypeNames, + InfiniDeviceNames, + infiniopOperatorDescriptor_t, +) + +# ============================================================================== +# Configuration (Internal Use Only) +# ============================================================================== +# These are not meant to be imported from other modules +_TEST_CASES = [ + # alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride + ((1, 2048), (2048, 2048), (1, 2048), None, None, None), +] + +# Data types used for testing +_TENSOR_DTYPES = [InfiniDtype.F16, InfiniDtype.BF16] + +# Tolerance map for different data types +_TOLERANCE_MAP = { + InfiniDtype.F16: {"atol": 0, "rtol": 2e-1}, + # InfiniDtype.F32: {"atol": 0, "rtol": 1e-3}, + InfiniDtype.BF16: {"atol": 0, "rtol": 2e-1}, +} + +DEBUG = False +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +# The argument list should be (lib, handle, torch_device, , dtype) +# The should keep the same order as the one specified in _TEST_CASES +def test( + handle, + device, + a_shape, + b_shape, + c_shape, + a_stride=None, + b_stride=None, + c_stride=None, + dtype=InfiniDtype.F16, + sync=None, +): + print( + # f"Testing Gemm on {InfiniDeviceNames[device]} with alpha:{alpha}, beta:{beta}," + f" a_shape:{a_shape}, b_shape:{b_shape}, c_shape:{c_shape}," + f" a_stride:{a_stride}, b_stride:{b_stride}, c_stride:{c_stride}, dtype:{InfiniDtypeNames[dtype]}" + ) + + input = TestTensor((512, 512), None, InfiniDtype.F16, device, mode="random") + output_q = TestTensor((512, 512), None, InfiniDtype.F8E4M3, device, mode="zeros") + outpus_s = TestTensor((5, 512), None, InfiniDtype.F32, device, mode="zeros") + + def eval_1x128(x_quant, x_scale): + scale = torch.repeat_interleave(x_scale, 128, dim=0) + scale = scale[: x_quant.shape[0], : x_quant.shape[1]] + + assert ( + scale.shape == x_quant.shape + ), f"scale shape {scale.shape} not match x_quant shape {x_quant.shape}" + + x_qdq = x_quant.to(torch.float32) * scale + return x_qdq.to(torch.float32) + + descriptor = infiniopOperatorDescriptor_t() + check_error( + LIBINFINIOP.infiniopCreateQuantizeDescriptor( + handle, + ctypes.byref(descriptor), + input.descriptor, + output_q.descriptor, + outpus_s.descriptor, + ) + ) + + # Get workspace size and create workspace + workspace_size = c_uint64(0) + check_error( + LIBINFINIOP.infiniopGetQuantizeWorkspaceSize( + descriptor, ctypes.byref(workspace_size) + ) + ) + workspace = TestWorkspace(workspace_size.value, device) + + # Execute infiniop gemm operator + def lib_quantize(): + check_error( + LIBINFINIOP.infiniopQuantize( + descriptor, + workspace.data(), + workspace_size.value, + input.data(), + output_q.data(), + outpus_s.data(), + # zeros.data(), + 128, + 0, + -448, + 448, + False, + None, + ) + ) + + lib_quantize() + + ans = eval_1x128(output_q.actual_tensor(), outpus_s.actual_tensor()) + # # Validate results + atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) + + if DEBUG: + debug(input.torch_tensor().to(torch.float32), ans, atol=atol, rtol=rtol) + print(ans, input.torch_tensor().to(torch.float32)) + + # assert torch.allclose(ans, input.torch_tensor().to(torch.float32), atol=atol) + + # # Profiling workflow + # if PROFILE: + # # fmt: off + # profile_operation("PyTorch", lambda: torch_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + # profile_operation(" lib", lambda: lib_gemm(), device, NUM_PRERUN, NUM_ITERATIONS) + # # fmt: on + check_error(LIBINFINIOP.infiniopDestroyQuantizeDescriptor(descriptor)) + + +# ============================================================================== +# Main Execution +# ============================================================================== +if __name__ == "__main__": + args = get_args() + + # Configure testing options + DEBUG = args.debug + PROFILE = args.profile + NUM_PRERUN = args.num_prerun + NUM_ITERATIONS = args.num_iterations + + # Execute tests + for device in get_test_devices(args): + test_operator(device, test, _TEST_CASES, _TENSOR_DTYPES) + + print("\033[92mTest passed!\033[0m") diff --git a/test/infiniop/topkrouter.py b/test/infiniop/topkrouter.py index 6f851a89f..4a7679fd2 100644 --- a/test/infiniop/topkrouter.py +++ b/test/infiniop/topkrouter.py @@ -19,7 +19,7 @@ InfiniDtypeNames, InfiniDeviceNames, infiniopOperatorDescriptor_t, - torch_device_map + torch_device_map, ) # ============================================================================== @@ -34,7 +34,7 @@ # w (weight) types # Note: 'None' means the same as input dtype -_X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] # +_X_DTYPES = [InfiniDtype.F32, InfiniDtype.BF16, InfiniDtype.F16] # # x types used for testing _VALUE_DTYPES = [InfiniDtype.F32] @@ -55,7 +55,16 @@ def tensorInfo(data): - print("data: ", data.is_contiguous(), data.device, data.dtype, data.shape, data.stride(), data.data_ptr(), hex(data.data_ptr())) + print( + "data: ", + data.is_contiguous(), + data.device, + data.dtype, + data.shape, + data.stride(), + data.data_ptr(), + hex(data.data_ptr()), + ) class DeepseekV3TopkRouter(nn.Module): @@ -78,14 +87,22 @@ def __init__(self, correction_bias, config=None): @torch.no_grad() def get_topk_indices(self, scores): - scores_for_choice = scores.view(-1, self.n_routed_experts) + self.e_score_correction_bias.unsqueeze(0) # Size([1, 256]) + scores_for_choice = scores.view( + -1, self.n_routed_experts + ) + self.e_score_correction_bias.unsqueeze( + 0 + ) # Size([1, 256]) group_scores = ( - scores_for_choice.view(-1, self.n_group, self.n_routed_experts // self.n_group) + scores_for_choice.view( + -1, self.n_group, self.n_routed_experts // self.n_group + ) .topk(2, dim=-1)[0] .sum(dim=-1) ) - group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[1] # Size([1, 4]) + group_idx = torch.topk(group_scores, k=self.topk_group, dim=-1, sorted=True)[ + 1 + ] # Size([1, 4]) group_mask = torch.zeros_like(group_scores) # Size([1, 8]) group_mask.scatter_(1, group_idx, 1) # Size([1, 8]) @@ -95,8 +112,12 @@ def get_topk_indices(self, scores): .reshape(-1, self.n_routed_experts) ) - scores_for_choice = scores_for_choice.masked_fill(~score_mask.bool(), 0.0) # Size([1, 256]) - topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[1] # Size([1, 8]) + scores_for_choice = scores_for_choice.masked_fill( + ~score_mask.bool(), 0.0 + ) # Size([1, 256]) + topk_indices = torch.topk(scores_for_choice, k=self.top_k, dim=-1, sorted=True)[ + 1 + ] # Size([1, 8]) return topk_indices @@ -124,14 +145,14 @@ def torch_topkrouter(router_logits, correction_bias): def test( - handle, - device, - x_shape, - x_stride, - topk, - x_dtype=InfiniDtype.F32, - dtype=InfiniDtype.F16, - sync=None, + handle, + device, + x_shape, + x_stride, + topk, + x_dtype=InfiniDtype.F32, + dtype=InfiniDtype.F16, + sync=None, ): print( f"Testing topkrouter on {InfiniDeviceNames[device]} with x_shape:{x_shape}" @@ -141,8 +162,12 @@ def test( data = torch.arange(0, x_shape[0] * x_shape[1]).reshape(x_shape) N, width = x_shape - x = TestTensor(x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random") - correction_bias = TestTensor([x_shape[1]], [1], InfiniDtype.F32, device, mode="random") + x = TestTensor( + x_shape, data.stride(), x_dtype, device, scale=5.0, bias=-5.0, mode="random" + ) + correction_bias = TestTensor( + [x_shape[1]], [1], InfiniDtype.F32, device, mode="random" + ) if sync is not None: sync() @@ -150,10 +175,7 @@ def test( descriptor = infiniopOperatorDescriptor_t() check_error( LIBINFINIOP.infiniopCreateTopkrouterDescriptor( - handle, - ctypes.byref(descriptor), - x.descriptor, - correction_bias.descriptor + handle, ctypes.byref(descriptor), x.descriptor, correction_bias.descriptor ) ) @@ -169,8 +191,12 @@ def test( ) workspace = TestWorkspace(workspace_size.value, x.device) - values = torch.zeros((N, topk), dtype=torch.float32, device=torch_device_map[x.device]) - indices = torch.zeros((N, topk), dtype=torch.int32, device=torch_device_map[x.device]) + values = torch.zeros( + (N, topk), dtype=torch.float32, device=torch_device_map[x.device] + ) + indices = torch.zeros( + (N, topk), dtype=torch.int32, device=torch_device_map[x.device] + ) def lib_topkrouter(): check_error( @@ -189,7 +215,9 @@ def lib_topkrouter(): ) lib_topkrouter() - lable_values, lable_indices = torch_topkrouter(x.actual_tensor(), correction_bias.actual_tensor()) + lable_values, lable_indices = torch_topkrouter( + x.actual_tensor(), correction_bias.actual_tensor() + ) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) if DEBUG: diff --git a/xmake.lua b/xmake.lua index e517aeec2..f5ba079d0 100644 --- a/xmake.lua +++ b/xmake.lua @@ -62,10 +62,20 @@ option("cudnn") set_description("Whether to compile cudnn for Nvidia GPU") option_end() +option("cublaslt") + set_default(true) + set_showmenu(true) + set_description("Whether to compile cublaslt for Nvidia GPU") +option_end() + if has_config("cudnn") then add_defines("ENABLE_CUDNN_API") end +if has_config("cublaslt") then + add_defines("ENABLE_CUBLASLT_API") +end + -- 寒武纪 option("cambricon-mlu") set_default(false) @@ -375,4 +385,4 @@ target("infinicore") target_end() -- Tests -includes("xmake/test.lua") +includes("xmake/test.lua") \ No newline at end of file diff --git a/xmake/nvidia.lua b/xmake/nvidia.lua index 797edcb5e..375e61544 100644 --- a/xmake/nvidia.lua +++ b/xmake/nvidia.lua @@ -10,7 +10,7 @@ target("infiniop-nvidia") set_policy("build.cuda.devlink", true) set_toolchains("cuda") - add_links("cudart", "cublas") + add_links("cudart", "cublas", "cublasLt") if has_config("cudnn") then add_links("cudnn") end