Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/infini_operators.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,23 @@
#include "ops/attention/attention.h"
#include "ops/avg_pool/avg_pool.h"
#include "ops/causal_softmax/causal_softmax.h"
#include "ops/clip/clip.h"
#include "ops/global_avg_pool/global_avg_pool.h"
#include "ops/expand/expand.h"
#include "ops/gather/gather.h"
#include "ops/gemm/gemm.h"
#include "ops/conv/conv.h"
#include "ops/matmul/matmul.h"
#include "ops/max_pool/max_pool.h"
#include "ops/mlp/mlp.h"
#include "ops/random_sample/random_sample.h"
#include "ops/rearrange/rearrange.h"
#include "ops/reduce_max/reduce_max.h"
#include "ops/reduce_mean/reduce_mean.h"
#include "ops/reduce_min/reduce_min.h"
#include "ops/relu/relu.h"
#include "ops/rms_norm/rms_norm.h"
#include "ops/rotary_embedding/rotary_embedding.h"
#include "ops/swiglu/swiglu.h"
#include "ops/where/where.h"
#include "tensor/tensor_descriptor.h"
27 changes: 27 additions & 0 deletions include/ops/clip/clip.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#ifndef CLIP_H
#define CLIP_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ClipDescriptor {
Device device;
} ClipDescriptor;

typedef ClipDescriptor *infiniopClipDescriptor_t;

__C __export infiniopStatus_t infiniopCreateClipDescriptor(infiniopHandle_t handle,
infiniopClipDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t input,
float min_value,
float max_value);

__C __export infiniopStatus_t infiniopClip(infiniopClipDescriptor_t desc,
void *output,
void const *input,
void *stream);

__C __export infiniopStatus_t infiniopDestroyClipDescriptor(infiniopClipDescriptor_t desc);

#endif
28 changes: 28 additions & 0 deletions include/ops/gather/gather.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#ifndef GATHER_H
#define GATHER_H

#include "../../export.h"
#include "../../operators.h"

typedef struct GatherDescriptor {
Device device;
} GatherDescriptor;

typedef GatherDescriptor *infiniopGatherDescriptor_t;

__C __export infiniopStatus_t infiniopCreateGatherDescriptor(infiniopHandle_t handle,
infiniopGatherDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t data,
infiniopTensorDescriptor_t indices,
int64_t axis);

__C __export infiniopStatus_t infiniopGather(infiniopGatherDescriptor_t desc,
void *output,
void const *data,
void const *indices,
void *stream);

__C __export infiniopStatus_t infiniopDestroyGatherDescriptor(infiniopGatherDescriptor_t desc);

#endif
30 changes: 30 additions & 0 deletions include/ops/reduce_max/reduce_max.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef REDUCE_MAX_H
#define REDUCE_MAX_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ReduceMaxDescriptor {
Device device;
} ReduceMaxDescriptor;

typedef ReduceMaxDescriptor *infiniopReduceMaxDescriptor_t;

__C __export infiniopStatus_t infiniopCreateReduceMaxDescriptor(infiniopHandle_t handle,
infiniopReduceMaxDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t reduced,
infiniopTensorDescriptor_t data,
int64_t *axes,
uint64_t axes_ndim,
bool keepdims,
bool noop_with_empty_axes);

__C __export infiniopStatus_t infiniopGetReduceMaxWorkspaceSize(infiniopReduceMaxDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopReduceMax(infiniopReduceMaxDescriptor_t desc,
void *workspace, uint64_t workspace_size,
void *reduced, void const *data, void *stream);

__C __export infiniopStatus_t infiniopDestroyReduceMaxDescriptor(infiniopReduceMaxDescriptor_t desc);

#endif
30 changes: 30 additions & 0 deletions include/ops/reduce_mean/reduce_mean.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef REDUCE_MEAN_H
#define REDUCE_MEAN_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ReduceMeanDescriptor {
Device device;
} ReduceMeanDescriptor;

