Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions include/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@
#define __DEVICE_H__

enum DeviceEnum {
DevCpu,
DevNvGpu,
DevCambriconMlu,
DevAscendNpu,
DevCpu = 0,
DevNvGpu = 1,
DevCambriconMlu = 2,
DevAscendNpu = 3,
DevMetaxGpu = 4,
DevMthreadsGpu = 5,
};

typedef enum DeviceEnum Device;
Expand Down
2 changes: 2 additions & 0 deletions operatorspy/devices.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@ class DeviceEnum:
DEVICE_CUDA = 1
DEVICE_BANG = 2
DEVICE_ASCEND = 3
DEVICE_MACA = 4
DEVICE_MUSA = 5
12 changes: 11 additions & 1 deletion operatorspy/tests/causal_softmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,14 @@ def test_ascend(lib, test_cases):

destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "cuda", x_shape, x_stride)

destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
# x_shape, x_stride
Expand Down Expand Up @@ -151,6 +159,8 @@ def test_ascend(lib, test_cases):
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
36 changes: 35 additions & 1 deletion operatorspy/tests/matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,38 @@ def test_ascend(lib, test_cases):

destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)

for (
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
) in test_cases:
test(
lib,
handle,
"cuda",
alpha,
beta,
a_shape,
b_shape,
c_shape,
a_stride,
b_stride,
c_stride,
dtype,
)

destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
# alpha, beta, a_shape, b_shape, c_shape, a_stride, b_stride, c_stride, dtype
Expand Down Expand Up @@ -353,6 +385,8 @@ def test_ascend(lib, test_cases):
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
24 changes: 20 additions & 4 deletions operatorspy/tests/random_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,18 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_
)
data = torch.arange(voc).float() * 0.0001
_perm = torch.randperm(voc)
data = data[_perm].to(x_dtype).to(torch_device)
if (torch_device == 'maca'):
data = data[_perm].to(x_dtype).to('cuda')
else:
data = data[_perm].to(x_dtype).to(torch_device)
if(topp > 0 and topk > 1):
ans = random_sample(data.to("cpu"), random_val, topp, topk, voc, temperature, "cpu")
else:
ans = random_sample_0(data)
indices = torch.zeros([1], dtype=torch.int64).to(torch_device)
if(torch_device == 'maca'):
indices = torch.zeros([1], dtype = torch.int64).to('cuda')
else:
indices = torch.zeros([1], dtype = torch.uint64).to(torch_device)
x_tensor = to_tensor(data, lib)
indices_tensor = to_tensor(indices, lib)
indices_tensor.descriptor.contents.dt = U64 # treat int64 as uint64
Expand Down Expand Up @@ -163,7 +169,15 @@ def test_ascend(lib, test_cases):
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
test(lib, handle, "npu", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)
destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)
for (voc, random_val, topp, topk, temperature) in test_cases:
test(lib, handle, "maca", voc, random_val, topp, topk, temperature)
destroy_handle(lib, handle)



if __name__ == "__main__":
Expand Down Expand Up @@ -220,6 +234,8 @@ def test_ascend(lib, test_cases):
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
11 changes: 11 additions & 0 deletions operatorspy/tests/rearrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,15 @@ def test_ascend(lib, test_cases):
test(lib, handle, "npu", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)
for test_case in test_cases:
x_shape, x_stride = test_case[0]
y_shape, y_stride = test_case[1]
test(lib, handle, "cuda", x_shape, x_stride, y_shape, y_stride)
destroy_handle(lib, handle)

if __name__ == "__main__":
args = get_args()
test_cases = [
Expand Down Expand Up @@ -145,4 +154,6 @@ def test_ascend(lib, test_cases):
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
print("\033[92mTest passed!\033[0m")
12 changes: 11 additions & 1 deletion operatorspy/tests/rms_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,14 @@ def test_ascend(lib, test_cases):

destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)
for (y_shape, x_shape, w_shape, dtype, w_dtype) in test_cases:
test(lib, handle, "cuda", y_shape, x_shape, w_shape, dtype, w_dtype)

destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
# y_shape, x_shape, w_shape, dtype, w_dtype
Expand Down Expand Up @@ -164,6 +172,8 @@ def test_ascend(lib, test_cases):
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
17 changes: 14 additions & 3 deletions operatorspy/tests/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ def rotary_embedding(t, pos, theta, torch_device):
)
freqs = torch.outer(pos, freqs)
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)

