Skip to content

Commit ae67c74

Browse files
tianyuxbeargongchensu
authored andcommitted
issue/456/feat: add equal operator
1 parent 3959c94 commit ae67c74

File tree

15 files changed

+825
-3
lines changed

15 files changed

+825
-3
lines changed

include/infiniop.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "infiniop/ops/clip.h"
99
#include "infiniop/ops/conv.h"
1010
#include "infiniop/ops/dequantize_awq.h"
11+
#include "infiniop/ops/equal.h"
1112
#include "infiniop/ops/gemm.h"
1213
#include "infiniop/ops/mul.h"
1314
#include "infiniop/ops/random_sample.h"

include/infiniop/ops/equal.h

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#ifndef __INFINIOP_EQUAL_API_H__
2+
#define __INFINIOP_EQUAL_API_H__
3+
4+
#include "../operator_descriptor.h"
5+
6+
typedef struct InfiniopDescriptor *infiniopEqualDescriptor_t;
7+
8+
__C __export infiniStatus_t infiniopCreateEqualDescriptor(infiniopHandle_t handle,
9+
infiniopEqualDescriptor_t *desc_ptr,
10+
infiniopTensorDescriptor_t c,
11+
infiniopTensorDescriptor_t a,
12+
infiniopTensorDescriptor_t b);
13+
14+
__C __export infiniStatus_t infiniopGetEqualWorkspaceSize(infiniopEqualDescriptor_t desc, size_t *size);
15+
16+
__C __export infiniStatus_t infiniopEqual(infiniopEqualDescriptor_t desc,
17+
void *workspace,
18+
size_t workspace_size,
19+
void *c,
20+
const void *a,
21+
const void *b,
22+
void *stream);
23+
24+
__C __export infiniStatus_t infiniopDestroyEqualDescriptor(infiniopEqualDescriptor_t desc);
25+
26+
#endif

src/infiniop-test/include/ops.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ DECLARE_INFINIOP_TEST(add)
1616
DECLARE_INFINIOP_TEST(causal_softmax)
1717
DECLARE_INFINIOP_TEST(rearrange)
1818
DECLARE_INFINIOP_TEST(sub)
19+
DECLARE_INFINIOP_TEST(equal)
1920

2021
#define REGISTER_INFINIOP_TEST(name) \
2122
{ \
@@ -43,6 +44,7 @@ DECLARE_INFINIOP_TEST(sub)
4344
REGISTER_INFINIOP_TEST(causal_softmax) \
4445
REGISTER_INFINIOP_TEST(rearrange) \
4546
REGISTER_INFINIOP_TEST(sub) \
47+
REGISTER_INFINIOP_TEST(equal) \
4648
}
4749