typedef ReduceMeanDescriptor *infiniopReduceMeanDescriptor_t;

__C __export infiniopStatus_t infiniopCreateReduceMeanDescriptor(infiniopHandle_t handle,
infiniopReduceMeanDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t reduced,
infiniopTensorDescriptor_t data,
int64_t *axes,
uint64_t axes_ndim,
bool keepdims,
bool noop_with_empty_axes);

__C __export infiniopStatus_t infiniopGetReduceMeanWorkspaceSize(infiniopReduceMeanDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopReduceMean(infiniopReduceMeanDescriptor_t desc,
void *workspace, uint64_t workspace_size,
void *reduced, void const *data, void *stream);

__C __export infiniopStatus_t infiniopDestroyReduceMeanDescriptor(infiniopReduceMeanDescriptor_t desc);

#endif
30 changes: 30 additions & 0 deletions include/ops/reduce_min/reduce_min.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#ifndef REDUCE_MIN_H
#define REDUCE_MIN_H

#include "../../export.h"
#include "../../operators.h"

typedef struct ReduceMinDescriptor {
Device device;
} ReduceMinDescriptor;

typedef ReduceMinDescriptor *infiniopReduceMinDescriptor_t;

__C __export infiniopStatus_t infiniopCreateReduceMinDescriptor(infiniopHandle_t handle,
infiniopReduceMinDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t reduced,
infiniopTensorDescriptor_t data,
int64_t *axes,
uint64_t axes_ndim,
bool keepdims,
bool noop_with_empty_axes);

__C __export infiniopStatus_t infiniopGetReduceMinWorkspaceSize(infiniopReduceMinDescriptor_t desc, uint64_t *size);

__C __export infiniopStatus_t infiniopReduceMin(infiniopReduceMinDescriptor_t desc,
void *workspace, uint64_t workspace_size,
void *reduced, void const *data, void *stream);

__C __export infiniopStatus_t infiniopDestroyReduceMinDescriptor(infiniopReduceMinDescriptor_t desc);

#endif
29 changes: 29 additions & 0 deletions include/ops/where/where.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef WHERE_H
#define WHERE_H

#include "../../export.h"
#include "../../operators.h"

typedef struct WhereDescriptor {
Device device;
} WhereDescriptor;

typedef WhereDescriptor *infiniopWhereDescriptor_t;

__C __export infiniopStatus_t infiniopCreateWhereDescriptor(infiniopHandle_t handle,
infiniopWhereDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t output,
infiniopTensorDescriptor_t condition,
infiniopTensorDescriptor_t x,
infiniopTensorDescriptor_t y);

__C __export infiniopStatus_t infiniopWhere(infiniopWhereDescriptor_t desc,
void *output,
void const *condition,
void const *x,
void const *y,
void *stream);

__C __export infiniopStatus_t infiniopDestroyWhereDescriptor(infiniopWhereDescriptor_t desc);

#endif
160 changes: 160 additions & 0 deletions operatorspy/tests/clip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from ctypes import POINTER, Structure, c_int32, c_void_p, c_float
import ctypes
import sys
import os

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
check_error,
)

from operatorspy.tests.test_utils import get_args
from enum import Enum, auto
import torch


class Inplace(Enum):
OUT_OF_PLACE = auto()
INPLACE_ = auto()


class ClipDescriptor(Structure):
_fields_ = [("device", c_int32)]


infiniopClipDescriptor_t = POINTER(ClipDescriptor)

def clip(x, min_value, max_value):
return torch.clip(x, min=min_value, max=max_value)


def test(
lib,
handle,
torch_device,
o_shape,
i_shape,
min_value=None,
max_value=None,
tensor_dtype=torch.float16,
inplace=Inplace.OUT_OF_PLACE,
):
print(
f"Testing Clip on {torch_device} with o_shape:{o_shape} i_shape:{i_shape} \
min_value:{min_value} max_value:{max_value} \
dtype:{tensor_dtype} inplace: {inplace.name}"
)
if o_shape != i_shape:
print("Unsupported test: unmatched shapes for input and output")
return

