Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
167 changes: 85 additions & 82 deletions test/infiniop/causal_softmax.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,47 @@
from ctypes import POINTER, Structure, c_int32, c_uint64, c_void_p
import torch
import ctypes
import sys
import os


sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from operatorspy import (
open_lib,
to_tensor,
DeviceEnum,
from ctypes import POINTER, Structure, c_int32, c_size_t, c_uint64, c_void_p, c_float
from libinfiniop import (
infiniopHandle_t,
infiniopTensorDescriptor_t,
create_handle,
destroy_handle,
open_lib,
to_tensor,
get_test_devices,
check_error,
rearrange_tensor,
rearrange_if_needed,
create_workspace,
test_operator,
get_args,
debug,
get_tolerance,
profile_operation,
)

from operatorspy.tests.test_utils import get_args
import torch
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES = [
# x_shape, x_stride
((32, 512), None),
((32, 512), (1024, 1)),
((32, 5, 5), None),
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
]

# Data types used for testing
_TENSOR_DTYPES = [torch.float16]

# Tolerance map for different data types
_TOLERANCE_MAP = {
torch.float16: {"atol": 0, "rtol": 1e-2},
}

DEBUG = False
PROFILE = False
NUM_PRERUN = 10
NUM_ITERATIONS = 1000


class CausalSoftmaxDescriptor(Structure):
Expand All @@ -37,101 +59,82 @@ def causal_softmax(x):
return torch.nn.functional.softmax(masked, dim=-1).to(type)


def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float16):
def test(lib, handle, torch_device, x_shape, x_stride=None, dtype=torch.float16):
print(
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{x_dtype}"
f"Testing CausalSoftmax on {torch_device} with x_shape:{x_shape} x_stride:{x_stride} dtype:{dtype}"
)
x = torch.rand(x_shape, dtype=x_dtype).to(torch_device)
if x_stride is not None:
x = rearrange_tensor(x, x_stride)

x = torch.rand(x_shape, dtype=dtype).to(torch_device)

ans = causal_softmax(x)

x = rearrange_if_needed(x, x_stride)

x_tensor = to_tensor(x, lib)

descriptor = infiniopCausalSoftmaxDescriptor_t()
check_error(
lib.infiniopCreateCausalSoftmaxDescriptor(
handle, ctypes.byref(descriptor), x_tensor.descriptor
)
)
workspace_size = c_uint64(0)
check_error(
lib.infiniopGetCausalSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)

# Invalidate the shape and strides in the descriptor to prevent them from being directly used by the kernel
x_tensor.descriptor.contents.invalidate()

workspace = create_workspace(workspace_size.value, x.device)
workspace_size = c_uint64(0)
check_error(
lib.infiniopCausalSoftmax(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
x_tensor.data,
None,
lib.infiniopGetCausalSoftmaxWorkspaceSize(
descriptor, ctypes.byref(workspace_size)
)
)
assert torch.allclose(x, ans, atol=0, rtol=1e-2)
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))


def test_cpu(lib, test_cases):
device = DeviceEnum.DEVICE_CPU
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "cpu", x_shape, x_stride)
destroy_handle(lib, handle)


def test_cuda(lib, test_cases):
device = DeviceEnum.DEVICE_CUDA
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "cuda", x_shape, x_stride)
destroy_handle(lib, handle)


def test_bang(lib, test_cases):
import torch_mlu
workspace = create_workspace(workspace_size.value, x.device)

device = DeviceEnum.DEVICE_BANG
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "mlu", x_shape, x_stride)
destroy_handle(lib, handle)
def lib_causal_softmax():
check_error(
lib.infiniopCausalSoftmax(
descriptor,
workspace.data_ptr() if workspace is not None else None,
workspace_size.value,
x_tensor.data,
None,
)
)

lib_causal_softmax()

def test_ascend(lib, test_cases):
import torch_npu
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
if DEBUG:
debug(x, ans, atol=atol, rtol=rtol)
assert torch.allclose(x, ans, atol=atol, rtol=rtol)

device = DeviceEnum.DEVICE_ASCEND
handle = create_handle(lib, device)
for x_shape, x_stride in test_cases:
test(lib, handle, "npu", x_shape, x_stride)
# Profiling workflow
if PROFILE:
# fmt: off
profile_operation("PyTorch", lambda: causal_softmax(x), torch_device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_causal_softmax(), torch_device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on

destroy_handle(lib, handle)
check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor))


if __name__ == "__main__":
test_cases = [
# x_shape, x_stride
((32, 20, 512), None),
((32, 20, 512), (20480, 512, 1)), # Ascend 暂不支持非连续
]
args = get_args()
lib = open_lib()

lib.infiniopCreateCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopCreateCausalSoftmaxDescriptor.argtypes = [
infiniopHandle_t,
POINTER(infiniopCausalSoftmaxDescriptor_t),
infiniopTensorDescriptor_t,
]

lib.infiniopGetCausalSoftmaxWorkspaceSize.restype = c_int32
lib.infiniopGetCausalSoftmaxWorkspaceSize.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
POINTER(c_uint64),
]

lib.infiniopCausalSoftmax.restype = c_int32
lib.infiniopCausalSoftmax.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
Expand All @@ -140,19 +143,19 @@ def test_ascend(lib, test_cases):
c_void_p,
c_void_p,
]

lib.infiniopDestroyCausalSoftmaxDescriptor.restype = c_int32
lib.infiniopDestroyCausalSoftmaxDescriptor.argtypes = [
infiniopCausalSoftmaxDescriptor_t,
]

if args.cpu:
test_cpu(lib, test_cases)
if args.cuda:
test_cuda(lib, test_cases)
if args.bang:
test_bang(lib, test_cases)
if args.ascend:
test_ascend(lib, test_cases)
if not (args.cpu or args.cuda or args.bang or args.ascend):
test_cpu(lib, test_cases)
# Configure testing options
DEBUG = args.debug
PROFILE = args.profile
NUM_PRERUN = args.num_prerun
NUM_ITERATIONS = args.num_iterations

for device in get_test_devices(args):
test_operator(lib, device, test, _TEST_CASES, _TENSOR_DTYPES)

print("\033[92mTest passed!\033[0m")
Loading