Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/infinicore.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ typedef enum {
INFINI_STATUS_BAD_TENSOR_SHAPE = 11,
INFINI_STATUS_BAD_TENSOR_STRIDES = 12,
INFINI_STATUS_INSUFFICIENT_WORKSPACE = 13,
INFINI_STATUS_NOT_ALIGNED = 14,
} infiniStatus_t;

typedef enum {
Expand Down Expand Up @@ -70,6 +71,9 @@ typedef enum {
INFINI_DTYPE_C64 = 17,
INFINI_DTYPE_C128 = 18,
INFINI_DTYPE_BF16 = 19,
INFINI_DTYPE_F8_E4M3 = 20,
INFINI_DTYPE_F8_E5M2 = 21,
INFINI_DTYPE_F8_UE8M0 = 22,
} infiniDtype_t;

#endif // __INFINICORE_API_H__
6 changes: 4 additions & 2 deletions include/infiniop/operator_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
// Base descriptor for all operators
struct InfiniopDescriptor;

__C __export infiniStatus_t infiniopGetDescriptorDeviceType(const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type);
__C __export infiniStatus_t infiniopGetDescriptorDeviceId(const struct InfiniopDescriptor *desc_ptr, int *device_id);
__C __export infiniStatus_t infiniopGetDescriptorDeviceType(
const struct InfiniopDescriptor *desc_ptr, infiniDevice_t *device_type);
__C __export infiniStatus_t infiniopGetDescriptorDeviceId(
const struct InfiniopDescriptor *desc_ptr, int *device_id);

#endif //__INFINIOP_OPERATOR_DESCRIPTOR_API_H__
26 changes: 26 additions & 0 deletions include/infiniop/ops/linear.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
#ifndef __INFINIOP_LINEAR_API_H__
#define __INFINIOP_LINEAR_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopLinearDescriptor_t;

__C __export infiniStatus_t infiniopCreateLinearDescriptor(
infiniopHandle_t handle, infiniopLinearDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t d_desc, infiniopTensorDescriptor_t a_desc,
infiniopTensorDescriptor_t b_desc, infiniopTensorDescriptor_t c_desc);

__C __export infiniStatus_t
infiniopGetLinearWorkspaceSize(infiniopLinearDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopLinear(
infiniopLinearDescriptor_t desc, float alpha, const void *a,
const void *a_scale, const void *b, const void *b_scale, float beta,
const void *c, const void *c_scale, const void *bias, void *d,
const void *d_scale, bool is_blockwise, bool is_a_1d_scaled,
bool is_b_1d_scaled, void *workspace, size_t workspace_size, void *stream);

__C __export infiniStatus_t
infiniopDestroyLinearDescriptor(infiniopLinearDescriptor_t desc);

#endif
24 changes: 24 additions & 0 deletions include/infiniop/ops/quantize.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#ifndef __INFINIOP_QUANTIZE_API_H__
#define __INFINIOP_QUANTIZE_API_H__

#include "../operator_descriptor.h"

typedef struct InfiniopDescriptor *infiniopQuantizeDescriptor_t;

__C __export infiniStatus_t infiniopCreateQuantizeDescriptor(
infiniopHandle_t handle, infiniopQuantizeDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t input_desc,
infiniopTensorDescriptor_t output_q_desc,
infiniopTensorDescriptor_t output_s_desc);

__C __export infiniStatus_t infiniopGetQuantizeWorkspaceSize(
infiniopQuantizeDescriptor_t desc, size_t *size);

__C __export infiniStatus_t infiniopQuantize(
infiniopQuantizeDescriptor_t desc, void *workspace, size_t workspace_size,
void *input, void *output_q, void *output_s, int group_size, double eps,
double min_8bit, double max_8bit, bool scale_ue8m0, void *stream);

__C __export infiniStatus_t
infiniopDestroyQuantizeDescriptor(infiniopQuantizeDescriptor_t desc);
#endif
1 change: 1 addition & 0 deletions scripts/install.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
PROJECT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
os.chdir(PROJECT_DIR)


def run_cmd(cmd):
subprocess.run(cmd, text=True, encoding="utf-8", check=True, shell=True)

Expand Down
12 changes: 12 additions & 0 deletions src/infiniop/devices/nvidia/nvidia_common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,18 @@ infiniStatus_t Handle::Internal::useCudnn(cudaStream_t stream, const Fn<cudnnHan
}
#endif

#ifdef ENABLE_CUBLASLT_API
infiniStatus_t Handle::Internal::useCublasLt(cudaStream_t stream, const Fn<cublasLtHandle_t> &f) const {
auto handle = blaslt_handles.pop();
if (!handle) {
CHECK_CUBLASLT(cublasLtCreate(&(*handle)));
}
CHECK_STATUS(f(*handle));
blaslt_handles.push(std::move(*handle));
return INFINI_STATUS_SUCCESS;
}
#endif

int Handle::Internal::warpSize() const { return _warp_size; }
int Handle::Internal::maxThreadsPerBlock() const { return _max_threads_per_block; }
int Handle::Internal::blockSizeX() const { return _block_size[0]; }
Expand Down
18 changes: 18 additions & 0 deletions src/infiniop/devices/nvidia/nvidia_handle.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,26 @@
#include "../pool.h"
#include "nvidia_handle.h"
#include <cublas_v2.h>
#include <cuda.h>
#include <cuda_fp8.h>
#include <functional>

#ifdef ENABLE_CUDNN_API
#include <cudnn.h>
#endif

#ifdef ENABLE_CUBLASLT_API
#include <cublasLt.h>
#if CUDA_VERSION >= 12090
#define SUPPORT_FP8_BLOCKWISE_SCALE 1
#else
#define SUPPORT_FP8_BLOCKWISE_SCALE 0
#endif
#endif

#define CHECK_CUBLAS(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)
#define CHECK_CUDNN(API) CHECK_INTERNAL(API, CUDNN_STATUS_SUCCESS)
#define CHECK_CUBLASLT(API) CHECK_INTERNAL(API, CUBLAS_STATUS_SUCCESS)

namespace device::nvidia {

Expand All @@ -21,6 +33,9 @@ class Handle::Internal {
#ifdef ENABLE_CUDNN_API
Pool<cudnnHandle_t> dnn_handles;
#endif
#ifdef ENABLE_CUBLASLT_API
Pool<cublasLtHandle_t> blaslt_handles;
#endif

int _warp_size,
_max_threads_per_block,
Expand All @@ -37,6 +52,9 @@ public:
#ifdef ENABLE_CUDNN_API
infiniStatus_t useCudnn(cudaStream_t stream, const Fn<cudnnHandle_t> &f) const;
#endif
#ifdef ENABLE_CUBLASLT_API
infiniStatus_t useCublasLt(cudaStream_t stream, const Fn<cublasLtHandle_t> &f) const;
#endif

int warpSize() const;
int maxThreadsPerBlock() const;
Expand Down
120 changes: 62 additions & 58 deletions src/infiniop/elementwise/cpu/elementwise_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,22 @@
#include <utility>

/**
* @brief Define the process for initializing a Descriptor of an elementwise operation
* for its CPU implementation
* @brief Define the process for initializing a Descriptor of an elementwise
* operation for its CPU implementation
*
* @param HANDLE The device handle.
* @param DTYPE The output dtype.
* @param OUT_DESC The output tensor descriptor.
* @param INPUT_DESC_VEC A vector containing input tensor descriptors.
*/
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, INPUT_DESC_VEC) \
#define CREATE_ELEMENTWISE_CPU_DESCRIPTOR(HANDLE, DTYPE, OUT_DESC, \
INPUT_DESC_VEC) \
\
auto info_result = op::elementwise::ElementwiseInfo::create(OUT_DESC, INPUT_DESC_VEC); \
CHECK_RESULT(info_result); \
\
*desc_ptr = new Descriptor( \
DTYPE, \
info_result.take(), \
nullptr, \
0, \
HANDLE->device, \
HANDLE->device_id);
*desc_ptr = new Descriptor(DTYPE, info_result.take(), nullptr, 0, \
HANDLE->device, HANDLE->device_id);

namespace op::elementwise::cpu {

Expand Down Expand Up @@ -62,18 +58,17 @@ class DeviceImpl final {
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
infiniStatus_t calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream, Args &&...args);

/**
* @brief Dispatches an elementwise operation with heterogeneous input types.
*
* Supports operations where each input may have a different type, as defined by Op.
* The number of input types must match the operation's expected input count.
* Supports operations where each input may have a different type, as defined
* by Op. The number of input types must match the operation's expected input
* count.
*
* @tparam Op The elementwise operation to perform.
* @tparam Tout Output data type.
Expand All @@ -86,15 +81,12 @@ class DeviceImpl final {
* @param args Additional backend-specific arguments.
* @return infiniStatus_t Status indicating success or failure.
*/
template <typename Op, typename Tout, typename... Tin,
typename... Args,
template <typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
infiniStatus_t calculate(
const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args);
infiniStatus_t calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream, Args &&...args);
};

// Define the Opaque struct for CPU, which is empty
Expand All @@ -106,74 +98,86 @@ utils::Result<DeviceImpl> DeviceImpl::create(Args &&...args) {
}

// Perform elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, size_t... Is, typename... Args,
template <typename Op, typename Tout, typename... Tin, size_t... Is,
typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int> = 0>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
std::index_sequence<Is...>, Args &&...args) {

Tout *out = reinterpret_cast<Tout *>(output);
std::tuple<const Tin *...> input_ptrs = {reinterpret_cast<const Tin *>(inputs[Is])...};
std::tuple<const Tin *...> input_ptrs = {
reinterpret_cast<const Tin *>(inputs[Is])...};
ptrdiff_t output_size = info.getOutputSize();

#pragma omp parallel for
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
: op::common_cpu::indexToOffset(
i, info.getNdim(), info.getOutputShape(),
info.getOutputStrides());

auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
: op::common_cpu::indexToOffset(
i, info.getNdim(), info.getInputShape(input_id),
info.getInputStrides(input_id));
};