4850
namespace infiniop_test {
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
#include "ops.hpp"
2+
#include "utils.hpp"
3+
#include <infinirt.h>
4+
#include <iomanip>
5+
#include <iostream>
6+
7+
namespace infiniop_test::equal {
8+
struct Test::Attributes {
9+
std::shared_ptr<Tensor> a;
10+
std::shared_ptr<Tensor> b;
11+
std::shared_ptr<Tensor> c;
12+
std::shared_ptr<Tensor> ans;
13+
};
14+
15+
std::shared_ptr<Test> Test::build(
16+
std::unordered_map<std::string, std::vector<uint8_t>> attributes,
17+
std::unordered_map<std::string, std::shared_ptr<Tensor>> tensors,
18+
double rtol, double atol) {
19+
auto test = std::shared_ptr<Test>(new Test(rtol, atol));
20+
test->_attributes = new Attributes();
21+
if (tensors.find("a") == tensors.end()
22+
|| tensors.find("b") == tensors.end()
23+
|| tensors.find("c") == tensors.end()
24+
|| tensors.find("ans") == tensors.end()) {
25+
throw std::runtime_error("Invalid Test");
26+
}
27+
28+
test->_attributes->a = tensors["a"];
29+
test->_attributes->b = tensors["b"];
30+
test->_attributes->c = tensors["c"];
31+
test->_attributes->ans = tensors["ans"];
32+
33+
return test;
34+
}
35+
36+
std::shared_ptr<infiniop_test::Result> Test::run(
37+
infiniopHandle_t handle, infiniDevice_t device, int device_id, size_t warm_ups, size_t iterations) {
38+
infiniopEqualDescriptor_t op_desc;
39+
auto a = _attributes->a->to(device, device_id);
40+
auto b = _attributes->b->to(device, device_id);
41+
auto c = _attributes->c->to(device, device_id);
42+
CHECK_OR(infiniopCreateEqualDescriptor(handle, &op_desc,
43+
c->desc(),
44+
a->desc(),
45+
b->desc()),
46+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to create op descriptor."));
47+
size_t workspace_size;
48+
CHECK_OR(infiniopGetEqualWorkspaceSize(op_desc, &workspace_size),
49+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to get workspace size."));
50+
void *workspace;
51+
CHECK_OR(infinirtMalloc(&workspace, workspace_size),
52+
return TEST_FAILED(OP_CREATION_FAILED, "Failed to allocate workspace."));
53+
CHECK_OR(infiniopEqual(op_desc, workspace, workspace_size,
54+
c->data(),
55+
a->data(),
56+
b->data(),
57+
nullptr),
58+
return TEST_FAILED(OP_EXECUTION_FAILED, "Failed during execution."));
59+
60+
try {
61+
allClose(c, _attributes->ans, _rtol, _atol);
62+
} catch (const std::exception &e) {
63+
return TEST_FAILED(RESULT_INCORRECT, e.what());
64+
}
65+
66+
double elapsed_time = 0.;
67+
68+
elapsed_time = benchmark(
69+
[=]() {
70+
infiniopEqual(
71+
op_desc, workspace, workspace_size,
72+
c->data(),
73+
a->data(),
74+
b->data(),
75+
nullptr);
76+
},
77+
warm_ups, iterations);
78+
79+
return TEST_PASSED(elapsed_time);
80+
}
81+
82+
std::vector<std::string> Test::attribute_names() {
83+
return {};
84+
}
85+
86+
std::vector<std::string> Test::tensor_names() {
87+
return {"a", "b", "c", "ans"};
88+
}
89+
90+
std::vector<std::string> Test::output_names() {
91+
return {"c"};
92+
}
93+
94+
std::string Test::toString() const {
95+
std::ostringstream oss;
96+
oss << op_name() << std::endl;
97+
oss << "- a: " << _attributes->a->info() << std::endl;
98+
oss << "- b: " << _attributes->b->info() << std::endl;
99+
oss << "- c: " << _attributes->c->info() << std::endl;
100+
oss << std::scientific << std::setprecision(2);
101+
oss << "- rtol=" << _rtol << ", atol=" << _atol << std::endl;
102+
return oss.str();
103+
}
104+
105+
Test::~Test() {
106+
delete _attributes;
107+
}
108+
109+
} // namespace infiniop_test::equal
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#include "equal_cpu.h"
2+
#include "infinicore.h"
3+
4+
namespace op::equal::cpu {
5+
6+
Descriptor::~Descriptor() = default;
7+
8+
infiniStatus_t Descriptor::create(
9+
infiniopHandle_t handle_,
10+
Descriptor **desc_ptr,
11+
infiniopTensorDescriptor_t out_desc,
12+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
13+
14+
auto handle = reinterpret_cast<device::cpu::Handle *>(handle_);
15+
16+
const auto &a_desc = input_desc_vec.at(0);
17+
const auto &b_desc = input_desc_vec.at(1);
18+
const auto &c_shape = out_desc->shape();
19+
const auto &a_shape = a_desc->shape();
20+
const auto &b_shape = b_desc->shape();
21+
22+
auto dtype = a_desc->dtype();
23+
24+
CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
25+
26+
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
27+
28+
// create CPU elementwise descriptor
29+
CREATE_ELEMENTWISE_CPU_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec);
30+
31+
return INFINI_STATUS_SUCCESS;
32+
}
33+
34+
infiniStatus_t Descriptor::calculate(
35+
void *workspace,
36+
size_t workspace_size,
37+
void *output,
38+
std::vector<const void *> inputs,
39+
void *stream) const {
40+
41+
switch (_dtype) {
42+
case INFINI_DTYPE_BOOL:
43+
return _device_info->calculate<EqualOp, bool, bool, bool>(_info, output, inputs, stream);
44+
case INFINI_DTYPE_I8:
45+
return _device_info->calculate<EqualOp, bool, int8_t, int8_t>(_info, output, inputs, stream);
46+
case INFINI_DTYPE_I16:
47+
return _device_info->calculate<EqualOp, bool, int16_t, int16_t>(_info, output, inputs, stream);
48+
case INFINI_DTYPE_I32:
49+
return _device_info->calculate<EqualOp, bool, int32_t, int32_t>(_info, output, inputs, stream);
50+
case INFINI_DTYPE_I64:
51+
return _device_info->calculate<EqualOp, bool, int64_t, int64_t>(_info, output, inputs, stream);
52+
case INFINI_DTYPE_BF16:
53+
return _device_info->calculate<EqualOp, bool, bf16_t, bf16_t>(_info, output, inputs, stream);
54+
case INFINI_DTYPE_F16:
55+
return _device_info->calculate<EqualOp, bool, fp16_t, fp16_t>(_info, output, inputs, stream);
56+
case INFINI_DTYPE_F32:
57+
return _device_info->calculate<EqualOp, bool, float, float>(_info, output, inputs, stream);
58+
case INFINI_DTYPE_F64:
59+
return _device_info->calculate<EqualOp, bool, double, double>(_info, output, inputs, stream);
60+
default:
61+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
62+
}
63+
64+
return INFINI_STATUS_SUCCESS;
65+
}
66+
} // namespace op::equal::cpu
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef __EQUAL_CPU_H__
2+
#define __EQUAL_CPU_H__
3+
4+
#include "../../../elementwise/cpu/elementwise_cpu.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(equal, cpu)
7+
8+
namespace op::equal::cpu {
9+
typedef struct EqualOp {
10+
public:
11+
static constexpr size_t num_inputs = 2;
12+
template <typename Tout, typename Ta, typename Tb>
13+
Tout operator()(const Ta &a, const Tb &b) const {
14+
if constexpr (!std::is_same_v<Ta, Tb>) {
15+
printf("Ta and Tb must be the same type!\n");
16+
std::abort();
17+
}
18+
if constexpr (std::is_same_v<Ta, bf16_t> || std::is_same_v<Ta, fp16_t>) {
19+
float f_a = utils::cast<float, Ta>(a);
20+
float f_b = utils::cast<float, Ta>(b);
21+
return f_a == f_b;
22+
} else {
23+
return a == b;
24+
}
25+
}
26+
} EqualOp;
27+
} // namespace op::equal::cpu
28+
29+
#endif // __EQUAL_CPU_H__
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#ifndef __EQUAL_CUDA_H__
2+
#define __EQUAL_CUDA_H__
3+
4+
namespace op::equal::cuda {
5+
typedef struct EqualOp {
6+
public:
7+
static constexpr size_t num_inputs = 2;
8+
template <typename Tout, typename Ta, typename Tb>
9+
__device__ __forceinline__ Tout operator()(const Ta &a, const Tb &b) const {
10+
if constexpr (!std::is_same_v<Ta, Tb>) {
11+
printf("Ta and Tb must be the same type!\n");
12+
std::abort();
13+
}
14+
return a == b;
15+
}
16+
} EqualOp;
17+
} // namespace op::equal::cuda
18+
19+
#endif // __EQUAL_CUDA_H__
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
#ifndef __EQUAL_METAX_API_H__
2+
#define __EQUAL_METAX_API_H__
3+
4+
#include "../../../elementwise/metax/elementwise_metax_api.h"
5+
6+
ELEMENTWISE_DESCRIPTOR(equal, metax)
7+
8+
#endif // __EQUAL_METAX_API_H__
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
#include "equal_metax.h"
2+
3+
#include "../../../elementwise/metax/elementwise_metax.h"
4+
5+
#include "../cuda/kernel.cuh"
6+
7+
namespace op::equal::metax {
8+
9+
Descriptor::~Descriptor() = default;
10+
11+
infiniStatus_t Descriptor::create(
12+
infiniopHandle_t handle_,
13+
Descriptor **desc_ptr,
14+
infiniopTensorDescriptor_t out_desc,
15+
std::vector<infiniopTensorDescriptor_t> input_desc_vec) {
16+
17+
auto handle = reinterpret_cast<device::metax::Handle *>(handle_);
18+
19+
const auto &a_desc = input_desc_vec.at(0);
20+
const auto &b_desc = input_desc_vec.at(1);
21+
const auto &c_shape = out_desc->shape();
22+
const auto &a_shape = a_desc->shape();
23+
const auto &b_shape = b_desc->shape();
24+
25+
auto dtype = a_desc->dtype();
26+
27+
CHECK_DTYPE(dtype, INFINI_DTYPE_BOOL, INFINI_DTYPE_I8, INFINI_DTYPE_I16, INFINI_DTYPE_I32, INFINI_DTYPE_I64, INFINI_DTYPE_BF16, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_F64);
28+
29+
CHECK_SAME_SHAPE(c_shape, a_shape, b_shape);
30+
31+
// create METAX elementwise descriptor
32+
CREATE_ELEMENTWISE_METAX_DESCRIPTOR(handle, dtype, out_desc, input_desc_vec)
33+
34+
return INFINI_STATUS_SUCCESS;
35+
}
36+
37+
infiniStatus_t Descriptor::calculate(
38+
void *workspace,
39+
size_t workspace_size,
40+
void *output,
41+
std::vector<const void *> inputs,
42+
void *stream) const {
43+
44+
if (workspace_size < _workspace_size) {
45+
return INFINI_STATUS_INSUFFICIENT_WORKSPACE;
46+
}
47+
48+
switch (_dtype) {
49+
case INFINI_DTYPE_BOOL:
50+
return _device_info->calculate<256, cuda::EqualOp, bool, bool, bool>(_info, workspace, output, inputs, stream);
51+
case INFINI_DTYPE_I8:
52+
return _device_info->calculate<256, cuda::EqualOp, bool, int8_t, int8_t>(_info, workspace, output, inputs, stream);
53+
case INFINI_DTYPE_I16:
54+
return _device_info->calculate<256, cuda::EqualOp, bool, int16_t, int16_t>(_info, workspace, output, inputs, stream);
55+
case INFINI_DTYPE_I32:
56+
return _device_info->calculate<256, cuda::EqualOp, bool, int32_t, int32_t>(_info, workspace, output, inputs, stream);
57+
case INFINI_DTYPE_I64:
58+
return _device_info->calculate<256, cuda::EqualOp, bool, int64_t, int64_t>(_info, workspace, output, inputs, stream);
59+
case INFINI_DTYPE_BF16:
60+
return _device_info->calculate<256, cuda::EqualOp, bool, cuda_bfloat16, cuda_bfloat16>(_info, workspace, output, inputs, stream);
61+
case INFINI_DTYPE_F16:
62+
return _device_info->calculate<256, cuda::EqualOp, bool, half, half>(_info, workspace, output, inputs, stream);
63+
case INFINI_DTYPE_F32:
64+
return _device_info->calculate<256, cuda::EqualOp, bool, float, float>(_info, workspace, output, inputs, stream);
65+
case INFINI_DTYPE_F64:
66+
return _device_info->calculate<256, cuda::EqualOp, bool, double, double>(_info, workspace, output, inputs, stream);
67+
default:
68+
return INFINI_STATUS_BAD_TENSOR_DTYPE;
69+
}
70+
71+
return INFINI_STATUS_SUCCESS;
72+
}
73+
} // namespace op::equal::metax

0 commit comments

Comments
 (0)