diff --git a/include/infini_operators.h b/include/infini_operators.h index 9a5a2555..2002ec1b 100644 --- a/include/infini_operators.h +++ b/include/infini_operators.h @@ -16,4 +16,10 @@ #include "ops/rms_norm/rms_norm.h" #include "ops/rotary_embedding/rotary_embedding.h" #include "ops/swiglu/swiglu.h" +#include "ops/clip/clip.h" +#include "ops/where/where.h" +#include "ops/gather/gather.h" +#include "ops/reduce_max/reduce_max.h" +#include "ops/reduce_mean/reduce_mean.h" +#include "ops/reduce_min/reduce_min.h" #include "tensor/tensor_descriptor.h" diff --git a/include/ops/clip/clip.h b/include/ops/clip/clip.h new file mode 100644 index 00000000..66d35088 --- /dev/null +++ b/include/ops/clip/clip.h @@ -0,0 +1,27 @@ +#ifndef CLIP_H +#define CLIP_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ClipDescriptor { + Device device; +} ClipDescriptor; + +typedef ClipDescriptor *infiniopClipDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle, + infiniopClipDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t input, + float *min, + float *max); + +__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, + void *output, + void const *input, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc); + +#endif diff --git a/include/ops/gather/gather.h b/include/ops/gather/gather.h new file mode 100644 index 00000000..a564c8d2 --- /dev/null +++ b/include/ops/gather/gather.h @@ -0,0 +1,29 @@ +#ifndef GATHER_H +#define GATHER_H + +#include "../../export.h" +#include "../../operators.h" +#include + +typedef struct GatherDescriptor { + Device device; +} GatherDescriptor; + +typedef GatherDescriptor *infiniopGatherDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle, + infiniopGatherDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t input, + infiniopTensorDescriptor_t indices, + int64_t axis); + +__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, + void *output, + void const *input, + void const *indices, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc); + +#endif diff --git a/include/ops/reduce_max/reduce_max.h b/include/ops/reduce_max/reduce_max.h new file mode 100644 index 00000000..e2449964 --- /dev/null +++ b/include/ops/reduce_max/reduce_max.h @@ -0,0 +1,24 @@ +#ifndef REDUCE_MAX_H +#define REDUCE_MAX_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceMaxDescriptor { + Device device; +} ReduceMaxDescriptor; +typedef ReduceMaxDescriptor *infiniopReduceMaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMaxDescriptor(infiniopHandle_t handle, + infiniopReduceMaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims); + +__C __export infiniopStatus_t infiniopReduceMax(infiniopReduceMaxDescriptor_t desc, void *y, void const *x, void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceMaxDescriptor(infiniopReduceMaxDescriptor_t desc); + +#endif diff --git a/include/ops/reduce_mean/reduce_mean.h b/include/ops/reduce_mean/reduce_mean.h new file mode 100644 index 00000000..66215c65 --- /dev/null +++ b/include/ops/reduce_mean/reduce_mean.h @@ -0,0 +1,24 @@ +#ifndef REDUCE_MEAN_H +#define REDUCE_MEAN_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceMeanDescriptor { + Device device; +} ReduceMeanDescriptor; +typedef ReduceMeanDescriptor *infiniopReduceMeanDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMeanDescriptor(infiniopHandle_t handle, + infiniopReduceMeanDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims); + +__C __export infiniopStatus_t infiniopReduceMean(infiniopReduceMeanDescriptor_t desc, void *y, void const *x, void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceMeanDescriptor(infiniopReduceMeanDescriptor_t desc); + +#endif diff --git a/include/ops/reduce_min/reduce_min.h b/include/ops/reduce_min/reduce_min.h new file mode 100644 index 00000000..feb9200e --- /dev/null +++ b/include/ops/reduce_min/reduce_min.h @@ -0,0 +1,24 @@ +#ifndef REDUCE_MIN_H +#define REDUCE_MIN_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct ReduceMinDescriptor { + Device device; +} ReduceMinDescriptor; +typedef ReduceMinDescriptor *infiniopReduceMinDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMinDescriptor(infiniopHandle_t handle, + infiniopReduceMinDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims); + +__C __export infiniopStatus_t infiniopReduceMin(infiniopReduceMinDescriptor_t desc, void *y, void const *x, void *stream); + +__C __export infiniopStatus_t infiniopDestroyReduceMinDescriptor(infiniopReduceMinDescriptor_t desc); + +#endif diff --git a/include/ops/where/where.h b/include/ops/where/where.h new file mode 100644 index 00000000..0cec3c5e --- /dev/null +++ b/include/ops/where/where.h @@ -0,0 +1,29 @@ +#ifndef WHERE_H +#define WHERE_H + +#include "../../export.h" +#include "../../operators.h" + +typedef struct WhereDescriptor { + Device device; +} WhereDescriptor; + +typedef WhereDescriptor *infiniopWhereDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t condition, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y); + +__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, + void *output, + void const *condition, + void const *x, + void const *y, + void *stream); + +__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc); + +#endif diff --git a/operatorspy/tests/clip.py b/operatorspy/tests/clip.py new file mode 100644 index 00000000..318d4048 --- /dev/null +++ b/operatorspy/tests/clip.py @@ -0,0 +1,133 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_float +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch +import numpy as np + + +class ClipDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopClipDescriptor_t = POINTER(ClipDescriptor) + + +def clip(x, min, max): + return torch.clip(x, min, max) + + +def test( + lib, + handle, + torch_device, + c_shape, + min, + max, + tensor_dtype=torch.float16, +): + print( + f"Testing Clip on {torch_device} with c_shape:{c_shape} dtype:{tensor_dtype}" + ) + + input = torch.rand(c_shape, dtype=tensor_dtype).to(torch_device) + output = torch.empty(c_shape, dtype=tensor_dtype).to(torch_device) + min_v = min if min else torch.finfo(tensor_dtype).min + max_v = max if max else torch.finfo(tensor_dtype).max + min_val = torch.tensor(min_v, dtype=tensor_dtype).to(torch_device) + max_val = torch.tensor(max_v, dtype=tensor_dtype).to(torch_device) + # min = np.random.uniform(0, 1) + # max = np.random.uniform(0, 1) + min_fp16_value = min_val.item() + max_fp16_value = max_val.item() + + ans = clip(input, min_val, max_val) + + input_tensor = to_tensor(input, lib) + output_tensor = to_tensor(output, lib) + descriptor = infiniopClipDescriptor_t() + + min_c = c_float(min_fp16_value) + max_c = c_float(max_fp16_value) + + check_error( + lib.infiniopCreateClipDescriptor( + handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + input_tensor.descriptor, + ctypes.byref(min_c) if min else None, + ctypes.byref(max_c) if max else None + ) + ) + + input_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + + check_error( + lib.infiniopClip(descriptor, output_tensor.data, input_tensor.data, None) + ) + + assert torch.allclose(output, ans, atol=0, rtol=0) + check_error(lib.infiniopDestroyClipDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for c_shape, min, max in test_cases: + test(lib, handle, "cpu", c_shape, min, max, tensor_dtype=torch.float16) + test(lib, handle, "cpu", c_shape, min, max, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # c_shape + ((1, 3), 0.2, 0.4), + ((3, 3), -0.1, 0.7), + ((2, 20, 3), 0.5, 0.9), + ((32, 20, 512), -0.2, 0.9), + ((32, 256, 112, 112), 0.1, None), + ((3, 2, 4, 5), None, None), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateClipDescriptor.restype = c_int32 + lib.infiniopCreateClipDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopClipDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + POINTER(c_float), + POINTER(c_float), + ] + lib.infiniopClip.restype = c_int32 + lib.infiniopClip.argtypes = [ + infiniopClipDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyClipDescriptor.restype = c_int32 + lib.infiniopDestroyClipDescriptor.argtypes = [ + infiniopClipDescriptor_t, + ] + + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/gather.py b/operatorspy/tests/gather.py new file mode 100644 index 00000000..f63bd71b --- /dev/null +++ b/operatorspy/tests/gather.py @@ -0,0 +1,125 @@ +import sys +import os +import ctypes +import numpy as np +import torch + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from ctypes import POINTER, Structure, c_int32, c_void_p + +class GatherDescriptor(Structure): + _fields_ = [("device", c_int32)] + +infiniopGatherDescriptor_t = POINTER(GatherDescriptor) + + +def gather(data, axis, indices): + # 计算输出形状:data.shape[:axis] + indices.shape + data.shape[axis+1:] + out_shape = list(data.shape[:axis]) + list(indices.shape) + list(data.shape[axis+1:]) + # 生成输出张量各维度坐标的网格 + grids = torch.meshgrid(*[torch.arange(s, device=data.device) for s in out_shape], indexing='ij') + # 为 gather 生成高级索引列表 + index_list = [] + # 对于 data 的前 axis 维度,对应网格的前 axis 个 + for i in range(axis): + index_list.append(grids[i]) + # 对于 data 第 axis 维,使用 indices,需要先将其扩展到输出形状 + new_shape = [1] * axis + list(indices.shape) + [1] * (data.dim() - axis - 1) + indices_expanded = indices.view(new_shape).expand(out_shape) + index_list.append(indices_expanded) + # 对于 data 后续维度,映射到网格中对应的位置 + for i in range(axis, data.dim() - 1): + index_list.append(grids[i + len(indices.shape)]) + return data[tuple(index_list)] + + +def test(lib, handle, torch_device, data_shape, indices_shape, axis, tensor_dtype=torch.float16): + print(f"Testing Gather on {torch_device} with data_shape:{data_shape}, indices_shape:{indices_shape}, axis:{axis}, dtype:{tensor_dtype}") + data = torch.rand(data_shape, dtype=tensor_dtype).to(torch_device) + indices = torch.randint(0, data_shape[axis], indices_shape, dtype=torch.int64).to(torch_device) + output = torch.empty(data.shape[:axis] + indices.shape + data.shape[axis+1:], dtype=tensor_dtype).to(torch_device) + ans = gather(data, axis, indices) + + + data_tensor = to_tensor(data, lib) + indices_tensor = to_tensor(indices, lib) + output_tensor = to_tensor(output, lib) + + descriptor = infiniopGatherDescriptor_t() + check_error(lib.infiniopCreateGatherDescriptor(handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + data_tensor.descriptor, + indices_tensor.descriptor, + ctypes.c_int64(axis))) + + data_tensor.descriptor.contents.invalidate() + indices_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + + check_error(lib.infiniopGather(descriptor, + output_tensor.data, + data_tensor.data, + indices_tensor.data, + None)) + assert torch.allclose(output, ans, atol=0, rtol=0) + check_error(lib.infiniopDestroyGatherDescriptor(descriptor)) + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for data_shape, indices_shape, axis in test_cases: + test(lib, handle, "cpu", data_shape, indices_shape, axis, tensor_dtype=torch.float16) + test(lib, handle, "cpu", data_shape, indices_shape, axis, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + +if __name__ == "__main__": + test_cases = [ + # (data_shape, indices_shape, axis) + ((3, 4), (2,), 0), + ((3, 4), (3,), 1), + ((2, 3, 4), (2,), 1), + ((3, 2), (2, 2), 0), + ((3, 3), (1, 2), 1), + ] + from operatorspy.tests.test_utils import get_args + args = get_args() + lib = open_lib() + lib.infiniopCreateGatherDescriptor.restype = ctypes.c_int32 + lib.infiniopCreateGatherDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopGatherDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ctypes.c_int64, + ] + lib.infiniopGather.restype = ctypes.c_int32 + lib.infiniopGather.argtypes = [ + infiniopGatherDescriptor_t, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ctypes.c_void_p, + ] + lib.infiniopDestroyGatherDescriptor.restype = ctypes.c_int32 + lib.infiniopDestroyGatherDescriptor.argtypes = [ + infiniopGatherDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if not args.cpu: + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/reduce_max.py b/operatorspy/tests/reduce_max.py new file mode 100644 index 00000000..23c43415 --- /dev/null +++ b/operatorspy/tests/reduce_max.py @@ -0,0 +1,185 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64 +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class ReduceMaxDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopReduceMaxDescriptor_t = POINTER(ReduceMaxDescriptor) + + +def reduce_max(x, axes, keepdim=False): + return torch.amax(x, dim=axes, keepdim=keepdim) + + +def inferShape(x_shape, axes, keep_dims): + output_shape = list(x_shape) + + if keep_dims: + for axis in axes: + # Convert negative axis to positive + actual_axis = axis if axis >= 0 else len(x_shape) + axis + output_shape[actual_axis] = 1 + else: + # Sort axes in descending order to avoid index shifting after removal + sorted_axes = sorted([axis if axis >= 0 else len(x_shape) + axis for axis in axes], reverse=True) + for axis in sorted_axes: + output_shape.pop(axis) + + return tuple(output_shape) + +# convert a python tuple to a ctype void pointer +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + axes, + keep_dims, + tensor_dtype=torch.float32, +): + print( + f"Testing ReduceMax on {torch_device} with x_shape:{x_shape} axes:{axes} keep_dims:{keep_dims} dtype:{tensor_dtype}" + ) + + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y_shape = inferShape(x_shape, axes, keep_dims) + y = torch.zeros(y_shape, dtype=tensor_dtype).to(torch_device) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = reduce_max(x, axes, keep_dims) + + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + ans = reduce_max(x, axes, keep_dims) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopReduceMaxDescriptor_t() + + check_error( + lib.infiniopCreateReduceMaxDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + tuple_to_void_p(axes), + len(axes), + 1 if keep_dims else 0, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopReduceMax( + descriptor, + y_tensor.data, + x_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopReduceMax( + descriptor, + y_tensor.data, + x_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + if (tensor_dtype == torch.float16): + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + else: + assert torch.allclose(y, ans, atol=0, rtol=1e-5) + check_error(lib.infiniopDestroyReduceMaxDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axes, keep_dims in test_cases: + test(lib, handle, "cpu", x_shape, axes, keep_dims, tensor_dtype=torch.float32) + test(lib, handle, "cpu", x_shape, axes, keep_dims, tensor_dtype=torch.float16) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, axes, keep_dims + ((2, 3, 4, 5), (1,), True), + ((2, 3, 4, 5), (1,), False), + ((2, 3, 4, 5), (2, 1), True), + ((2, 3, 4, 5), (1, 2), False), + ((2, 3, 4, 5), (0, 1, 2, 3), True), + ((2, 3, 4, 5), (-1, -2), True), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceMaxDescriptor.restype = c_int32 + lib.infiniopCreateReduceMaxDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceMaxDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_uint64, + c_int32, + ] + lib.infiniopReduceMax.restype = c_int32 + lib.infiniopReduceMax.argtypes = [ + infiniopReduceMaxDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyReduceMaxDescriptor.restype = c_int32 + lib.infiniopDestroyReduceMaxDescriptor.argtypes = [ + infiniopReduceMaxDescriptor_t, + ] + + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/reduce_mean.py b/operatorspy/tests/reduce_mean.py new file mode 100644 index 00000000..06f9b2bf --- /dev/null +++ b/operatorspy/tests/reduce_mean.py @@ -0,0 +1,185 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64 +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class ReduceMeanDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopReduceMeanDescriptor_t = POINTER(ReduceMeanDescriptor) + + +def reduce_mean(x, axes, keepdim=False): + return torch.mean(x, dim=axes, keepdim=keepdim) + + +def inferShape(x_shape, axes, keep_dims): + output_shape = list(x_shape) + + if keep_dims: + for axis in axes: + # Convert negative axis to positive + actual_axis = axis if axis >= 0 else len(x_shape) + axis + output_shape[actual_axis] = 1 + else: + # Sort axes in descending order to avoid index shifting after removal + sorted_axes = sorted([axis if axis >= 0 else len(x_shape) + axis for axis in axes], reverse=True) + for axis in sorted_axes: + output_shape.pop(axis) + + return tuple(output_shape) + +# convert a python tuple to a ctype void pointer +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + axes, + keep_dims, + tensor_dtype=torch.float32, +): + print( + f"Testing ReduceMean on {torch_device} with x_shape:{x_shape} axes:{axes} keep_dims:{keep_dims} dtype:{tensor_dtype}" + ) + + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y_shape = inferShape(x_shape, axes, keep_dims) + y = torch.zeros(y_shape, dtype=tensor_dtype).to(torch_device) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = reduce_mean(x, axes, keep_dims) + + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + ans = reduce_mean(x, axes, keep_dims) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopReduceMeanDescriptor_t() + + check_error( + lib.infiniopCreateReduceMeanDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + tuple_to_void_p(axes), + len(axes), + 1 if keep_dims else 0, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopReduceMean( + descriptor, + y_tensor.data, + x_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopReduceMean( + descriptor, + y_tensor.data, + x_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + if (tensor_dtype == torch.float16): + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + else: + assert torch.allclose(y, ans, atol=0, rtol=1e-5) + check_error(lib.infiniopDestroyReduceMeanDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axes, keep_dims in test_cases: + test(lib, handle, "cpu", x_shape, axes, keep_dims, tensor_dtype=torch.float32) + test(lib, handle, "cpu", x_shape, axes, keep_dims, tensor_dtype=torch.float16) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, axes, keep_dims + ((2, 3, 4, 5), (1,), True), + ((2, 3, 4, 5), (1,), False), + ((2, 3, 4, 5), (1, 2), True), + ((2, 3, 4, 5), (1, 2), False), + ((2, 3, 4, 5), (0, 1, 2, 3), True), + ((2, 3, 4, 5), (-1, -2), True), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceMeanDescriptor.restype = c_int32 + lib.infiniopCreateReduceMeanDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceMeanDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_uint64, + c_int32, + ] + lib.infiniopReduceMean.restype = c_int32 + lib.infiniopReduceMean.argtypes = [ + infiniopReduceMeanDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyReduceMeanDescriptor.restype = c_int32 + lib.infiniopDestroyReduceMeanDescriptor.argtypes = [ + infiniopReduceMeanDescriptor_t, + ] + + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/reduce_min.py b/operatorspy/tests/reduce_min.py new file mode 100644 index 00000000..5cb94e7c --- /dev/null +++ b/operatorspy/tests/reduce_min.py @@ -0,0 +1,185 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p, c_uint64, c_int64 +import ctypes +import sys +import os +import time + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +import torch +from typing import Tuple + +# constant for control whether profile the pytorch and lib functions +# NOTE: need to manually add synchronization function to the lib function, +# e.g., cudaDeviceSynchronize() for CUDA +PROFILE = False +NUM_PRERUN = 10 +NUM_ITERATIONS = 1000 + + +class ReduceMinDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopReduceMinDescriptor_t = POINTER(ReduceMinDescriptor) + + +def reduce_min(x, axes, keepdim=False): + return torch.amin(x, dim=axes, keepdim=keepdim) + + +def inferShape(x_shape, axes, keep_dims): + output_shape = list(x_shape) + + if keep_dims: + for axis in axes: + # Convert negative axis to positive + actual_axis = axis if axis >= 0 else len(x_shape) + axis + output_shape[actual_axis] = 1 + else: + # Sort axes in descending order to avoid index shifting after removal + sorted_axes = sorted([axis if axis >= 0 else len(x_shape) + axis for axis in axes], reverse=True) + for axis in sorted_axes: + output_shape.pop(axis) + + return tuple(output_shape) + +# convert a python tuple to a ctype void pointer +def tuple_to_void_p(py_tuple: Tuple): + array = ctypes.c_int64 * len(py_tuple) + data_array = array(*py_tuple) + return ctypes.cast(data_array, ctypes.c_void_p) + +def test( + lib, + handle, + torch_device, + x_shape, + axes, + keep_dims, + tensor_dtype=torch.float32, +): + print( + f"Testing ReduceMin on {torch_device} with x_shape:{x_shape} axes:{axes} keep_dims:{keep_dims} dtype:{tensor_dtype}" + ) + + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y_shape = inferShape(x_shape, axes, keep_dims) + y = torch.zeros(y_shape, dtype=tensor_dtype).to(torch_device) + + for i in range(NUM_PRERUN if PROFILE else 1): + ans = reduce_min(x, axes, keep_dims) + + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + ans = reduce_min(x, axes, keep_dims) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f"pytorch time: {elapsed :6f}") + + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + descriptor = infiniopReduceMinDescriptor_t() + + check_error( + lib.infiniopCreateReduceMinDescriptor( + handle, + ctypes.byref(descriptor), + y_tensor.descriptor, + x_tensor.descriptor, + tuple_to_void_p(axes), + len(axes), + 1 if keep_dims else 0, + ) + ) + + # Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + + for i in range(NUM_PRERUN if PROFILE else 1): + check_error( + lib.infiniopReduceMin( + descriptor, + y_tensor.data, + x_tensor.data, + None, + ) + ) + if PROFILE: + start_time = time.time() + for i in range(NUM_ITERATIONS): + check_error( + lib.infiniopReduceMin( + descriptor, + y_tensor.data, + x_tensor.data, + None, + ) + ) + elapsed = (time.time() - start_time) / NUM_ITERATIONS + print(f" lib time: {elapsed :6f}") + + if (tensor_dtype == torch.float16): + assert torch.allclose(y, ans, atol=0, rtol=1e-3) + else: + assert torch.allclose(y, ans, atol=0, rtol=1e-5) + check_error(lib.infiniopDestroyReduceMinDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for x_shape, axes, keep_dims in test_cases: + test(lib, handle, "cpu", x_shape, axes, keep_dims, tensor_dtype=torch.float32) + test(lib, handle, "cpu", x_shape, axes, keep_dims, tensor_dtype=torch.float16) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # x_shape, axes, keep_dims + ((2, 3, 4, 5), (1,), True), + ((2, 3, 4, 5), (1,), False), + ((2, 3, 4, 5), (1, 2), True), + ((2, 3, 4, 5), (1, 2), False), + ((2, 3, 4, 5), (0, 1, 2, 3), True), + ((2, 3, 4, 5), (-1, -2), True), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateReduceMinDescriptor.restype = c_int32 + lib.infiniopCreateReduceMinDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopReduceMinDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + c_void_p, + c_uint64, + c_int32, + ] + lib.infiniopReduceMin.restype = c_int32 + lib.infiniopReduceMin.argtypes = [ + infiniopReduceMinDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyReduceMinDescriptor.restype = c_int32 + lib.infiniopDestroyReduceMinDescriptor.argtypes = [ + infiniopReduceMinDescriptor_t, + ] + + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/operatorspy/tests/where.py b/operatorspy/tests/where.py new file mode 100644 index 00000000..3c57689d --- /dev/null +++ b/operatorspy/tests/where.py @@ -0,0 +1,136 @@ +from ctypes import POINTER, Structure, c_int32, c_void_p +import ctypes +import sys +import os + +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) +from operatorspy import ( + open_lib, + to_tensor, + DeviceEnum, + infiniopHandle_t, + infiniopTensorDescriptor_t, + create_handle, + destroy_handle, + check_error, +) + +from operatorspy.tests.test_utils import get_args +from enum import Enum, auto +import torch +import numpy as np + + +class WhereDescriptor(Structure): + _fields_ = [("device", c_int32)] + + +infiniopWhereDescriptor_t = POINTER(WhereDescriptor) + + +def where(condition, x, y): + return torch.where(condition.bool(), x, y) + + +def test( + lib, + handle, + torch_device, + output_shape, + condition_shape, + x_shape, + y_shape, + tensor_dtype=torch.float16, +): + print( + f"Testing Where on {torch_device} with output_shape:{output_shape} condition_shape:{condition_shape} x_shape:{x_shape} y_shape:{y_shape} dtype:{tensor_dtype}" + ) + + condition = torch.randint(0, 2, condition_shape, dtype=torch.uint8).to(torch_device) + x = torch.rand(x_shape, dtype=tensor_dtype).to(torch_device) + y = torch.rand(y_shape, dtype=tensor_dtype).to(torch_device) + output = torch.rand(output_shape, dtype=tensor_dtype).to(torch_device) + + ans = where(condition, x, y) + + condition_tensor = to_tensor(condition, lib) + x_tensor = to_tensor(x, lib) + y_tensor = to_tensor(y, lib) + output_tensor = to_tensor(output, lib) + descriptor = infiniopWhereDescriptor_t() + + check_error( + lib.infiniopCreateWhereDescriptor( + handle, + ctypes.byref(descriptor), + output_tensor.descriptor, + condition_tensor.descriptor, + x_tensor.descriptor, + y_tensor.descriptor, + ) + ) + + condition_tensor.descriptor.contents.invalidate() + x_tensor.descriptor.contents.invalidate() + y_tensor.descriptor.contents.invalidate() + output_tensor.descriptor.contents.invalidate() + + check_error( + lib.infiniopWhere(descriptor, output_tensor.data, condition_tensor.data, x_tensor.data, y_tensor.data, None) + ) + assert torch.allclose(output, ans, atol=0, rtol=0) + check_error(lib.infiniopDestroyWhereDescriptor(descriptor)) + + +def test_cpu(lib, test_cases): + device = DeviceEnum.DEVICE_CPU + handle = create_handle(lib, device) + for output_shape, condition_shape, x_shape, y_shape in test_cases: + test(lib, handle, "cpu", output_shape, condition_shape, x_shape, y_shape, tensor_dtype=torch.float16) + test(lib, handle, "cpu", output_shape, condition_shape, x_shape, y_shape, tensor_dtype=torch.float32) + destroy_handle(lib, handle) + + +if __name__ == "__main__": + test_cases = [ + # output_shape, condition_shape, x_shape, y_shape + ((1, 3), (1, 3), (1, 3), (1, 3)), + ((), (), (), ()), + ((3, 3), (3, 3), (3, 3), (3, 3)), + ((2, 20, 3), (2, 1, 3), (2, 20, 3), (2, 20, 3)), + ((32, 20, 512), (32, 20, 512), (32, 20, 512), (32, 20, 512)), + ((32, 256, 112, 112), (32, 256, 112, 1), (32, 256, 112, 112), (32, 256, 112, 112)), + ((2, 4, 3), (2, 1, 3), (4, 3), (4, 3)), + ((2, 3, 4, 5), (2, 3, 4, 5), (5,), (5,)), + ((3, 2, 4, 5), (4, 5), (3, 2, 1, 1), (3, 2, 1, 1)), + ] + args = get_args() + lib = open_lib() + lib.infiniopCreateWhereDescriptor.restype = c_int32 + lib.infiniopCreateWhereDescriptor.argtypes = [ + infiniopHandle_t, + POINTER(infiniopWhereDescriptor_t), + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + infiniopTensorDescriptor_t, + ] + lib.infiniopWhere.restype = c_int32 + lib.infiniopWhere.argtypes = [ + infiniopWhereDescriptor_t, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + ] + lib.infiniopDestroyWhereDescriptor.restype = c_int32 + lib.infiniopDestroyWhereDescriptor.argtypes = [ + infiniopWhereDescriptor_t, + ] + + if args.cpu: + test_cpu(lib, test_cases) + if not args.cpu: + test_cpu(lib, test_cases) + print("\033[92mTest passed!\033[0m") diff --git a/src/ops/clip/operator.cc b/src/ops/clip/operator.cc new file mode 100644 index 00000000..db915a16 --- /dev/null +++ b/src/ops/clip/operator.cc @@ -0,0 +1,90 @@ +#include "../../devices/cpu/common_cpu.h" +#include "../utils.h" +#include "operators.h" +#include "ops/clip/clip.h" +#include "status.h" +#include "tensor/tensor_descriptor.h" +#include +#include +#include +#include + +struct _ClipDescriptor { + Device device; + DT dtype; + uint64_t output_data_size; + float min_value; + float max_value; +}; + +typedef struct _ClipDescriptor *_ClipDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle, + infiniopClipDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t input, + float *min, + float *max) { + if (!is_contiguous(output) || !is_contiguous(input)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (output->dt != input->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + // type: f16, f32 + if (output->dt != F16 && output->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t output_data_size = std::accumulate(output->shape, output->shape + output->ndim, 1ULL, std::multiplies()); + + float min_val = (min != nullptr) ? *min : std::numeric_limits::lowest(); + float max_val = (max != nullptr) ? *max : std::numeric_limits::max(); + + *(_ClipDescriptor_t *) desc_ptr = new _ClipDescriptor{ + handle->device, + output->dt, + output_data_size, + min_val, + max_val}; + return STATUS_SUCCESS; +} + +template +infiniopStatus_t clip_cpu(_ClipDescriptor_t desc, void *output, void const *input) { + auto output_data = reinterpret_cast(output); + auto input_data = reinterpret_cast(input); + float min_value = desc->min_value; + float max_value = desc->max_value; + + for (uint64_t i = 0; i < desc->output_data_size; ++i) { + if constexpr (std::is_same::value) { + output_data[i] = f32_to_f16(std::min(std::max(f16_to_f32(input_data[i]), min_value), max_value)); + } else { + output_data[i] = std::min(std::max(input_data[i], min_value), max_value); + } + } + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc, + void *output, + void const *input, + void *stream) { + auto _desc = (_ClipDescriptor_t) desc; + auto dtype = _desc->dtype; + + if (dtype == F16) { + return clip_cpu(_desc, output, input); + } else if (dtype == F32) { + return clip_cpu(_desc, output, input); + } + + return STATUS_BAD_TENSOR_DTYPE; +} + +__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc) { + delete (_ClipDescriptor_t) desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/gather/operator.cc b/src/ops/gather/operator.cc new file mode 100644 index 00000000..2eebc1f7 --- /dev/null +++ b/src/ops/gather/operator.cc @@ -0,0 +1,180 @@ +#include "../../devices/cpu/common_cpu.h" +#include "../utils.h" +#include "data_type.h" +#include "operators.h" +#include "ops/gather/gather.h" +#include "status.h" +#include "tensor/tensor_descriptor.h" +#include +#include +#include +#include +#include +#include +#include + +struct _GatherDescriptor { + Device device; + DT dtype; + DT indices_type; + uint64_t data_ndim; + uint64_t indices_ndim; + uint64_t output_ndim; + uint64_t output_data_size; + uint64_t const *output_shape; + uint64_t const *input_strides; + uint64_t const *indices_strides; + uint64_t *output_indices; + int64_t axis; +}; + +typedef struct _GatherDescriptor *_GatherDescriptor_t; + +inline void incrementOne(uint64_t *indices, uint64_t const *shape, uint64_t ndim) { + for (int64_t i = ndim - 1; i >= 0; --i) { + if (++indices[i] != shape[i]) { + return; + } + indices[i] = 0; + } +} + +inline uint64_t compactToFlat(uint64_t const *indices, uint64_t const *strides, uint64_t ndim) { + return std::inner_product(indices, indices + ndim, strides, uint64_t(0)); +} + +__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle, + infiniopGatherDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t input, + infiniopTensorDescriptor_t indices, + int64_t axis) { + if (!is_contiguous(input) || !is_contiguous(indices) || !is_contiguous(output)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (indices->dt != I32 && indices->dt != I64) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (output->dt != input->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t data_ndim = input->ndim; + uint64_t indices_ndim = indices->ndim; + if (axis < -(int64_t(data_ndim)) || axis >= int64_t(data_ndim)) { + return STATUS_BAD_PARAM; + } + if (axis < 0) { + axis += data_ndim; + } + + uint64_t output_ndim = data_ndim - 1 + indices_ndim; + uint64_t *output_shape = new uint64_t[output_ndim]; + for (uint64_t i = 0; i < axis; ++i) { + output_shape[i] = input->shape[i]; + } + for (uint64_t i = 0; i < indices_ndim; ++i) { + output_shape[axis + i] = indices->shape[i]; + } + for (uint64_t i = axis + 1; i < data_ndim; ++i) { + output_shape[i + indices_ndim - 1] = input->shape[i]; + } + + for (uint64_t i = 0; i < output_ndim; ++i) { + if (output_shape[i] != output->shape[i]) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + + uint64_t output_data_size = std::accumulate(output->shape, output->shape + output->ndim, 1ULL, std::multiplies()); + + uint64_t *output_indices = new uint64_t[output_ndim]; + std::fill(output_indices, output_indices + output_ndim, 0); + + uint64_t *input_strides = new uint64_t[data_ndim]; + for (uint64_t i = 0; i < data_ndim; ++i) { + input_strides[i] = input->strides[i]; + } + uint64_t *indices_strides = new uint64_t[indices_ndim]; + for (uint64_t i = 0; i < indices_ndim; ++i) { + indices_strides[i] = indices->strides[i]; + } + + *(_GatherDescriptor_t *) desc_ptr = new _GatherDescriptor{ + handle->device, + output->dt, + indices->dt, + data_ndim, + indices_ndim, + output_ndim, + output_data_size, + output_shape, + input_strides, + indices_strides, + output_indices, + axis}; + + return STATUS_SUCCESS; +} + +template +infiniopStatus_t gather_cpu(_GatherDescriptor_t desc, void *output, void const *input, void const *indices) { + auto input_ = reinterpret_cast(input); + auto indices_ = reinterpret_cast(indices); + auto output_ = reinterpret_cast(output); + + const auto &output_indices = desc->output_indices; + for (uint64_t i = 0; i < desc->output_data_size; ++i, incrementOne(output_indices, desc->output_shape, desc->output_ndim)) { + // 下标部分:在 output_indices 中提取 [axis, axis+indices_ndim-1] + uint64_t flat_indices = compactToFlat(output_indices + desc->axis, desc->indices_strides, desc->indices_ndim); + int gather_index = indices_[flat_indices]; + // 计算 data 对应多维下标(前+gather_index+后) + uint64_t *in_multi = new uint64_t[desc->data_ndim]; + for (uint64_t j = 0; j < desc->axis; ++j) + in_multi[j] = output_indices[j]; + in_multi[desc->axis] = gather_index; + for (uint64_t j = desc->axis + 1; j < desc->data_ndim; ++j) + in_multi[j] = output_indices[j + desc->indices_ndim - 1]; + uint64_t in_flat = compactToFlat(in_multi, desc->input_strides, desc->data_ndim); + delete[] in_multi; + output_[i] = input_[in_flat]; + } + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc, + void *output, + void const *input, + void const *indices, + void *stream) { + auto _desc = (_GatherDescriptor_t) desc; + auto dtype = _desc->dtype; + auto indices_type = _desc->indices_type; + + if (dtype == F16) { + if (indices_type == I32) { + return gather_cpu(_desc, output, input, indices); + } else if (indices_type == I64) { + return gather_cpu(_desc, output, input, indices); + } + } else if (dtype == F32) { + if (indices_type == I32) { + return gather_cpu(_desc, output, input, indices); + } else if (indices_type == I64) { + return gather_cpu(_desc, output, input, indices); + } + } + + return STATUS_BAD_TENSOR_DTYPE; +} + +__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc) { + auto _desc = (_GatherDescriptor_t) desc; + delete[] _desc->output_shape; + delete[] _desc->input_strides; + delete[] _desc->indices_strides; + delete[] _desc->output_indices; + + delete (_GatherDescriptor_t) desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/reduce/cpu/reduce_cpu.cc b/src/ops/reduce/cpu/reduce_cpu.cc new file mode 100644 index 00000000..d05b4e98 --- /dev/null +++ b/src/ops/reduce/cpu/reduce_cpu.cc @@ -0,0 +1,263 @@ +#include "reduce_cpu.h" +#include "../../utils.h" +#include +#include +#include +#include +#include + +inline uint64_t getTotalSize(const uint64_t *arr, uint64_t ndim) { + return std::accumulate(arr, arr + ndim, 1ULL, std::multiplies()); +} + +infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t handle, + ReduceCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims, + int reduce_type) { + uint64_t x_ndim = x->ndim; + uint64_t y_ndim = y->ndim; + if (n_axes <= 0 || n_axes > x_ndim) { + return STATUS_BAD_PARAM; + } + // shape + if (keep_dims) { + if (x_ndim != y_ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + } else { + if (x_ndim - n_axes != y_ndim) { + return STATUS_BAD_TENSOR_SHAPE; + } + } + if (y->dt != F16 && y->dt != F32) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (y->dt != x->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (!is_contiguous(y) || !is_contiguous(x)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (reduce_type < 0 || reduce_type > 2) { + return STATUS_BAD_PARAM; + } + // axes -r ~ r-1 + for (int i = 0; i < n_axes; i++) { + if (axes[i] < -int64_t(x_ndim) || axes[i] >= int64_t(x_ndim)) { + return STATUS_BAD_PARAM; + } + } + + const auto x_size = getTotalSize(x->shape, x->ndim); + const auto y_size = getTotalSize(y->shape, y->ndim); + + uint64_t *x_shape = new uint64_t[x_ndim]; + uint64_t *y_shape = new uint64_t[y_ndim]; + memcpy(x_shape, x->shape, x_ndim * sizeof(uint64_t)); + memcpy(y_shape, y->shape, y_ndim * sizeof(uint64_t)); + + int64_t *axes_ = new int64_t[n_axes]; + for (int i = 0; i < n_axes; i++) { + axes_[i] = axes[i] >= 0 ? axes[i] : axes[i] + x_ndim; + } + + *desc_ptr = new ReduceCpuDescriptor{ + DevCpu, + y->dt, + x->ndim, + y_size, + x_size, + x_shape, + y_shape, + axes_, + n_axes, + keep_dims, + reduce_type}; + + return STATUS_SUCCESS; +} + +infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc) { + delete[] desc->x_shape; + delete[] desc->y_shape; + delete[] desc->axes; + delete desc; + return STATUS_SUCCESS; +} + +inline bool isOnReduceAxis(int64_t dim, int64_t const *axes, uint64_t n_axes) { + for (size_t i = 0; i < n_axes; ++i) { + if (dim == axes[i]) { + return true; + } + } + return false; +} + +uint64_t flattenIndex(uint64_t const *indices, uint64_t const *shape, uint64_t ndim) { + uint64_t idx = 0; + uint64_t stride = 1; + for (int64_t i = ndim - 1; i >= 0; --i) { + idx += indices[i] * stride; + stride *= shape[i]; + } + return idx; +} + +template +void generateReduceIndices(uint64_t dim, uint64_t *curr_index, uint64_t const *x_shape, + int64_t const *axes, uint64_t n_axes, uint64_t ndim, + std::vector &flat_indices) { + if (dim == ndim) { + flat_indices.push_back(flattenIndex(curr_index, x_shape, ndim)); + return; + } + + if (isOnReduceAxis(dim, axes, n_axes)) { + for (uint64_t i = 0; i < x_shape[dim]; ++i) { + curr_index[dim] = i; + generateReduceIndices(dim + 1, curr_index, x_shape, axes, n_axes, ndim, flat_indices); + } + } else { + generateReduceIndices(dim + 1, curr_index, x_shape, axes, n_axes, ndim, flat_indices); + } +} + +template +void getReduceIndices(uint64_t y_idx, uint64_t const *y_shape, uint64_t const *x_shape, + int64_t const *axes, uint64_t n_axes, uint64_t ndim, uint64_t y_ndim, + bool keep_dims, std::vector &flat_indices) { + // 将y_idx转换为多维索引 + std::vector y_indices(y_ndim, 0); + uint64_t temp = y_idx; + for (int64_t i = y_ndim - 1; i >= 0; --i) { + y_indices[i] = temp % y_shape[i]; + temp /= y_shape[i]; + } + + // 将y的多维索引映射到x的多维索引 + std::vector x_indices(ndim, 0); + uint64_t y_dim = 0; + + for (uint64_t x_dim = 0; x_dim < ndim; ++x_dim) { + if (isOnReduceAxis(x_dim, axes, n_axes)) { + if (keep_dims) { + y_dim++; + } + x_indices[x_dim] = 0; + } else { + x_indices[x_dim] = y_indices[y_dim++]; + } + } + + flat_indices.clear(); + generateReduceIndices(0, x_indices.data(), x_shape, axes, n_axes, ndim, flat_indices); +} + +template +T performReduce(T const *x, const std::vector &indices, int reduce_type) { + if (indices.empty()) { + return 0; + } + + T result; + switch (reduce_type) { + case 0:// Max + result = std::numeric_limits::lowest(); + for (uint64_t idx : indices) { + result = std::max(result, x[idx]); + } + break; + case 1:// Min + result = std::numeric_limits::max(); + for (uint64_t idx : indices) { + result = std::min(result, x[idx]); + } + break; + case 2:// Mean + result = 0; + for (uint64_t idx : indices) { + result += x[idx]; + } + result /= static_cast(indices.size()); + break; + default: + result = 0; + } + + return result; +} + +template<> +uint16_t performReduce(uint16_t const *x, const std::vector &indices, int reduce_type) { + if (indices.empty()) { + return 0; + } + + float result; + switch (reduce_type) { + case 0:// Max + result = -std::numeric_limits::max(); + for (uint64_t idx : indices) { + result = std::max(result, f16_to_f32(x[idx])); + } + break; + case 1:// Min + result = std::numeric_limits::max(); + for (uint64_t idx : indices) { + result = std::min(result, f16_to_f32(x[idx])); + } + break; + case 2:// Mean + result = 0; + for (uint64_t idx : indices) { + result += f16_to_f32(x[idx]); + } + result /= static_cast(indices.size()); + break; + default: + result = 0; + } + + return f32_to_f16(result); +} + +template +infiniopStatus_t reduce_cpu(ReduceCpuDescriptor_t desc, void *y, void const *x) { + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + + std::vector reduce_indices; + +#pragma omp parallel for private(reduce_indices) + for (uint64_t i = 0; i < desc->y_size; ++i) { + // 获取当前输出索引对应的所有输入索引 + getReduceIndices(i, desc->y_shape, desc->x_shape, desc->axes, desc->n_axes, + desc->ndim, desc->keep_dims ? desc->ndim : desc->ndim - desc->n_axes, + desc->keep_dims, reduce_indices); + + // reduce + y_[i] = performReduce(x_, reduce_indices, desc->reduce_type); + if constexpr (std::is_same::value) { + y_[i] = performReduce(x_, reduce_indices, desc->reduce_type); + } else { + y_[i] = performReduce(x_, reduce_indices, desc->reduce_type); + } + } + + return STATUS_SUCCESS; +} + + +infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, void *y, void const *x, void *stream) { + if (desc->dt == F16) { + return reduce_cpu(desc, y, x); + } else if (desc->dt == F32) { + return reduce_cpu(desc, y, x); + } + return STATUS_BAD_TENSOR_DTYPE; +} diff --git a/src/ops/reduce/cpu/reduce_cpu.h b/src/ops/reduce/cpu/reduce_cpu.h new file mode 100644 index 00000000..7c0cb862 --- /dev/null +++ b/src/ops/reduce/cpu/reduce_cpu.h @@ -0,0 +1,44 @@ +#ifndef __CPU_REDUCE_H__ +#define __CPU_REDUCE_H__ + +#include "../../../devices/cpu/common_cpu.h" +#include "operators.h" +#include +#include +#include +#include +#include + +struct ReduceCpuDescriptor { + Device device; + DataLayout dt; + uint64_t ndim; + uint64_t y_size; + uint64_t x_size; + uint64_t *x_shape; + uint64_t *y_shape; + int64_t *axes; + uint64_t n_axes; + int keep_dims; + int reduce_type; // 0: max, 1: min, 2: mean +}; + +typedef struct ReduceCpuDescriptor *ReduceCpuDescriptor_t; + +infiniopStatus_t cpuCreateReduceDescriptor(infiniopHandle_t handle, + ReduceCpuDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims, + int reduce_type); + +infiniopStatus_t cpuReduce(ReduceCpuDescriptor_t desc, + void *y, + void const *x, + void *stream); + +infiniopStatus_t cpuDestroyReduceDescriptor(ReduceCpuDescriptor_t desc); + +#endif diff --git a/src/ops/reduce/operator.cc b/src/ops/reduce/operator.cc new file mode 100644 index 00000000..550e7d92 --- /dev/null +++ b/src/ops/reduce/operator.cc @@ -0,0 +1,36 @@ +#include "../utils.h" +#include "operators.h" +#include "reduce.h" + +#ifdef ENABLE_CPU +#include "cpu/reduce_cpu.h" +#endif + +__C infiniopStatus_t infiniopCreateReduceDescriptor( + infiniopHandle_t handle, + infiniopReduceDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims, + int reduce_type) { + if (handle->device == DevCpu) { + return cpuCreateReduceDescriptor(handle, (ReduceCpuDescriptor_t *) desc_ptr, y, x, axes, n_axes, keep_dims, reduce_type); + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, void *y, void const *x, void *stream) { + if (desc->device == DevCpu) { + return cpuReduce((ReduceCpuDescriptor_t) desc, y, x, stream); + } + return STATUS_BAD_DEVICE; +} + +__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc) { + if (desc->device == DevCpu) { + return cpuDestroyReduceDescriptor((ReduceCpuDescriptor_t) desc); + } + return STATUS_BAD_DEVICE; +} diff --git a/src/ops/reduce/reduce.h b/src/ops/reduce/reduce.h new file mode 100644 index 00000000..ea0ca928 --- /dev/null +++ b/src/ops/reduce/reduce.h @@ -0,0 +1,24 @@ +#ifndef REDUCE_H +#define REDUCE_H + +#include "export.h" +#include "operators.h" + +typedef struct ReduceDescriptor { + Device device; +} ReduceDescriptor; +typedef ReduceDescriptor *infiniopReduceDescriptor_t; + +__C infiniopStatus_t infiniopCreateReduceDescriptor(infiniopHandle_t handle, + infiniopReduceDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims, + int reduce_type); + +__C infiniopStatus_t infiniopReduce(infiniopReduceDescriptor_t desc, void *y, void const *x, void *stream); + +__C infiniopStatus_t infiniopDestroyReduceDescriptor(infiniopReduceDescriptor_t desc); +#endif diff --git a/src/ops/reduce_max/operator.cc b/src/ops/reduce_max/operator.cc new file mode 100644 index 00000000..019acd41 --- /dev/null +++ b/src/ops/reduce_max/operator.cc @@ -0,0 +1,41 @@ +#include "../reduce/reduce.h" +#include "../utils.h" +#include "export.h" +#include "ops/reduce_max/reduce_max.h" + +struct _ReduceMaxDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceMaxDescriptor *_ReduceMaxDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMaxDescriptor(infiniopHandle_t handle, + infiniopReduceMaxDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims) { + infiniopReduceDescriptor_t reduce_desc; + // reduce_max: 0 + CHECK_STATUS(infiniopCreateReduceDescriptor(handle, &reduce_desc, y, x, axes, n_axes, keep_dims, 0), STATUS_SUCCESS); + + *(_ReduceMaxDescriptor_t *) desc_ptr = new _ReduceMaxDescriptor{ + handle->device, + reduce_desc}; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReduceMax(infiniopReduceMaxDescriptor_t desc, void *y, void const *x, void *stream) { + auto _desc = (_ReduceMaxDescriptor_t) desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, y, x, stream), STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceMaxDescriptor(infiniopReduceMaxDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyReduceDescriptor(((_ReduceMaxDescriptor_t) desc)->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/reduce_mean/operator.cc b/src/ops/reduce_mean/operator.cc new file mode 100644 index 00000000..90f867d2 --- /dev/null +++ b/src/ops/reduce_mean/operator.cc @@ -0,0 +1,41 @@ +#include "../reduce/reduce.h" +#include "../utils.h" +#include "export.h" +#include "ops/reduce_mean/reduce_mean.h" + +struct _ReduceMeanDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceMeanDescriptor *_ReduceMeanDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMeanDescriptor(infiniopHandle_t handle, + infiniopReduceMeanDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims) { + infiniopReduceDescriptor_t reduce_desc; + // reduce_mean: 2 + CHECK_STATUS(infiniopCreateReduceDescriptor(handle, &reduce_desc, y, x, axes, n_axes, keep_dims, 2), STATUS_SUCCESS); + + *(_ReduceMeanDescriptor_t *) desc_ptr = new _ReduceMeanDescriptor{ + handle->device, + reduce_desc}; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReduceMean(infiniopReduceMeanDescriptor_t desc, void *y, void const *x, void *stream) { + auto _desc = (_ReduceMeanDescriptor_t) desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, y, x, stream), STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceMeanDescriptor(infiniopReduceMeanDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyReduceDescriptor(((_ReduceMeanDescriptor_t) desc)->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/reduce_min/operator.cc b/src/ops/reduce_min/operator.cc new file mode 100644 index 00000000..ae1e79e3 --- /dev/null +++ b/src/ops/reduce_min/operator.cc @@ -0,0 +1,41 @@ +#include "../reduce/reduce.h" +#include "../utils.h" +#include "export.h" +#include "ops/reduce_min/reduce_min.h" + +struct _ReduceMinDescriptor { + Device device; + infiniopReduceDescriptor_t reduce_desc; +}; + +typedef struct _ReduceMinDescriptor *_ReduceMinDescriptor_t; + +__C __export infiniopStatus_t infiniopCreateReduceMinDescriptor(infiniopHandle_t handle, + infiniopReduceMinDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t y, + infiniopTensorDescriptor_t x, + int64_t const *axes, + uint64_t n_axes, + int keep_dims) { + infiniopReduceDescriptor_t reduce_desc; + // reduce_min: 1 + CHECK_STATUS(infiniopCreateReduceDescriptor(handle, &reduce_desc, y, x, axes, n_axes, keep_dims, 1), STATUS_SUCCESS); + + *(_ReduceMinDescriptor_t *) desc_ptr = new _ReduceMinDescriptor{ + handle->device, + reduce_desc}; + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopReduceMin(infiniopReduceMinDescriptor_t desc, void *y, void const *x, void *stream) { + auto _desc = (_ReduceMinDescriptor_t) desc; + CHECK_STATUS(infiniopReduce(_desc->reduce_desc, y, x, stream), STATUS_SUCCESS); + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopDestroyReduceMinDescriptor(infiniopReduceMinDescriptor_t desc) { + CHECK_STATUS(infiniopDestroyReduceDescriptor(((_ReduceMinDescriptor_t) desc)->reduce_desc), STATUS_SUCCESS); + delete desc; + return STATUS_SUCCESS; +} diff --git a/src/ops/where/operator.cc b/src/ops/where/operator.cc new file mode 100644 index 00000000..e260fff4 --- /dev/null +++ b/src/ops/where/operator.cc @@ -0,0 +1,143 @@ +#include "../../devices/cpu/common_cpu.h" +#include "../utils.h" +#include "data_type.h" +#include "operators.h" +#include "ops/where/where.h" +#include "status.h" +#include "tensor/tensor_descriptor.h" +#include +#include +#include +#include +#include +#include + +struct _WhereDescriptor { + Device device; + DT dtype; + uint64_t ndim; + uint64_t output_data_size; + uint64_t const *output_shape; + uint64_t const *condition_strides; + uint64_t const *x_strides; + uint64_t const *y_strides; + uint64_t *output_indices; +}; + +typedef struct _WhereDescriptor *_WhereDescriptor_t; + +inline void incrementOne(uint64_t *indices, uint64_t const *shape, uint64_t ndim) { + for (int64_t i = ndim - 1; i >= 0; --i) { + if (++indices[i] != shape[i]) { + return; + } + indices[i] = 0; + } +} + +inline uint64_t compactToFlat(uint64_t const *indices, uint64_t const *strides, uint64_t ndim) { + return std::inner_product(indices, indices + ndim, strides, uint64_t(0)); +} + +__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle, + infiniopWhereDescriptor_t *desc_ptr, + infiniopTensorDescriptor_t output, + infiniopTensorDescriptor_t condition, + infiniopTensorDescriptor_t x, + infiniopTensorDescriptor_t y) { + if (!isValidBroadcastShape(output, condition) || + !isValidBroadcastShape(output, x) || + !isValidBroadcastShape(output, y)) { + return STATUS_BAD_TENSOR_SHAPE; + } + if (!is_contiguous(x) || !is_contiguous(y) || !is_contiguous(condition) || !is_contiguous(output)) { + return STATUS_BAD_TENSOR_STRIDES; + } + if (condition->dt != U8) { + return STATUS_BAD_TENSOR_DTYPE; + } + if (output->dt != x->dt || output->dt != y->dt) { + return STATUS_BAD_TENSOR_DTYPE; + } + + uint64_t output_data_size = std::accumulate(output->shape, output->shape + output->ndim, 1ULL, std::multiplies()); + uint64_t ndim = output->ndim; + uint64_t *condition_strides = new uint64_t[ndim]; + uint64_t *x_strides = new uint64_t[ndim]; + uint64_t *y_strides = new uint64_t[ndim]; + for (size_t i = 0; i < ndim; i++) { + condition_strides[i] = (i < ndim - condition->ndim || output->shape[i] != condition->shape[i + condition->ndim - ndim]) ? 0 : condition->strides[i + condition->ndim - ndim]; + x_strides[i] = (i < ndim - x->ndim || output->shape[i] != x->shape[i + x->ndim - ndim]) ? 0 : x->strides[i + x->ndim - ndim]; + y_strides[i] = (i < ndim - y->ndim || output->shape[i] != y->shape[i + y->ndim - ndim]) ? 0 : y->strides[i + y->ndim - ndim]; + } + + uint64_t *output_indices = new uint64_t[ndim]; + std::fill(output_indices, output_indices + ndim, 0); + uint64_t *output_shape = new uint64_t[ndim]; + std::copy(output->shape, output->shape + ndim, output_shape); + + *(_WhereDescriptor_t *) desc_ptr = new _WhereDescriptor{ + handle->device, + output->dt, + ndim, + output_data_size, + output_shape, + condition_strides, + x_strides, + y_strides, + output_indices}; + + return STATUS_SUCCESS; +} + +template +infiniopStatus_t where_cpu(_WhereDescriptor_t desc, void *output, void const *condition, void const *x, void const *y) { + auto x_ = reinterpret_cast(x); + auto y_ = reinterpret_cast(y); + auto output_ = reinterpret_cast(output); + auto condition_ = reinterpret_cast(condition); + const auto &indices = desc->output_indices; + + for (uint64_t i = 0; i < desc->output_data_size; ++i, incrementOne(indices, desc->output_shape, desc->ndim)) { + auto x_index = compactToFlat(indices, desc->x_strides, desc->ndim); + auto y_index = compactToFlat(indices, desc->y_strides, desc->ndim); + auto condition_index = compactToFlat(indices, desc->condition_strides, desc->ndim); + + if constexpr (std::is_same::value) { + output_[i] = f32_to_f16(f16_to_f32(condition_[condition_index] ? x_[x_index] : y_[y_index])); + } else { + output_[i] = condition_[condition_index] ? x_[x_index] : y_[y_index]; + } + } + + return STATUS_SUCCESS; +} + +__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc, + void *output, + void const *condition, + void const *x, + void const *y, + void *stream) { + auto _desc = (_WhereDescriptor_t) desc; + auto dtype = _desc->dtype; + + if (dtype == F16) { + return where_cpu(_desc, output, condition, x, y); + } else if (dtype == F32) { + return where_cpu(_desc, output, condition, x, y); + } + + return STATUS_BAD_TENSOR_DTYPE; +} + +__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc) { + delete[] ((_WhereDescriptor_t) desc)->output_shape; + delete[] ((_WhereDescriptor_t) desc)->condition_strides; + delete[] ((_WhereDescriptor_t) desc)->x_strides; + delete[] ((_WhereDescriptor_t) desc)->y_strides; + delete[] ((_WhereDescriptor_t) desc)->output_indices; + + delete (_WhereDescriptor_t) desc; + return STATUS_SUCCESS; +}