out[out_idx] = utils::cast<Tout>(
Op{}.template operator()<Tout, Tin...>(std::get<Is>(input_ptrs)[get_input_idx(Is)]..., std::forward<Args>(args)...));
out[out_idx] = utils::cast<Tout>(Op{}.template operator()<Tout, Tin...>(
std::get<Is>(input_ptrs)[get_input_idx(Is)]...,
std::forward<Args>(args)...));
}
}

// Invoke elementwise operation for different input types
template <typename Op, typename Tout, typename... Tin, typename... Args, std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
template <typename Op, typename Tout, typename... Tin, typename... Args,
std::enable_if_t<(sizeof...(Tin) == Op::num_inputs), int>>
infiniStatus_t
DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output, const std::vector<const void *> &inputs,
void *stream, Args &&...args) {

static_assert(sizeof...(Tin) == Op::num_inputs, "Input type count mismatch");
calculate_impl<Op, Tout, Tin...>(info, output, inputs, std::make_index_sequence<sizeof...(Tin)>{}, std::forward<Args>(args)...);
calculate_impl<Op, Tout, Tin...>(info, output, inputs,
std::make_index_sequence<sizeof...(Tin)>{},
std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}

// Perform elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, size_t... Is, typename... Args>
void calculate_impl(const op::elementwise::ElementwiseInfo &info,
void *output,
void calculate_impl(const op::elementwise::ElementwiseInfo &info, void *output,
const std::vector<const void *> &inputs,
std::index_sequence<Is...>,
Args &&...args) {
std::index_sequence<Is...>, Args &&...args) {

Tdata *out = reinterpret_cast<Tdata *>(output);
std::array<const Tdata *, sizeof...(Is)> ins = {reinterpret_cast<const Tdata *>(inputs[Is])...};
std::array<const Tdata *, sizeof...(Is)> ins = {
reinterpret_cast<const Tdata *>(inputs[Is])...};
const ptrdiff_t output_size = info.getOutputSize();

#pragma omp parallel for if (output_size > 1024)
for (ptrdiff_t i = 0; i < output_size; ++i) {
size_t out_idx = info.isOutputContiguous()
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getOutputShape(), info.getOutputStrides());
: op::common_cpu::indexToOffset(
i, info.getNdim(), info.getOutputShape(),
info.getOutputStrides());

auto get_input_idx = [&](size_t input_id) {
return info.getInputContiguous()[input_id]
? i
: op::common_cpu::indexToOffset(i, info.getNdim(), info.getInputShape(input_id), info.getInputStrides(input_id));
: op::common_cpu::indexToOffset(
i, info.getNdim(), info.getInputShape(input_id),
info.getInputStrides(input_id));
};

if constexpr (std::is_same_v<Tdata, fp16_t> || std::is_same_v<Tdata, bf16_t>) {
out[out_idx] = utils::cast<Tdata>(Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])..., std::forward<Args>(args)...));
out[out_idx] = utils::cast<Tdata>(
Op{}(utils::cast<float>(ins[Is][get_input_idx(Is)])...,
std::forward<Args>(args)...));
} else {
out[out_idx] = Op{}(ins[Is][get_input_idx(Is)]..., std::forward<Args>(args)...);
}
Expand All @@ -182,16 +186,16 @@ void calculate_impl(const op::elementwise::ElementwiseInfo &info,

// Invoke elementwise operation when all inputs have the same type
template <typename Op, typename Tdata, typename... Args>
infiniStatus_t DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output,
const std::vector<const void *> &inputs,
void *stream,
Args &&...args) {
infiniStatus_t
DeviceImpl::calculate(const op::elementwise::ElementwiseInfo &info,
void *output, const std::vector<const void *> &inputs,
void *stream, Args &&...args) {
constexpr size_t N = Op::num_inputs;
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{}, std::forward<Args>(args)...);
calculate_impl<Op, Tdata>(info, output, inputs, std::make_index_sequence<N>{},
std::forward<Args>(args)...);
return INFINI_STATUS_SUCCESS;
}

} // namespace op::elementwise::cpu

#endif // __INFINIOP_ELEMENTWISE_CPU_H__
#endif // __INFINIOP_ELEMENTWISE_CPU_H__
Loading