From 9ea4a25193ecbb572fe9ba50f7551e1f72cd7d1c Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 20 Jan 2025 18:22:55 +0800 Subject: [PATCH 1/4] Add debug mode and apply to all current operators' tests --- operatorspy/tests/add.py | 21 ++++- operatorspy/tests/attention.py | 7 +- operatorspy/tests/avg_pool.py | 19 +++- operatorspy/tests/causal_softmax.py | 8 +- operatorspy/tests/conv.py | 22 ++++- operatorspy/tests/expand.py | 11 ++- operatorspy/tests/gemm.py | 21 ++++- operatorspy/tests/global_avg_pool.py | 19 +++- operatorspy/tests/matmul.py | 21 ++++- operatorspy/tests/max_pool.py | 20 +++- operatorspy/tests/mlp.py | 20 +++- operatorspy/tests/random_sample.py | 15 ++- operatorspy/tests/rearrange.py | 13 ++- operatorspy/tests/relu.py | 13 ++- operatorspy/tests/rms_norm.py | 20 +++- operatorspy/tests/rotary_embedding.py | 20 +++- operatorspy/tests/swiglu.py | 29 +++++- operatorspy/tests/test_utils.py | 130 ++++++++++++++++++++++++++ 18 files changed, 392 insertions(+), 37 deletions(-) diff --git a/operatorspy/tests/add.py b/operatorspy/tests/add.py index 455014cc..e985d9ef 100644 --- a/operatorspy/tests/add.py +++ b/operatorspy/tests/add.py @@ -15,10 +15,21 @@ check_error, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) from enum import Enum, auto import torch +DEBUG = False + +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-3}, + torch.float32: {'atol': 0, 'rtol': 1e-5}, +} class Inplace(Enum): OUT_OF_PLACE = auto() @@ -83,7 +94,11 @@ def test( check_error( lib.infiniopAdd(descriptor, c_tensor.data, a_tensor.data, b_tensor.data, None) ) - assert torch.allclose(c, ans, atol=0, rtol=1e-3) + + atol, rtol = get_tolerance(tolerance_map, tensor_dtype) + if DEBUG: + debug(c, ans, atol=atol, rtol=rtol) + assert torch.allclose(c, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyAddDescriptor(descriptor)) @@ -157,6 +172,8 @@ def test_bang(lib, test_cases): infiniopAddDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/attention.py b/operatorspy/tests/attention.py index f5449aaa..830904e9 100644 --- a/operatorspy/tests/attention.py +++ b/operatorspy/tests/attention.py @@ -18,10 +18,11 @@ create_workspace, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import get_args, debug import torch import torch.nn.functional as F +DEBUG = False class AttentionDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -184,6 +185,8 @@ def test( ) ) + if DEBUG: + debug(out, ans, atol=1e-4, rtol=1e-2) assert torch.allclose(out, ans, atol=1e-4, rtol=1e-2) check_error(lib.infiniopDestroyAttentionDescriptor(descriptor)) @@ -406,6 +409,8 @@ def test_bang(lib, test_cases): infiniopAttentionDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/avg_pool.py b/operatorspy/tests/avg_pool.py index 9c240789..345e4b64 100644 --- a/operatorspy/tests/avg_pool.py +++ b/operatorspy/tests/avg_pool.py @@ -16,10 +16,15 @@ check_error, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch from typing import Tuple +DEBUG = False # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -27,6 +32,11 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-3}, + torch.float32: {'atol': 0, 'rtol': 1e-5}, +} class AvgPoolDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -156,7 +166,10 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - assert torch.allclose(y, ans, atol=0, rtol=1e-3) + atol, rtol = get_tolerance(tolerance_map, tensor_dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyAvgPoolDescriptor(descriptor)) @@ -228,6 +241,8 @@ def test_bang(lib, test_cases): infiniopAvgPoolDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/causal_softmax.py b/operatorspy/tests/causal_softmax.py index 1ad304b2..4627998a 100644 --- a/operatorspy/tests/causal_softmax.py +++ b/operatorspy/tests/causal_softmax.py @@ -18,9 +18,10 @@ create_workspace, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import get_args, debug import torch +DEBUG = False class CausalSoftmaxDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -72,6 +73,9 @@ def test(lib, handle, torch_device, x_shape, x_stride=None, x_dtype=torch.float1 None, ) ) + + if DEBUG: + debug(x, ans, atol=0, rtol=1e-2) assert torch.allclose(x, ans, atol=0, rtol=1e-2) check_error(lib.infiniopDestroyCausalSoftmaxDescriptor(descriptor)) @@ -143,6 +147,8 @@ def test_ascend(lib, test_cases): infiniopCausalSoftmaxDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/conv.py b/operatorspy/tests/conv.py index 7e7ea953..4e090e9f 100644 --- a/operatorspy/tests/conv.py +++ b/operatorspy/tests/conv.py @@ -16,13 +16,18 @@ check_error, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch import math import ctypes from torch.nn import functional as F from typing import List, Tuple +DEBUG = False # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -30,6 +35,11 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-2}, + torch.float32: {'atol': 0, 'rtol': 1e-3}, +} class ConvDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -177,10 +187,10 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - if (tensor_dtype == torch.float16): - assert torch.allclose(y, ans, atol=0, rtol=1e-2) - else: - assert torch.allclose(y, ans, atol=0, rtol=1e-3) + atol, rtol = get_tolerance(tolerance_map, tensor_dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyConvDescriptor(descriptor)) @@ -286,6 +296,8 @@ def test_bang(lib, test_cases): infiniopConvDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/expand.py b/operatorspy/tests/expand.py index e060ad73..34505e01 100644 --- a/operatorspy/tests/expand.py +++ b/operatorspy/tests/expand.py @@ -17,9 +17,11 @@ rearrange_tensor, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import get_args, debug import torch +DEBUG = False + # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -101,7 +103,10 @@ def test( ) elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - assert torch.allclose(y, ans, atol=0, rtol=1e-3) + + if DEBUG: + debug(y, ans, atol=0, rtol=0) + assert torch.allclose(y, ans, atol=0, rtol=0) check_error(lib.infiniopDestroyExpandDescriptor(descriptor)) @@ -168,6 +173,8 @@ def test_bang(lib, test_cases): infiniopExpandDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/gemm.py b/operatorspy/tests/gemm.py index 5da99eac..eb303a5c 100644 --- a/operatorspy/tests/gemm.py +++ b/operatorspy/tests/gemm.py @@ -17,9 +17,15 @@ rearrange_tensor, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch +DEBUG = False + # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -27,6 +33,12 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-2}, + torch.float32: {'atol': 0, 'rtol': 1e-2}, +} + class GEMMDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -161,7 +173,10 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - assert torch.allclose(y, ans, atol=0, rtol=1e-2) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyGEMMDescriptor(descriptor)) @@ -363,6 +378,8 @@ def test_bang(lib, test_cases): infiniopGEMMDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/global_avg_pool.py b/operatorspy/tests/global_avg_pool.py index 33f7b64d..c447acb7 100644 --- a/operatorspy/tests/global_avg_pool.py +++ b/operatorspy/tests/global_avg_pool.py @@ -16,9 +16,14 @@ check_error, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch, time +DEBUG = False # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -26,6 +31,11 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-3}, + torch.float32: {'atol': 0, 'rtol': 1e-4}, +} class GlobalAvgPoolDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -118,7 +128,10 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - assert torch.allclose(y, ans, atol=0, rtol=1e-3) + atol, rtol = get_tolerance(tolerance_map, tensor_dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyGlobalAvgPoolDescriptor(descriptor)) @@ -197,6 +210,8 @@ def test_bang(lib, test_cases): infiniopGlobalAvgPoolDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/matmul.py b/operatorspy/tests/matmul.py index ac4b0f7f..7ebb8e9f 100644 --- a/operatorspy/tests/matmul.py +++ b/operatorspy/tests/matmul.py @@ -19,13 +19,25 @@ create_workspace, ) -from operatorspy.tests.test_utils import get_args, synchronize_device +from operatorspy.tests.test_utils import ( + get_args, + synchronize_device, + debug, + get_tolerance, +) import torch +DEBUG = False PROFILE = False NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-2}, + torch.float32: {'atol': 0, 'rtol': 1e-3}, +} + class MatmulDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -115,7 +127,10 @@ def test( ) ) - assert torch.allclose(c, ans, atol=0, rtol=1e-2) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(c, ans, atol=atol, rtol=rtol) + assert torch.allclose(c, ans, atol=atol, rtol=rtol) if PROFILE: for i in range(NUM_PRERUN): @@ -343,6 +358,8 @@ def test_ascend(lib, test_cases): infiniopMatmulDescriptor_t, ] + if args.debug: + DEBUG = True if args.profile: PROFILE = True if args.cpu: diff --git a/operatorspy/tests/max_pool.py b/operatorspy/tests/max_pool.py index ffc0bb19..45d35cfc 100644 --- a/operatorspy/tests/max_pool.py +++ b/operatorspy/tests/max_pool.py @@ -16,10 +16,15 @@ check_error, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch from typing import Tuple +DEBUG = False # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -27,6 +32,12 @@ NUM_PRERUN = 10 NUM_ITERATIONS = 1000 +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-3}, + torch.float32: {'atol': 0, 'rtol': 1e-5}, +} + class MaxPoolDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -153,7 +164,10 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - assert torch.allclose(y, ans, atol=0, rtol=1e-3) + atol, rtol = get_tolerance(tolerance_map, tensor_dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyMaxPoolDescriptor(descriptor)) @@ -225,6 +239,8 @@ def test_bang(lib, test_cases): infiniopMaxPoolDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/mlp.py b/operatorspy/tests/mlp.py index 668d7861..5acbfdf2 100644 --- a/operatorspy/tests/mlp.py +++ b/operatorspy/tests/mlp.py @@ -18,10 +18,20 @@ create_workspace, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch import torch.nn as nn +DEBUG = False + +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 2e-2}, +} class MLPDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -135,7 +145,11 @@ def test( None, ) ) - assert torch.allclose(y, ans, atol=0, rtol=2e-2) + + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyMLPDescriptor(descriptor)) @@ -305,6 +319,8 @@ def test_bang(lib, test_cases): infiniopMLPDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index 98a8dceb..d93cff65 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -13,14 +13,17 @@ create_handle, destroy_handle, check_error, - rearrange_tensor, create_workspace, U64, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug_all, +) import torch +DEBUG = False class RandomSampleDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -128,6 +131,12 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ if torch_device == "npu": torch.npu.synchronize() + if DEBUG: + debug_all((indices[0].type(ans.dtype), data[ans]), + (ans, data[indices[0]]), + "or", + atol=0, + rtol=0) assert indices[0].type(ans.dtype) == ans or data[ans] == data[indices[0]] check_error(lib.infiniopDestroyRandomSampleDescriptor(descriptor)) @@ -212,6 +221,8 @@ def test_ascend(lib, test_cases): infiniopRandomSampleDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/rearrange.py b/operatorspy/tests/rearrange.py index e9cc81b9..0b118c9e 100644 --- a/operatorspy/tests/rearrange.py +++ b/operatorspy/tests/rearrange.py @@ -17,7 +17,10 @@ rearrange_tensor, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, +) import torch @@ -64,7 +67,10 @@ def test( check_error( lib.infiniopRearrange(descriptor, y_tensor.data, x_tensor.data, None) ) - assert torch.allclose(x, y, atol=0, rtol=1e-3) + + if DEBUG: + debug(x, y, atol=0, rtol=0) + assert torch.allclose(x, y, atol=0, rtol=0) check_error(lib.infiniopDestroyRearrangeDescriptor(descriptor)) @@ -137,6 +143,9 @@ def test_ascend(lib, test_cases): ] lib.infiniopDestroyRearrangeDescriptor.restype = c_int32 lib.infiniopDestroyRearrangeDescriptor.argtypes = [infiniopRearrangeDescriptor_t] + + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/relu.py b/operatorspy/tests/relu.py index b7f76627..a5a61f1b 100644 --- a/operatorspy/tests/relu.py +++ b/operatorspy/tests/relu.py @@ -16,10 +16,14 @@ check_error, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, +) from enum import Enum, auto import torch +DEBUG = False # constant for control whether profile the pytorch and lib functions # NOTE: need to manually add synchronization function to the lib function, # e.g., cudaDeviceSynchronize() for CUDA @@ -100,7 +104,10 @@ def test( elapsed = (time.time() - start_time) / NUM_ITERATIONS print(f" lib time: {elapsed :6f}") - assert torch.allclose(y, ans, atol=0, rtol=1e-3) + atol, rtol = 0, 0 + if DEBUG: + debug(y, ans, atol=atol, rtol=rtol) + assert torch.allclose(y, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyReluDescriptor(descriptor)) @@ -166,6 +173,8 @@ def test_bang(lib, test_cases): infiniopReluDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/rms_norm.py b/operatorspy/tests/rms_norm.py index 13cf1ccf..07fb8ca4 100644 --- a/operatorspy/tests/rms_norm.py +++ b/operatorspy/tests/rms_norm.py @@ -17,9 +17,20 @@ create_workspace, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch +DEBUG = False + +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 0, 'rtol': 1e-3}, +} + class RMSNormDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -83,7 +94,10 @@ def test(lib, handle, torch_device, y_shape, x_shape, w_shape, dtype=torch.float ) ) - assert torch.allclose(y.to(dtype), ans.to(dtype), atol=1e-3, rtol=1e-3) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(y.to(dtype), ans.to(dtype), atol=atol, rtol=rtol) + assert torch.allclose(y.to(dtype), ans.to(dtype), atol=atol, rtol=rtol) check_error(lib.infiniopDestroyRMSNormDescriptor(descriptor)) def test_cpu(lib, test_cases): @@ -156,6 +170,8 @@ def test_ascend(lib, test_cases): infiniopRMSNormDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/rotary_embedding.py b/operatorspy/tests/rotary_embedding.py index 081d2f91..e8077b0f 100644 --- a/operatorspy/tests/rotary_embedding.py +++ b/operatorspy/tests/rotary_embedding.py @@ -19,9 +19,19 @@ U64, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch +DEBUG = False + +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 1e-4, 'rtol': 1e-2}, +} class RoPEDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -134,7 +144,10 @@ def test(lib, handle, torch_device, shape, strides=None, dtype=torch.float16): ) ) - assert torch.allclose(t, ans, atol=1e-4, rtol=1e-2) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(t, ans, atol=atol, rtol=rtol) + assert torch.allclose(t, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroyRoPEDescriptor(descriptor)) @@ -214,6 +227,9 @@ def test_ascend(lib, test_cases) : lib.infiniopDestroyRoPEDescriptor.argtypes = [ infiniopRoPEDescriptor_t, ] + + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/swiglu.py b/operatorspy/tests/swiglu.py index 7fb447a1..abf6b0bb 100644 --- a/operatorspy/tests/swiglu.py +++ b/operatorspy/tests/swiglu.py @@ -17,9 +17,19 @@ rearrange_tensor, ) -from operatorspy.tests.test_utils import get_args +from operatorspy.tests.test_utils import ( + get_args, + debug, + get_tolerance, +) import torch +DEBUG = False + +# the atol and rtol for each data type +tolerance_map = { + torch.float16: {'atol': 1e-4, 'rtol': 1e-2}, +} class SwiGLUDescriptor(Structure): _fields_ = [("device", c_int32)] @@ -86,7 +96,10 @@ def test_out_of_place( ) ) - assert torch.allclose(c, ans, atol=1e-4, rtol=1e-2) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(c, ans, atol=atol, rtol=rtol) + assert torch.allclose(c, ans, atol=atol, rtol=rtol) print("out-of-place Test passed!") check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) @@ -137,7 +150,10 @@ def test_in_place1( ) ) - assert torch.allclose(a, ans, atol=1e-4, rtol=1e-2) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(a, ans, atol=atol, rtol=rtol) + assert torch.allclose(a, ans, atol=atol, rtol=rtol) print("in-place1 Test passed!") check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) @@ -188,7 +204,10 @@ def test_in_place2( ) ) - assert torch.allclose(b, ans, atol=1e-4, rtol=1e-2) + atol, rtol = get_tolerance(tolerance_map, dtype) + if DEBUG: + debug(b, ans, atol=atol, rtol=rtol) + assert torch.allclose(b, ans, atol=atol, rtol=rtol) check_error(lib.infiniopDestroySwiGLUDescriptor(descriptor)) @@ -285,6 +304,8 @@ def test_ascend(lib, test_cases): infiniopSwiGLUDescriptor_t, ] + if args.debug: + DEBUG = True if args.cpu: test_cpu(lib, test_cases) if args.cuda: diff --git a/operatorspy/tests/test_utils.py b/operatorspy/tests/test_utils.py index 47635b6e..cf700180 100644 --- a/operatorspy/tests/test_utils.py +++ b/operatorspy/tests/test_utils.py @@ -1,3 +1,5 @@ +from typing import Sequence + def get_args(): import argparse @@ -7,6 +9,11 @@ def get_args(): action="store_true", help="Whether profile tests", ) + parser.add_argument( + "--debug", + action="store_true", + help="Whether turn on debug mode", + ) parser.add_argument( "--cpu", action="store_true", @@ -39,3 +46,126 @@ def synchronize_device(torch_device): torch.npu.synchronize() elif torch_device == "mlu": torch.mlu.synchronize() + + +def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): + """ + Debugging function to compare two tensors (actual and desired) and print discrepancies. + + Arguments: + - actual : The tensor containing the actual computed values. + - desired : The tensor containing the expected values that `actual` should be compared to. + - atol : optional (default=0) + The absolute tolerance for the comparison. + - rtol : optional (default=1e-2) + The relative tolerance for the comparison. + - equal_nan : bool, optional (default=False) + If True, `NaN` values in `actual` and `desired` will be considered equal. + - verbose : bool, optional (default=True) + If True, the function will print detailed information about any discrepancies between the tensors. + """ + import numpy as np + print_discrepancy(actual, desired, atol, rtol, verbose) + np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True) + +def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True): + """ + Debugging function to compare two sequences of values (actual and desired), prints discrepancies + + Arguments: + - actual_vals (Sequence): A sequence (e.g., list or tuple) of actual computed values. + - desired_vals (Sequence): A sequence (e.g., list or tuple) of desired (expected) values to compare against. + - condition (str): A string specifying the condition for passing the test. It must be either: + - 'or': Test passes if any pair of actual and desired values satisfies the tolerance criteria. + - 'and': Test passes if all pairs of actual and desired values satisfy the tolerance criteria. + - atol (float, optional): Absolute tolerance. Default is 0. + - rtol (float, optional): Relative tolerance. Default is 1e-2. + - equal_nan (bool, optional): If True, NaN values in both actual and desired are considered equal. Default is False. + - verbose (bool, optional): If True, detailed output is printed for each comparison. Default is True. + + Raises: + - AssertionError: If the condition is not satisfied based on the provided `condition`, `atol`, and `rtol`. + - ValueError: If the length of `actual_vals` and `desired_vals` do not match. + - AssertionError: If the specified `condition` is not 'or' or 'and'. + """ + assert len(actual_vals) == len(desired_vals) + assert condition in {"or", "and"} + import numpy as np + + passed = False if condition == "or" else True + + for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)): + print(f"Condition #{index + 1}: {actual} == {desired}") + indices = print_discrepancy(actual, desired, atol, rtol, verbose) + if condition == "or": + if not passed and len(indices) == 0: + passed = True + elif condition == "and": + if passed and len(indices) != 0: + passed = False + print(f"\033[31mThe condition has not been satisfied: Condition #{index + 1}\033[0m") + np.testing.assert_allclose(actual.cpu(), desired.cpu(), rtol, atol, equal_nan, verbose=True, strict=True) + assert passed, "\033[31mThe condition has not been satisfied\033[0m" + + +def print_discrepancy( + actual, expected, atol=0, rtol=1e-3, verbose=True +): + if actual.shape != expected.shape: + raise ValueError("Tensors must have the same shape to compare.") + + import torch + import sys + + is_terminal = sys.stdout.isatty() + + # Calculate the difference mask based on atol and rtol + diff_mask = torch.abs(actual - expected) > (atol + rtol * torch.abs(expected)) + diff_indices = torch.nonzero(diff_mask, as_tuple=False) + delta = actual - expected + + # Display format: widths for columns + col_width = [18, 20, 20, 20] + decimal_places = [0, 12, 12, 12] + total_width = sum(col_width) + sum(decimal_places) + + def add_color(text, color_code): + if is_terminal: + return f"\033[{color_code}m{text}\033[0m" + else: + return text + + if verbose: + for idx in diff_indices: + index_tuple = tuple(idx.tolist()) + actual_str = f"{actual[index_tuple]:<{col_width[1]}.{decimal_places[1]}f}" + expected_str = f"{expected[index_tuple]:<{col_width[2]}.{decimal_places[2]}f}" + delta_str = f"{delta[index_tuple]:<{col_width[3]}.{decimal_places[3]}f}" + print( + f" > Index: {str(index_tuple):<{col_width[0]}}" + f"actual: {add_color(actual_str, 31)}" + f"expect: {add_color(expected_str, 32)}" + f"delta: {add_color(delta_str, 33)}" + ) + + print("-" * total_width) + print(add_color("INFO:", 35)) + print(f" - Actual dtype: {actual.dtype}") + print(f" - Desired dtype: {expected.dtype}") + print(f" - Atol: {atol}") + print(f" - Rtol: {rtol}") + print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)") + print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}") + print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}") + print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}") + print("-" * total_width + "\n") + + return diff_indices + + +def get_tolerance(tolerance_map, tensor_dtype, default_atol=0, default_rtol=1e-3): + """ + Returns the atol and rtol for a given tensor data type in the tolerance_map. + If the given data type is not found, it returns the provided default tolerance values. + """ + return tolerance_map.get(tensor_dtype, {'atol': default_atol, 'rtol': default_rtol}).values() From 8f256da3ed8ca2cfc543d29e987bf827a6826ef6 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 20 Jan 2025 18:38:00 +0800 Subject: [PATCH 2/4] Add comment format and assertion error message for debug functions, misc. --- operatorspy/tests/random_sample.py | 4 ++-- operatorspy/tests/test_utils.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/operatorspy/tests/random_sample.py b/operatorspy/tests/random_sample.py index d93cff65..6618e53a 100644 --- a/operatorspy/tests/random_sample.py +++ b/operatorspy/tests/random_sample.py @@ -132,8 +132,8 @@ def test(lib, handle, torch_device, voc, random_val, topp, topk, temperature, x_ torch.npu.synchronize() if DEBUG: - debug_all((indices[0].type(ans.dtype), data[ans]), - (ans, data[indices[0]]), + debug_all((indices[0].type(ans.dtype), data[indices[0]]), + (ans, data[ans]), "or", atol=0, rtol=0) diff --git a/operatorspy/tests/test_utils.py b/operatorspy/tests/test_utils.py index cf700180..ac6534e1 100644 --- a/operatorspy/tests/test_utils.py +++ b/operatorspy/tests/test_utils.py @@ -53,6 +53,7 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): Debugging function to compare two tensors (actual and desired) and print discrepancies. Arguments: + ---------- - actual : The tensor containing the actual computed values. - desired : The tensor containing the expected values that `actual` should be compared to. - atol : optional (default=0) @@ -70,9 +71,11 @@ def debug(actual, desired, atol=0, rtol=1e-2, equal_nan=False, verbose=True): def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, atol=0, rtol=1e-2, equal_nan=False, verbose=True): """ - Debugging function to compare two sequences of values (actual and desired), prints discrepancies + Debugging function to compare two sequences of values (actual and desired) pair by pair, results + are linked by the given logical condition, and prints discrepancies Arguments: + ---------- - actual_vals (Sequence): A sequence (e.g., list or tuple) of actual computed values. - desired_vals (Sequence): A sequence (e.g., list or tuple) of desired (expected) values to compare against. - condition (str): A string specifying the condition for passing the test. It must be either: @@ -84,12 +87,13 @@ def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, ato - verbose (bool, optional): If True, detailed output is printed for each comparison. Default is True. Raises: + ---------- - AssertionError: If the condition is not satisfied based on the provided `condition`, `atol`, and `rtol`. - ValueError: If the length of `actual_vals` and `desired_vals` do not match. - AssertionError: If the specified `condition` is not 'or' or 'and'. """ - assert len(actual_vals) == len(desired_vals) - assert condition in {"or", "and"} + assert len(actual_vals) == len(desired_vals), "Invalid Length" + assert condition in {"or", "and"}, "Invalid condition: should be either 'or' or 'and'" import numpy as np passed = False if condition == "or" else True From 0944c977a49d70173374fcc18c92df76738093cd Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Mon, 20 Jan 2025 18:43:42 +0800 Subject: [PATCH 3/4] Add DEBUG to rearrange.py --- operatorspy/tests/rearrange.py | 1 + 1 file changed, 1 insertion(+) diff --git a/operatorspy/tests/rearrange.py b/operatorspy/tests/rearrange.py index 0b118c9e..58febbc0 100644 --- a/operatorspy/tests/rearrange.py +++ b/operatorspy/tests/rearrange.py @@ -23,6 +23,7 @@ ) import torch +DEBUG = False class RerrangeDescriptor(Structure): _fields_ = [("device", c_int32)] From 46807b2fa590b1d1993684f3eee24a00047fb4a3 Mon Sep 17 00:00:00 2001 From: Zimin Li Date: Wed, 22 Jan 2025 19:04:21 +0800 Subject: [PATCH 4/4] Change condition tab to blue and tweak the display format --- operatorspy/tests/test_utils.py | 21 ++++++++++----------- 1 file changed, 10 insertions(+), 11 deletions(-) diff --git a/operatorspy/tests/test_utils.py b/operatorspy/tests/test_utils.py index ac6534e1..40214ce1 100644 --- a/operatorspy/tests/test_utils.py +++ b/operatorspy/tests/test_utils.py @@ -99,7 +99,7 @@ def debug_all(actual_vals: Sequence, desired_vals: Sequence, condition: str, ato passed = False if condition == "or" else True for index, (actual, desired) in enumerate(zip(actual_vals, desired_vals)): - print(f"Condition #{index + 1}: {actual} == {desired}") + print(f" \033[36mCondition #{index + 1}:\033[0m {actual} == {desired}") indices = print_discrepancy(actual, desired, atol, rtol, verbose) if condition == "or": if not passed and len(indices) == 0: @@ -152,16 +152,15 @@ def add_color(text, color_code): f"delta: {add_color(delta_str, 33)}" ) - print("-" * total_width) - print(add_color("INFO:", 35)) - print(f" - Actual dtype: {actual.dtype}") - print(f" - Desired dtype: {expected.dtype}") - print(f" - Atol: {atol}") - print(f" - Rtol: {rtol}") - print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)") - print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}") - print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}") - print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}") + print(add_color(" INFO:", 35)) + print(f" - Actual dtype: {actual.dtype}") + print(f" - Desired dtype: {expected.dtype}") + print(f" - Atol: {atol}") + print(f" - Rtol: {rtol}") + print(f" - Mismatched elements: {len(diff_indices)} / {actual.numel()} ({len(diff_indices) / actual.numel() * 100}%)") + print(f" - Min(actual) : {torch.min(actual):<{col_width[1]}} | Max(actual) : {torch.max(actual):<{col_width[2]}}") + print(f" - Min(desired): {torch.min(expected):<{col_width[1]}} | Max(desired): {torch.max(expected):<{col_width[2]}}") + print(f" - Min(delta) : {torch.min(delta):<{col_width[1]}} | Max(delta) : {torch.max(delta):<{col_width[2]}}") print("-" * total_width + "\n") return diff_indices