input = torch.rand(i_shape, dtype=tensor_dtype).to(torch_device)
output = torch.rand(o_shape, dtype=tensor_dtype).to(torch_device) if inplace == Inplace.OUT_OF_PLACE else input

if min_value != None:
min_value = torch.tensor(min_value, dtype=tensor_dtype, device=torch_device)
else:
min_value = torch.tensor(float("-inf"), dtype=tensor_dtype, device=torch_device)
if max_value != None:
max_value = torch.tensor(max_value, dtype=tensor_dtype, device=torch_device)
else:
max_value = torch.tensor(float("inf"), dtype=tensor_dtype, device=torch_device)

descriptor = infiniopClipDescriptor_t()
i_tensor = to_tensor(input, lib)
o_tensor = to_tensor(output, lib) if inplace == Inplace.OUT_OF_PLACE else (i_tensor)

ans = clip(input, min_value, max_value)

check_error(
lib.infiniopCreateClipDescriptor(
handle,
ctypes.byref(descriptor),
o_tensor.descriptor,
i_tensor.descriptor,
min_value.item(),
max_value.item(),
)
)

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
o_tensor.descriptor.contents.invalidate()
i_tensor.descriptor.contents.invalidate()

check_error(
lib.infiniopClip(descriptor,
o_tensor.data,
i_tensor.data,
None)
)
# print(f" min:{min_value}, max:{max_value}, input:{input}, ans:{ans}, output:{output},")
assert torch.allclose(output, ans, atol=0, rtol=0)
check_error(lib.infiniopDestroyClipDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for o_shape, i_shape, min_value, max_value, inplace in test_cases:
test(lib, handle, "cpu", o_shape, i_shape, min_value, max_value, tensor_dtype=torch.float16, inplace=inplace)
test(lib, handle, "cpu", o_shape, i_shape, min_value, max_value, tensor_dtype=torch.float32, inplace=inplace)
destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
# o_shape, i_shape, min, max, inplace
((1, 3), (1, 3), None, None, Inplace.OUT_OF_PLACE),
((3, 3), (3, 3), 1.0, 2.0, Inplace.OUT_OF_PLACE),
((), (), None, None, Inplace.OUT_OF_PLACE),
((2, 20, 3), (2, 20, 3), 0., -2.0, Inplace.INPLACE_),

((32, 20, 512), (32, 20, 512), -0.3, 0.3, Inplace.INPLACE_),
((32, 256, 112, 112), (32, 256, 112, 112), -0.1, 0.1, Inplace.OUT_OF_PLACE),
# ((32, 150, 5120), (32, 150, 5120), None, None, Inplace.OUT_OF_PLACE),

((2, 4, 3), (2, 1, 3), None, None, Inplace.OUT_OF_PLACE),
((2, 3, 4, 5), (2, 3, 4, 5), -0.1, 0.1, Inplace.OUT_OF_PLACE),
((3, 2, 4, 5), (4, 5), None, None, Inplace.OUT_OF_PLACE),
]

args = get_args()
lib = open_lib()
lib.infiniopCreateClipDescriptor.restype = c_int32
lib.infiniopCreateClipDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopClipDescriptor_t),
infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t,
c_float,
c_float,
]
lib.infiniopClip.restype = c_int32
lib.infiniopClip.argtypes = [
infiniopClipDescriptor_t,
c_void_p,
c_void_p,
#
c_void_p,
]
lib.infiniopDestroyClipDescriptor.restype = c_int32
lib.infiniopDestroyClipDescriptor.argtypes = [
infiniopClipDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
# if args.cuda:
# test_cuda(lib, test_cases)
# if args.bang:
# test_bang(lib, test_cases)
if not (args.cpu or args.cuda or args.bang):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
Loading