t_ = torch.view_as_complex(t.reshape(*t.shape[:-1], -1, 2))
freqs_cis = reshape_for_broadcast(freqs_cis, t_)
t_out = torch.view_as_real(t_ * freqs_cis).flatten(2).to(t.dtype)
Expand Down Expand Up @@ -82,6 +81,10 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
ans = rotary_embedding(t, posTmp, theta, "cpu").to(torch_device)
pos = pos.to(torch_device)
t = t.to(torch_device)
elif torch_device == 'maca':
ans = rotary_embedding(t, posTmp, theta, "cpu").to('cuda')
pos = pos.to('cuda')
t = t.to('cuda')
else:
t = t.to(torch_device)
pos = pos.to(torch_device)
Expand Down Expand Up @@ -133,7 +136,6 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16):
None,
)
)

assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2)
check_error(lib.infiniopDestroyRoPEDescriptor(descriptor))

Expand Down Expand Up @@ -172,6 +174,13 @@ def test_ascend(lib, test_cases) :
test(lib, handle, "npu", shape, strides, dtype)
destroy_handle(lib, handle)

def test_maca(lib, test_cases) :
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)
for shape, strides, dtype in test_cases:
test(lib, handle, "maca", shape, strides, dtype)
destroy_handle(lib, handle)

if __name__ == "__main__":
test_cases = [
((1, 32, 128), None, torch.float16),
Expand Down Expand Up @@ -222,6 +231,8 @@ def test_ascend(lib, test_cases) :
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
if args.maca:
test_maca(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend or args.maca):
test_cpu(lib, test_cases)
print("\033[92mTest passed!\033[0m")
14 changes: 14 additions & 0 deletions operatorspy/tests/swiglu.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,18 @@ def test_ascend(lib, test_cases):

destroy_handle(lib, handle)

def test_maca(lib, test_cases):
device = DeviceEnum.DEVICE_MACA
handle = create_handle(lib, device)

for shape, a_stride, b_stride, c_stride, dtype in test_cases:
test_out_of_place(
lib, handle, "cuda", shape, a_stride, b_stride, c_stride, dtype)
test_in_place1(lib, handle, "cuda", shape, a_stride, b_stride, dtype)
test_in_place2(lib, handle, "cuda", shape, a_stride, b_stride, dtype)

destroy_handle(lib, handle)


if __name__ == "__main__":
test_cases = [
Expand Down Expand Up @@ -293,4 +305,6 @@ def test_ascend(lib, test_cases):
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if args.maca:
test_maca(lib, test_cases)
print("\033[92mTest passed!\033[0m")
5 changes: 5 additions & 0 deletions operatorspy/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ def get_args():
action="store_true",
help="Run ASCEND NPU test",
)
parser.add_argument(
"--maca",
action="store_true",
help="Run ASCEND NPU test",
)

return parser.parse_args()

Expand Down
2 changes: 2 additions & 0 deletions operatorspy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ def create_workspace(size, torch_device):
if size == 0:
return None
import torch
if (torch_device == 'maca'):
return torch.zeros(size=(size,), dtype=torch.uint8, device='cuda')
return torch.zeros(size=(size,), dtype=torch.uint8, device=torch_device)

def create_handle(lib, device, id=0):
Expand Down
13 changes: 13 additions & 0 deletions src/devices/handle.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
#ifdef ENABLE_ASCEND_NPU
#include "./ascend/ascend_handle.h"
#endif
#ifdef ENABLE_METAX_GPU
#include "./maca/maca_handle.h"
#endif


__C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device device, int device_id) {
Expand Down Expand Up @@ -40,6 +43,11 @@ __C infiniopStatus_t infiniopCreateHandle(infiniopHandle_t *handle_ptr, Device d
case DevAscendNpu: {
return createAscendHandle((AscendHandle_t *) handle_ptr, device_id);
}
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return createMacaHandle((MacaHandle_t *) handle_ptr, device_id);
}
#endif
}
return STATUS_BAD_DEVICE;
Expand Down Expand Up @@ -68,6 +76,11 @@ __C infiniopStatus_t infiniopDestroyHandle(infiniopHandle_t handle) {
case DevAscendNpu: {
return deleteAscendHandle((AscendHandle_t) handle);
}
#endif
#ifdef ENABLE_METAX_GPU
case DevMetaxGpu: {
return deleteMacaHandle((MacaHandle_t) handle);
}
#endif
}
return STATUS_BAD_DEVICE;
Expand Down
Loading